Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active April 14, 2023 22:02
Show Gist options
  • Select an option

  • Save norabelrose/f28a308cbb79f46068379eeecc0d28d1 to your computer and use it in GitHub Desktop.

Select an option

Save norabelrose/f28a308cbb79f46068379eeecc0d28d1 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from elk.metrics import to_one_hot
from elk.training import Classifier
from scipy.optimize import brentq
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from torch import Tensor
import numpy as np
import torch
import torch.nn.functional as F
@dataclass
class RlaceResult:
"""The result of applying R-LACE."""
P: Tensor
"""The orthogonal projection matrix."""
P_relaxed: Tensor
"""The relaxed projection matrix."""
subspace: Tensor
"""The subspace erased by the projection matrix."""
clf: Classifier
"""The best-response classifier."""
clf_loss: float
"""The loss of the best-response classifier."""
@torch.no_grad()
def fantope_project(A: Tensor, d: int = 1) -> Tensor:
"""Project `A` to the Fantope."""
L, Q = torch.linalg.eigh((A + A.T) / 2)
# Solve the eigenvalue constraint on the CPU
L_cpu = L.cpu()
L -= brentq(
lambda theta: torch.clamp(L_cpu - theta, 0, 1).sum() - d,
a=L_cpu.max(), b=L_cpu.min() - 1
)
return Q @ L.clamp(0, 1).diag() @ Q.T
def sal(
X: torch.Tensor,
y: torch.Tensor,
num_classes: int,
rank: int = 1,
):
"""Spectral Attribute Removal <https://arxiv.org/abs/2203.07893>."""
# Compute the direction of highest covariance with the labels
# and use this to initialize the projection matrix. This usually
# gets us most of the way to the optimal solution.
y_one_hot = to_one_hot(y, num_classes).float() if num_classes > 2 else y
cross_cov = (X - X.mean(0)).T @ (y_one_hot - y_one_hot.mean(0)) / n
if num_classes > 2:
u, _, _ = torch.svd_lowrank(cross_cov, q=rank)
else:
# We can skip the SVD entirely for binary classification
u = F.normalize(cross_cov, dim=0).unsqueeze(1)
return u @ u.T
def rlace(
X: torch.Tensor,
y: torch.Tensor,
rank: int = 1,
*,
max_iter: int = 100,
lr: float = 1e-2,
tolerance_grad: float = 1e-5,
tolerance_loss: float = 1e-2,
) -> RlaceResult:
"""
Apply Relaxed Linear Adversarial Concept Erasure (R-LACE) to `X` and `y`.
R-LACE locates a rank k projection matrix P which maximizes the loss of the
optimal classifier on the projected data. Method from Ravfogel et al. (2022)
<https://arxiv.org/abs/2201.12091>.
Args:
X: The data matrix (n x d)
y: The labels (n)
rank: The rank of the projection matrix
max_iter: The maximum number of iterations
lr: The learning rate for the projection matrix
tolerance_loss: How close the classifier loss must be to the random baseline
before we break.
tolerance_grad: The tolerance for the squared gradient norm
"""
if X.ndim != 2:
raise ValueError("X must be a 2D tensor.")
if y.ndim != 1:
raise ValueError("y must be a 1D tensor.")
n, d = X.shape
if n < d:
raise ValueError("Must have n >= d.")
if n != len(y):
raise ValueError("Number of labels must match number of rows in X.")
class_sizes = torch.bincount(y.long())
num_classes = len(class_sizes)
if num_classes == 1:
raise ValueError("Must have at least two classes.")
elif num_classes == 2:
loss_fn = F.binary_cross_entropy_with_logits
y = y.float()
else:
loss_fn = F.cross_entropy
y = y.long()
# Compute entropy of the labels
fracs = class_sizes / n
eps = torch.finfo(fracs.dtype).eps
H = -torch.sum(fracs * fracs.add(eps).log())
# Initialize with Spectral Attribute Removal
P = sal(X, y, num_classes, rank)
P.requires_grad = True
# We use a small learning rate for the projection matrix instead of strong Wolfe
# line search because the projection matrix is usually quite close to the optimal
# solution at initialization, and line search seems to end up overshooting and
# causing divergence.
adv_opt = torch.optim.LBFGS([P], lr=lr, tolerance_grad=1e-4)
clf = Classifier(d, num_classes=num_classes, device=X.device)
def adv_closure():
adv_opt.zero_grad()
P.data = fantope_project(P, rank)
I = torch.eye(d, device=X.device)
loss_P = -loss_fn(clf(X @ (I - P)).squeeze(), y)
loss_P.backward()
return float(loss_P)
# "best" here means HIGHEST classifier loss; we're trying to maximize the loss of
# the best-response classifier. We want to ensure that even if we start diverging
# for some reason, we can still recover the best solution we've seen.
best_loss: float = -torch.inf
best_P = P.detach().clone()
for _ in range(max_iter):
# Alternate between optimizing the projection matrix and the classifier
clf.requires_grad_(True)
clf_loss = clf.fit(
X @ (torch.eye(d, device=X.device) - P.detach()), y,
l2_penalty=0.0
)
if clf_loss > best_loss:
best_loss = clf_loss
best_P.copy_(P.detach())
# Check if we've reached the random baseline
if H - clf_loss < tolerance_loss:
break
clf.requires_grad_(False)
adv_opt.step(adv_closure)
# Check if we're very close to a saddle point
grads = [p.grad for p in clf.parameters()] + [P.grad]
grad_norm_sq = torch.cat([g.view(-1) for g in grads if g is not None]).square().sum()
if grad_norm_sq < tolerance_grad:
break
# Make P an actual orthogonal projection matrix
_, U = torch.linalg.eigh(best_P)
U = U.T
W = U[-rank:]
P_final = torch.eye(d, device=W.device) - W.T @ W
return RlaceResult(
P=P_final.detach(),
P_relaxed=torch.eye(d, device=U.device) - best_P,
subspace=W.detach(),
clf=clf,
clf_loss=best_loss,
)
def get_majority_acc(y):
"""Get the majority accuracy of a set of labels."""
from collections import Counter
c = Counter(y)
fracts = [v / sum(c.values()) for v in c.values()]
maj = max(fracts)
return maj
if __name__ == "__main__":
# create a synthetic dataset
n, dim = 15000, 200
X, y = make_classification(n, dim, n_classes=2)
l_train = int(0.6 * n)
X_train, y_train = X[:l_train], y[:l_train]
X_dev, y_dev = X[l_train:], y[l_train:]
rank = 3
output = rlace(
torch.from_numpy(X_train).float(),
torch.from_numpy(y_train).float(),
rank=rank
)
# train a classifier
P_svd = output.P.numpy()
P_relaxed = output.P_relaxed.numpy()
model = LogisticRegression().fit(X_train[:], y_train[:])
score_original = model.score(X_dev, y_dev)
model = LogisticRegression().fit(X_train[:] @ P_relaxed, y_train[:])
score_projected_no_svd = model.score(X_dev @ P_relaxed, y_dev)
model = LogisticRegression().fit(X_train[:] @ P_svd, y_train[:])
score_projected_svd_dev = model.score(X_dev @ P_svd, y_dev)
score_projected_svd_train = model.score(X_train @ P_svd, y_train)
maj_acc_dev = get_majority_acc(y_dev)
maj_acc_train = get_majority_acc(y_train)
print("===================================================")
print(
f"Original Acc, dev: {score_original * 100:.3f}%; "
f"Acc, relaxed projection, dev: {score_projected_no_svd * 100:.3f}%; "
f"Acc, orth. projection, train: {score_projected_svd_train * 100:.3f}%; "
f"Acc, orth. projection, dev: {score_projected_svd_dev * 100:.3f}%"
)
print(f"Majority Acc, dev: {maj_acc_dev * 100:.3f} %")
print(f"Majority Acc, train: {maj_acc_train * 100:.3f} %")
print(
f"Gap, dev: {np.abs(maj_acc_dev - score_projected_svd_dev) * 100:.3f} %"
)
print(
f"Gap, train: {np.abs(maj_acc_train - score_projected_svd_train) * 100:.3f} %"
)
print("===================================================")
eigs_before_svd, _ = np.linalg.eigh(P_relaxed)
print(f"Eigenvalues, before SVD: {eigs_before_svd}")
eigs_after_svd, _ = np.linalg.eigh(P_svd)
print(f"Eigenvalues, after SVD: {eigs_after_svd}")
eps = 1e-6
assert np.abs((eigs_after_svd > eps).sum() - (dim - rank)) < eps
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment