Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Created January 5, 2023 08:39
Show Gist options
  • Save norabelrose/89594b6e297b2c5505da24e64016a4a5 to your computer and use it in GitHub Desktop.
Save norabelrose/89594b6e297b2c5505da24e64016a4a5 to your computer and use it in GitHub Desktop.
Causal basis extraction
from copy import deepcopy
from einops import rearrange
from tqdm.auto import tqdm, trange
from transformers import PreTrainedModel
from typing import (
Literal, NamedTuple, Optional, Union, Sequence
)
from white_box import TunedLens
from white_box.causal import ablate_subspace, remove_subspace
from white_box.nn import Decoder
from white_box.utils import (
maybe_shift_labels,
maybe_shift_preds
)
import torch.distributed as dist
import torch.nn.functional as F
import torch as th
import warnings
class CausalBasis(NamedTuple):
"""An ordered orthonormal basis for a subspace of activations.
Attributes:
basis: A matrix of shape (d, k) where d is the ambient dimension
and k is the dimension of the subspace. The columns of this
matrix are basis vectors, ordered by decreasing energy.
energies: A vector of shape (k,) containing the energies of the
basis vectors. Each energy is the expected KL divergence of
the post-intervention logits wrt the control logits when the
corresponding basis vector is ablated.
"""
basis: th.Tensor
energies: th.Tensor
def extract_causal_bases(
model: Union[Decoder, PreTrainedModel, TunedLens],
inputs: Union[th.Tensor, Sequence[th.Tensor]],
k: int,
*,
max_iter: int = 100,
min_energy: float = 0.0,
mode: Literal["mean", "resample", "zero"] = "mean",
) -> list[CausalBasis]:
"""Compute a causal basis for each layer of a model.
Args:
model: A model to compute causal bases for. This can be a
`Decoder`, a `PreTrainedModel`, or a `TunedLens` instance. If it is a
`Decoder` or a `TunedLens, then `inputs` must be a Sequence of hidden states.
If it is a `PreTrainedModel`, then it should be must be provided.
inputs: A sequence of hidden states from the model.
k: The number of basis vectors to compute for each layer.
max_iter: The maximum number of iterations to run L-BFGS for each vector.
min_energy: The minimum energy a basis vector needs in order to be included.
"""
model.requires_grad_(False)
log_p = None
if isinstance(model, PreTrainedModel):
assert isinstance(inputs, th.Tensor) and not th.is_floating_point(inputs)
device = model.device
dtype = model.dtype
d = model.config.hidden_size
num_layers = model.config.num_hidden_layers
outputs = model(inputs, output_hidden_states=True)
hiddens: list[th.Tensor] = outputs.hidden_states[:-1]
log_p = outputs.logits.log_softmax(-1)
else:
assert isinstance(inputs, Sequence) and len(inputs) > 1
device = inputs[0].device
dtype = inputs[0].dtype
d = inputs[0].shape[-1]
hiddens = [h.detach() for h in inputs]
num_layers = len(hiddens) - 1
assert k <= d
if k < 1:
k = d
bases: list[CausalBasis] = []
I = th.eye(d, device=device, dtype=dtype)
tol = th.finfo(dtype).eps
# Outer loop iterates over layers
pbar = trange(num_layers * k)
for i in range(num_layers):
basis = CausalBasis(
I[:, :k].clone(), th.zeros(k, device=device)
)
if isinstance(model, Decoder):
log_p = model(hiddens[i]).log_softmax(-1)
elif isinstance(model, TunedLens):
log_p = model(hiddens[i], i).log_softmax(-1)
else:
raise NotImplementedError()
# Inner loop iterates over directions
p = log_p.exp()
for j in range(k):
pbar.set_description(f"Layer {i + 1}/{num_layers}, vector {j + 1}/{k}")
# Construct the operator for projecting away from the previously
# identified basis vectors
if j:
A = basis.basis[:, :j]
proj = I - A @ A.T
else:
proj = I
def project(x: th.Tensor) -> th.Tensor:
# Project away from previously identified basis vectors
x = proj @ x
# Project to the unit sphere
return x / (x.norm() + th.finfo(x.dtype).eps)
basis.basis[:, j] = project(basis.basis[:, j])
v = th.nn.Parameter(basis.basis[:, j])
opt = th.optim.LBFGS(
[v],
line_search_fn="strong_wolfe",
max_iter=max_iter,
tolerance_change=tol,
)
last_energy = th.tensor(0.0, device=device)
nfev = 0
def closure():
nonlocal last_energy, nfev
nfev += 1
opt.zero_grad()
v_ = project(v)
if isinstance(model, PreTrainedModel):
with ablate_subspace(model, v_, i, mode=mode, orthonormal=True):
log_q = model(inputs).logits.log_softmax(-1)
else:
h_ = remove_subspace(hiddens[i], v_, mode=mode, orthonormal=True)
if isinstance(model, Decoder):
log_q = model(h_).log_softmax(dim=-1)
elif isinstance(model, TunedLens):
log_q = model(h_, i).log_softmax(dim=-1)
else:
raise TypeError(f"Unknown lens type {type(model)}")
loss = -th.sum(p * (log_p - log_q), dim=-1).mean()
loss.backward()
assert v.grad is not None
if dist.is_initialized():
dist.all_reduce(v.grad)
v.grad /= dist.get_world_size()
last_energy = -loss.detach()
pbar.set_postfix(
energy=last_energy.item(),
nfev=nfev
)
if not loss.isfinite():
print("Loss is not finite")
loss = th.tensor(0.0, device=device)
opt.zero_grad()
return loss
opt.step(closure) # type: ignore[arg-type]
basis.basis[:, j] = project(v.data)
basis.energies[j] = last_energy
# If the energy is too low, stop looking for more basis vectors
if last_energy < min_energy:
basis = CausalBasis(basis.basis[:, :j + 1], basis.energies[:j + 1])
pbar.update(k - j)
pbar.write("Hit minimum energy threshold; skipping remaining basis vectors")
break
pbar.update()
# Sanity check the energies for monotonicity
indices = basis.energies.argsort(descending=True)
bases.append(
CausalBasis(basis.basis[:, indices], basis.energies[indices])
)
return bases
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment