Created
January 5, 2023 08:39
-
-
Save norabelrose/89594b6e297b2c5505da24e64016a4a5 to your computer and use it in GitHub Desktop.
Causal basis extraction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from einops import rearrange | |
from tqdm.auto import tqdm, trange | |
from transformers import PreTrainedModel | |
from typing import ( | |
Literal, NamedTuple, Optional, Union, Sequence | |
) | |
from white_box import TunedLens | |
from white_box.causal import ablate_subspace, remove_subspace | |
from white_box.nn import Decoder | |
from white_box.utils import ( | |
maybe_shift_labels, | |
maybe_shift_preds | |
) | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
import torch as th | |
import warnings | |
class CausalBasis(NamedTuple): | |
"""An ordered orthonormal basis for a subspace of activations. | |
Attributes: | |
basis: A matrix of shape (d, k) where d is the ambient dimension | |
and k is the dimension of the subspace. The columns of this | |
matrix are basis vectors, ordered by decreasing energy. | |
energies: A vector of shape (k,) containing the energies of the | |
basis vectors. Each energy is the expected KL divergence of | |
the post-intervention logits wrt the control logits when the | |
corresponding basis vector is ablated. | |
""" | |
basis: th.Tensor | |
energies: th.Tensor | |
def extract_causal_bases( | |
model: Union[Decoder, PreTrainedModel, TunedLens], | |
inputs: Union[th.Tensor, Sequence[th.Tensor]], | |
k: int, | |
*, | |
max_iter: int = 100, | |
min_energy: float = 0.0, | |
mode: Literal["mean", "resample", "zero"] = "mean", | |
) -> list[CausalBasis]: | |
"""Compute a causal basis for each layer of a model. | |
Args: | |
model: A model to compute causal bases for. This can be a | |
`Decoder`, a `PreTrainedModel`, or a `TunedLens` instance. If it is a | |
`Decoder` or a `TunedLens, then `inputs` must be a Sequence of hidden states. | |
If it is a `PreTrainedModel`, then it should be must be provided. | |
inputs: A sequence of hidden states from the model. | |
k: The number of basis vectors to compute for each layer. | |
max_iter: The maximum number of iterations to run L-BFGS for each vector. | |
min_energy: The minimum energy a basis vector needs in order to be included. | |
""" | |
model.requires_grad_(False) | |
log_p = None | |
if isinstance(model, PreTrainedModel): | |
assert isinstance(inputs, th.Tensor) and not th.is_floating_point(inputs) | |
device = model.device | |
dtype = model.dtype | |
d = model.config.hidden_size | |
num_layers = model.config.num_hidden_layers | |
outputs = model(inputs, output_hidden_states=True) | |
hiddens: list[th.Tensor] = outputs.hidden_states[:-1] | |
log_p = outputs.logits.log_softmax(-1) | |
else: | |
assert isinstance(inputs, Sequence) and len(inputs) > 1 | |
device = inputs[0].device | |
dtype = inputs[0].dtype | |
d = inputs[0].shape[-1] | |
hiddens = [h.detach() for h in inputs] | |
num_layers = len(hiddens) - 1 | |
assert k <= d | |
if k < 1: | |
k = d | |
bases: list[CausalBasis] = [] | |
I = th.eye(d, device=device, dtype=dtype) | |
tol = th.finfo(dtype).eps | |
# Outer loop iterates over layers | |
pbar = trange(num_layers * k) | |
for i in range(num_layers): | |
basis = CausalBasis( | |
I[:, :k].clone(), th.zeros(k, device=device) | |
) | |
if isinstance(model, Decoder): | |
log_p = model(hiddens[i]).log_softmax(-1) | |
elif isinstance(model, TunedLens): | |
log_p = model(hiddens[i], i).log_softmax(-1) | |
else: | |
raise NotImplementedError() | |
# Inner loop iterates over directions | |
p = log_p.exp() | |
for j in range(k): | |
pbar.set_description(f"Layer {i + 1}/{num_layers}, vector {j + 1}/{k}") | |
# Construct the operator for projecting away from the previously | |
# identified basis vectors | |
if j: | |
A = basis.basis[:, :j] | |
proj = I - A @ A.T | |
else: | |
proj = I | |
def project(x: th.Tensor) -> th.Tensor: | |
# Project away from previously identified basis vectors | |
x = proj @ x | |
# Project to the unit sphere | |
return x / (x.norm() + th.finfo(x.dtype).eps) | |
basis.basis[:, j] = project(basis.basis[:, j]) | |
v = th.nn.Parameter(basis.basis[:, j]) | |
opt = th.optim.LBFGS( | |
[v], | |
line_search_fn="strong_wolfe", | |
max_iter=max_iter, | |
tolerance_change=tol, | |
) | |
last_energy = th.tensor(0.0, device=device) | |
nfev = 0 | |
def closure(): | |
nonlocal last_energy, nfev | |
nfev += 1 | |
opt.zero_grad() | |
v_ = project(v) | |
if isinstance(model, PreTrainedModel): | |
with ablate_subspace(model, v_, i, mode=mode, orthonormal=True): | |
log_q = model(inputs).logits.log_softmax(-1) | |
else: | |
h_ = remove_subspace(hiddens[i], v_, mode=mode, orthonormal=True) | |
if isinstance(model, Decoder): | |
log_q = model(h_).log_softmax(dim=-1) | |
elif isinstance(model, TunedLens): | |
log_q = model(h_, i).log_softmax(dim=-1) | |
else: | |
raise TypeError(f"Unknown lens type {type(model)}") | |
loss = -th.sum(p * (log_p - log_q), dim=-1).mean() | |
loss.backward() | |
assert v.grad is not None | |
if dist.is_initialized(): | |
dist.all_reduce(v.grad) | |
v.grad /= dist.get_world_size() | |
last_energy = -loss.detach() | |
pbar.set_postfix( | |
energy=last_energy.item(), | |
nfev=nfev | |
) | |
if not loss.isfinite(): | |
print("Loss is not finite") | |
loss = th.tensor(0.0, device=device) | |
opt.zero_grad() | |
return loss | |
opt.step(closure) # type: ignore[arg-type] | |
basis.basis[:, j] = project(v.data) | |
basis.energies[j] = last_energy | |
# If the energy is too low, stop looking for more basis vectors | |
if last_energy < min_energy: | |
basis = CausalBasis(basis.basis[:, :j + 1], basis.energies[:j + 1]) | |
pbar.update(k - j) | |
pbar.write("Hit minimum energy threshold; skipping remaining basis vectors") | |
break | |
pbar.update() | |
# Sanity check the energies for monotonicity | |
indices = basis.energies.argsort(descending=True) | |
bases.append( | |
CausalBasis(basis.basis[:, indices], basis.energies[indices]) | |
) | |
return bases |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment