Skip to content

CETT Functions

Low-level CETT extraction and causal intervention utilities. These are used internally by HProbes but are also available for custom pipelines.

hprobes.cett.available_layers(model)

Return list of all available layer indices for the model.

Source code in src/hprobes/cett.py
def available_layers(model: torch.nn.Module) -> List[int]:
    """Return list of all available layer indices for the model."""
    return list(range(len(_get_transformer_layers(model))))

hprobes.cett.precompute_col_norms(model, layers)

Precompute ‖W_down[:, j]‖₂ for each layer.

Returns dict mapping layer_idx → (intermediate_dim,) tensor of column norms. Computed once and reused across all samples.

Source code in src/hprobes/cett.py
def precompute_col_norms(
    model: torch.nn.Module,
    layers: List[int],
) -> Dict[int, torch.Tensor]:
    """Precompute ‖W_down[:, j]‖₂ for each layer.

    Returns dict mapping layer_idx → (intermediate_dim,) tensor of column norms.
    Computed once and reused across all samples.
    """
    col_norms = {}
    for layer_idx in layers:
        down_proj = get_mlp_down_proj(model, layer_idx)
        W = down_proj.weight.detach().float()  # (hidden_dim, intermediate_dim) usually

        # GPT-2 uses Conv1D where weight is (intermediate_dim, hidden_dim)
        if type(down_proj).__name__ == "Conv1D":
            col_norms[layer_idx] = torch.norm(W, dim=1).cpu()  # (intermediate_dim,)
        else:
            col_norms[layer_idx] = torch.norm(W, dim=0).cpu()  # (intermediate_dim,)
    return col_norms

hprobes.cett.forward_cett(model, tokens, layers, col_norms, token_position=-1)

Single forward pass — extract CETT at a given token position.

Parameters

model : causal LM tokens : tokenizer output (input_ids, attention_mask on correct device) layers : list of layer indices to hook col_norms : precomputed column norms from precompute_col_norms() token_position : which token to extract CETT at (-1 = last token)

Returns

cett_vec : (n_layers * intermediate_dim,) float32 — concatenated CETT values logits : (vocab_size,) float32 — output logits at the last token

Source code in src/hprobes/cett.py
def forward_cett(
    model: torch.nn.Module,
    tokens: Dict[str, torch.Tensor],
    layers: List[int],
    col_norms: Dict[int, torch.Tensor],
    token_position: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Single forward pass — extract CETT at a given token position.

    Parameters
    ----------
    model : causal LM
    tokens : tokenizer output (input_ids, attention_mask on correct device)
    layers : list of layer indices to hook
    col_norms : precomputed column norms from precompute_col_norms()
    token_position : which token to extract CETT at (-1 = last token)

    Returns
    -------
    cett_vec : (n_layers * intermediate_dim,) float32 — concatenated CETT values
    logits   : (vocab_size,) float32 — output logits at the last token
    """
    z_cache: Dict[int, torch.Tensor] = {}
    h_cache: Dict[int, torch.Tensor] = {}
    handles = []

    for layer_idx in layers:
        down_proj = get_mlp_down_proj(model, layer_idx)

        def make_hook(idx: int):
            def hook(module, input, output):
                z = input[0]
                h = output
                z_cache[idx] = z[0, token_position, :].detach().float().cpu()
                h_cache[idx] = h[0, token_position, :].detach().float().cpu()
                return output

            return hook

        handles.append(down_proj.register_forward_hook(make_hook(layer_idx)))

    try:
        with torch.no_grad():
            out = model(**tokens)
    finally:
        for h in handles:
            h.remove()

    logits = out.logits[0, -1, :].detach().float().cpu()

    cett_parts = []
    for layer_idx in layers:
        z = z_cache[layer_idx]
        h = h_cache[layer_idx]
        h_norm = torch.norm(h).item() + 1e-8
        cett = (torch.abs(z) * col_norms[layer_idx]) / h_norm
        cett_parts.append(cett)

    return torch.cat(cett_parts, dim=0), logits

hprobes.cett.forward_cett_at_token(model, tokens, extra_token_id, layers, col_norms)

Append one token to the input and capture CETT at that appended position.

Returns

cett_answer : (n_layers * intermediate_dim,) float32

Source code in src/hprobes/cett.py
def forward_cett_at_token(
    model: torch.nn.Module,
    tokens: Dict[str, torch.Tensor],
    extra_token_id: int,
    layers: List[int],
    col_norms: Dict[int, torch.Tensor],
) -> torch.Tensor:
    """Append one token to the input and capture CETT at that appended position.

    Returns
    -------
    cett_answer : (n_layers * intermediate_dim,) float32
    """
    input_ids = tokens["input_ids"]
    extra_t = torch.tensor([[extra_token_id]], device=input_ids.device)
    extended_ids = torch.cat([input_ids, extra_t], dim=1)

    extended: Dict[str, torch.Tensor] = {"input_ids": extended_ids}
    if "attention_mask" in tokens:
        m = tokens["attention_mask"]
        extended["attention_mask"] = torch.cat(
            [m, torch.ones((1, 1), device=m.device, dtype=m.dtype)], dim=1
        )

    z_cache: Dict[int, torch.Tensor] = {}
    h_cache: Dict[int, torch.Tensor] = {}
    handles = []

    for layer_idx in layers:
        down_proj = get_mlp_down_proj(model, layer_idx)

        def make_hook(idx: int):
            def hook(module, input, output):
                z_cache[idx] = input[0][0, -1, :].detach().float().cpu()
                h_cache[idx] = output[0, -1, :].detach().float().cpu()
                return output

            return hook

        handles.append(down_proj.register_forward_hook(make_hook(layer_idx)))

    try:
        with torch.no_grad():
            model(**extended)
    finally:
        for h in handles:
            h.remove()

    cett_parts = []
    for layer_idx in layers:
        h_norm = torch.norm(h_cache[layer_idx]).item() + 1e-8
        cett = (torch.abs(z_cache[layer_idx]) * col_norms[layer_idx]) / h_norm
        cett_parts.append(cett)

    return torch.cat(cett_parts, dim=0)

hprobes.cett.forward_cett_span(model, tokens, span_start, span_end, layers, col_norms, aggregation='mean')

Forward pass over a full sequence — extract CETT aggregated over a token span.

Source code in src/hprobes/cett.py
def forward_cett_span(
    model: torch.nn.Module,
    tokens: Dict[str, torch.Tensor],
    span_start: int,
    span_end: int,
    layers: List[int],
    col_norms: Dict[int, torch.Tensor],
    aggregation: str = "mean",
) -> torch.Tensor:
    """Forward pass over a full sequence — extract CETT aggregated over a token span."""
    z_cache: Dict[int, torch.Tensor] = {}
    h_cache: Dict[int, torch.Tensor] = {}
    handles = []

    for layer_idx in layers:
        down_proj = get_mlp_down_proj(model, layer_idx)

        def make_hook(idx: int):
            def hook(module, input, output):
                z_cache[idx] = input[0][0].detach().float().cpu()
                h_cache[idx] = output[0].detach().float().cpu()
                return output

            return hook

        handles.append(down_proj.register_forward_hook(make_hook(layer_idx)))

    try:
        with torch.no_grad():
            model(**tokens)
    finally:
        for h in handles:
            h.remove()

    cett_parts = []
    for layer_idx in layers:
        z_span = z_cache[layer_idx][span_start:span_end]
        h_span = h_cache[layer_idx][span_start:span_end]
        h_norms = torch.norm(h_span, dim=-1, keepdim=True) + 1e-8
        cett_span = (torch.abs(z_span) * col_norms[layer_idx].unsqueeze(0)) / h_norms
        if aggregation == "max":
            cett_agg = cett_span.max(dim=0).values
        else:
            cett_agg = cett_span.mean(dim=0)
        cett_parts.append(cett_agg)

    return torch.cat(cett_parts, dim=0)

hprobes.cett.forward_cett_batch(model, batch_tokens, layers, col_norms, token_positions)

Batched forward pass — extract CETT for each sample.

Source code in src/hprobes/cett.py
def forward_cett_batch(
    model: torch.nn.Module,
    batch_tokens: Dict[str, torch.Tensor],
    layers: List[int],
    col_norms: Dict[int, torch.Tensor],
    token_positions: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Batched forward pass — extract CETT for each sample."""
    batch_size = batch_tokens["input_ids"].shape[0]
    device = batch_tokens["input_ids"].device
    batch_idx = torch.arange(batch_size, device=device)
    token_pos_t = torch.tensor(token_positions, device=device)

    z_cache: Dict[int, torch.Tensor] = {}
    h_cache: Dict[int, torch.Tensor] = {}
    handles = []

    for layer_idx in layers:
        down_proj = get_mlp_down_proj(model, layer_idx)

        def make_hook(idx: int):
            def hook(module, input, output):
                z_cache[idx] = input[0][batch_idx, token_pos_t].detach().float()
                h_cache[idx] = output[batch_idx, token_pos_t].detach().float()
                return output

            return hook

        handles.append(down_proj.register_forward_hook(make_hook(layer_idx)))

    if "attention_mask" in batch_tokens:
        position_ids = (batch_tokens["attention_mask"].cumsum(dim=-1) - 1).clamp(min=0)
    else:
        seq_len = batch_tokens["input_ids"].shape[1]
        position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)

    try:
        with torch.no_grad():
            out = model(**batch_tokens, position_ids=position_ids)
    finally:
        for h in handles:
            h.remove()

    logits_matrix = out.logits[batch_idx, token_pos_t].detach().float().cpu()

    z_all = torch.stack([z_cache[li] for li in layers], dim=0)
    h_all = torch.stack([h_cache[li] for li in layers], dim=0)
    col_norms_gpu = torch.stack([col_norms[li].to(device) for li in layers])

    h_norm = torch.norm(h_all, dim=-1, keepdim=True) + 1e-8
    cett = (torch.abs(z_all) * col_norms_gpu.unsqueeze(1)) / h_norm
    cett_matrix = cett.permute(1, 0, 2).reshape(batch_size, -1).cpu()

    return cett_matrix, logits_matrix

hprobes.cett.forward_cett_at_token_batch(model, batch_tokens, extra_token_ids, layers, col_norms)

Batched version of forward_cett_at_token.

Source code in src/hprobes/cett.py
def forward_cett_at_token_batch(
    model: torch.nn.Module,
    batch_tokens: Dict[str, torch.Tensor],
    extra_token_ids: List[int],
    layers: List[int],
    col_norms: Dict[int, torch.Tensor],
) -> torch.Tensor:
    """Batched version of forward_cett_at_token."""
    batch_size = batch_tokens["input_ids"].shape[0]
    device = batch_tokens["input_ids"].device

    extra_t = torch.tensor(extra_token_ids, device=device).unsqueeze(1)
    extended_ids = torch.cat([batch_tokens["input_ids"], extra_t], dim=1)

    extended: Dict[str, torch.Tensor] = {"input_ids": extended_ids}
    if "attention_mask" in batch_tokens:
        m = batch_tokens["attention_mask"]
        extended["attention_mask"] = torch.cat(
            [m, torch.ones((batch_size, 1), device=device, dtype=m.dtype)], dim=1
        )

    z_cache: Dict[int, torch.Tensor] = {}
    h_cache: Dict[int, torch.Tensor] = {}
    handles = []

    for layer_idx in layers:
        down_proj = get_mlp_down_proj(model, layer_idx)

        def make_hook(idx: int):
            def hook(module, input, output):
                z_cache[idx] = input[0][:, -1, :].detach().float()
                h_cache[idx] = output[:, -1, :].detach().float()
                return output

            return hook

        handles.append(down_proj.register_forward_hook(make_hook(layer_idx)))

    if "attention_mask" in extended:
        position_ids = (extended["attention_mask"].cumsum(dim=-1) - 1).clamp(min=0)
    else:
        seq_len = extended["input_ids"].shape[1]
        position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)

    try:
        with torch.no_grad():
            model(**extended, position_ids=position_ids)
    finally:
        for h in handles:
            h.remove()

    z_all = torch.stack([z_cache[li] for li in layers], dim=0)
    h_all = torch.stack([h_cache[li] for li in layers], dim=0)
    col_norms_gpu = torch.stack([col_norms[li].to(device) for li in layers])

    h_norm = torch.norm(h_all, dim=-1, keepdim=True) + 1e-8
    cett = (torch.abs(z_all) * col_norms_gpu.unsqueeze(1)) / h_norm
    cett_matrix = cett.permute(1, 0, 2).reshape(batch_size, -1).cpu()

    return cett_matrix

hprobes.cett.scale_h_neurons(model, tokens, h_neurons, alpha, layers)

Forward pass scaling H-Neuron activations by alpha.

Source code in src/hprobes/cett.py
def scale_h_neurons(
    model: torch.nn.Module,
    tokens: Dict[str, torch.Tensor],
    h_neurons: List[Tuple[int, int]],
    alpha: float,
    layers: List[int],
) -> torch.Tensor:
    """Forward pass scaling H-Neuron activations by alpha."""
    neurons_by_layer: Dict[int, List[int]] = {}
    for layer_idx, neuron_idx in h_neurons:
        neurons_by_layer.setdefault(layer_idx, []).append(neuron_idx)

    handles = []
    for layer_idx in layers:
        if layer_idx not in neurons_by_layer:
            continue
        indices = torch.tensor(neurons_by_layer[layer_idx], dtype=torch.long)
        down_proj = get_mlp_down_proj(model, layer_idx)

        def make_pre_hook(idx: torch.Tensor, a: float):
            def pre_hook(module, input):
                z = input[0].clone()
                z[..., idx.to(z.device)] *= a
                return (z,) + input[1:]

            return pre_hook

        handles.append(down_proj.register_forward_pre_hook(make_pre_hook(indices, alpha)))

    try:
        with torch.no_grad():
            out = model(**tokens)
    finally:
        for h in handles:
            h.remove()

    return out.logits[0, -1, :].detach().float().cpu()