Last active
April 14, 2023 22:02
-
-
Save norabelrose/f28a308cbb79f46068379eeecc0d28d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from elk.metrics import to_one_hot | |
| from elk.training import Classifier | |
| from scipy.optimize import brentq | |
| from sklearn.datasets import make_classification | |
| from sklearn.linear_model import LogisticRegression | |
| from torch import Tensor | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| @dataclass | |
| class RlaceResult: | |
| """The result of applying R-LACE.""" | |
| P: Tensor | |
| """The orthogonal projection matrix.""" | |
| P_relaxed: Tensor | |
| """The relaxed projection matrix.""" | |
| subspace: Tensor | |
| """The subspace erased by the projection matrix.""" | |
| clf: Classifier | |
| """The best-response classifier.""" | |
| clf_loss: float | |
| """The loss of the best-response classifier.""" | |
| @torch.no_grad() | |
| def fantope_project(A: Tensor, d: int = 1) -> Tensor: | |
| """Project `A` to the Fantope.""" | |
| L, Q = torch.linalg.eigh((A + A.T) / 2) | |
| # Solve the eigenvalue constraint on the CPU | |
| L_cpu = L.cpu() | |
| L -= brentq( | |
| lambda theta: torch.clamp(L_cpu - theta, 0, 1).sum() - d, | |
| a=L_cpu.max(), b=L_cpu.min() - 1 | |
| ) | |
| return Q @ L.clamp(0, 1).diag() @ Q.T | |
| def sal( | |
| X: torch.Tensor, | |
| y: torch.Tensor, | |
| num_classes: int, | |
| rank: int = 1, | |
| ): | |
| """Spectral Attribute Removal <https://arxiv.org/abs/2203.07893>.""" | |
| # Compute the direction of highest covariance with the labels | |
| # and use this to initialize the projection matrix. This usually | |
| # gets us most of the way to the optimal solution. | |
| y_one_hot = to_one_hot(y, num_classes).float() if num_classes > 2 else y | |
| cross_cov = (X - X.mean(0)).T @ (y_one_hot - y_one_hot.mean(0)) / n | |
| if num_classes > 2: | |
| u, _, _ = torch.svd_lowrank(cross_cov, q=rank) | |
| else: | |
| # We can skip the SVD entirely for binary classification | |
| u = F.normalize(cross_cov, dim=0).unsqueeze(1) | |
| return u @ u.T | |
| def rlace( | |
| X: torch.Tensor, | |
| y: torch.Tensor, | |
| rank: int = 1, | |
| *, | |
| max_iter: int = 100, | |
| lr: float = 1e-2, | |
| tolerance_grad: float = 1e-5, | |
| tolerance_loss: float = 1e-2, | |
| ) -> RlaceResult: | |
| """ | |
| Apply Relaxed Linear Adversarial Concept Erasure (R-LACE) to `X` and `y`. | |
| R-LACE locates a rank k projection matrix P which maximizes the loss of the | |
| optimal classifier on the projected data. Method from Ravfogel et al. (2022) | |
| <https://arxiv.org/abs/2201.12091>. | |
| Args: | |
| X: The data matrix (n x d) | |
| y: The labels (n) | |
| rank: The rank of the projection matrix | |
| max_iter: The maximum number of iterations | |
| lr: The learning rate for the projection matrix | |
| tolerance_loss: How close the classifier loss must be to the random baseline | |
| before we break. | |
| tolerance_grad: The tolerance for the squared gradient norm | |
| """ | |
| if X.ndim != 2: | |
| raise ValueError("X must be a 2D tensor.") | |
| if y.ndim != 1: | |
| raise ValueError("y must be a 1D tensor.") | |
| n, d = X.shape | |
| if n < d: | |
| raise ValueError("Must have n >= d.") | |
| if n != len(y): | |
| raise ValueError("Number of labels must match number of rows in X.") | |
| class_sizes = torch.bincount(y.long()) | |
| num_classes = len(class_sizes) | |
| if num_classes == 1: | |
| raise ValueError("Must have at least two classes.") | |
| elif num_classes == 2: | |
| loss_fn = F.binary_cross_entropy_with_logits | |
| y = y.float() | |
| else: | |
| loss_fn = F.cross_entropy | |
| y = y.long() | |
| # Compute entropy of the labels | |
| fracs = class_sizes / n | |
| eps = torch.finfo(fracs.dtype).eps | |
| H = -torch.sum(fracs * fracs.add(eps).log()) | |
| # Initialize with Spectral Attribute Removal | |
| P = sal(X, y, num_classes, rank) | |
| P.requires_grad = True | |
| # We use a small learning rate for the projection matrix instead of strong Wolfe | |
| # line search because the projection matrix is usually quite close to the optimal | |
| # solution at initialization, and line search seems to end up overshooting and | |
| # causing divergence. | |
| adv_opt = torch.optim.LBFGS([P], lr=lr, tolerance_grad=1e-4) | |
| clf = Classifier(d, num_classes=num_classes, device=X.device) | |
| def adv_closure(): | |
| adv_opt.zero_grad() | |
| P.data = fantope_project(P, rank) | |
| I = torch.eye(d, device=X.device) | |
| loss_P = -loss_fn(clf(X @ (I - P)).squeeze(), y) | |
| loss_P.backward() | |
| return float(loss_P) | |
| # "best" here means HIGHEST classifier loss; we're trying to maximize the loss of | |
| # the best-response classifier. We want to ensure that even if we start diverging | |
| # for some reason, we can still recover the best solution we've seen. | |
| best_loss: float = -torch.inf | |
| best_P = P.detach().clone() | |
| for _ in range(max_iter): | |
| # Alternate between optimizing the projection matrix and the classifier | |
| clf.requires_grad_(True) | |
| clf_loss = clf.fit( | |
| X @ (torch.eye(d, device=X.device) - P.detach()), y, | |
| l2_penalty=0.0 | |
| ) | |
| if clf_loss > best_loss: | |
| best_loss = clf_loss | |
| best_P.copy_(P.detach()) | |
| # Check if we've reached the random baseline | |
| if H - clf_loss < tolerance_loss: | |
| break | |
| clf.requires_grad_(False) | |
| adv_opt.step(adv_closure) | |
| # Check if we're very close to a saddle point | |
| grads = [p.grad for p in clf.parameters()] + [P.grad] | |
| grad_norm_sq = torch.cat([g.view(-1) for g in grads if g is not None]).square().sum() | |
| if grad_norm_sq < tolerance_grad: | |
| break | |
| # Make P an actual orthogonal projection matrix | |
| _, U = torch.linalg.eigh(best_P) | |
| U = U.T | |
| W = U[-rank:] | |
| P_final = torch.eye(d, device=W.device) - W.T @ W | |
| return RlaceResult( | |
| P=P_final.detach(), | |
| P_relaxed=torch.eye(d, device=U.device) - best_P, | |
| subspace=W.detach(), | |
| clf=clf, | |
| clf_loss=best_loss, | |
| ) | |
| def get_majority_acc(y): | |
| """Get the majority accuracy of a set of labels.""" | |
| from collections import Counter | |
| c = Counter(y) | |
| fracts = [v / sum(c.values()) for v in c.values()] | |
| maj = max(fracts) | |
| return maj | |
| if __name__ == "__main__": | |
| # create a synthetic dataset | |
| n, dim = 15000, 200 | |
| X, y = make_classification(n, dim, n_classes=2) | |
| l_train = int(0.6 * n) | |
| X_train, y_train = X[:l_train], y[:l_train] | |
| X_dev, y_dev = X[l_train:], y[l_train:] | |
| rank = 3 | |
| output = rlace( | |
| torch.from_numpy(X_train).float(), | |
| torch.from_numpy(y_train).float(), | |
| rank=rank | |
| ) | |
| # train a classifier | |
| P_svd = output.P.numpy() | |
| P_relaxed = output.P_relaxed.numpy() | |
| model = LogisticRegression().fit(X_train[:], y_train[:]) | |
| score_original = model.score(X_dev, y_dev) | |
| model = LogisticRegression().fit(X_train[:] @ P_relaxed, y_train[:]) | |
| score_projected_no_svd = model.score(X_dev @ P_relaxed, y_dev) | |
| model = LogisticRegression().fit(X_train[:] @ P_svd, y_train[:]) | |
| score_projected_svd_dev = model.score(X_dev @ P_svd, y_dev) | |
| score_projected_svd_train = model.score(X_train @ P_svd, y_train) | |
| maj_acc_dev = get_majority_acc(y_dev) | |
| maj_acc_train = get_majority_acc(y_train) | |
| print("===================================================") | |
| print( | |
| f"Original Acc, dev: {score_original * 100:.3f}%; " | |
| f"Acc, relaxed projection, dev: {score_projected_no_svd * 100:.3f}%; " | |
| f"Acc, orth. projection, train: {score_projected_svd_train * 100:.3f}%; " | |
| f"Acc, orth. projection, dev: {score_projected_svd_dev * 100:.3f}%" | |
| ) | |
| print(f"Majority Acc, dev: {maj_acc_dev * 100:.3f} %") | |
| print(f"Majority Acc, train: {maj_acc_train * 100:.3f} %") | |
| print( | |
| f"Gap, dev: {np.abs(maj_acc_dev - score_projected_svd_dev) * 100:.3f} %" | |
| ) | |
| print( | |
| f"Gap, train: {np.abs(maj_acc_train - score_projected_svd_train) * 100:.3f} %" | |
| ) | |
| print("===================================================") | |
| eigs_before_svd, _ = np.linalg.eigh(P_relaxed) | |
| print(f"Eigenvalues, before SVD: {eigs_before_svd}") | |
| eigs_after_svd, _ = np.linalg.eigh(P_svd) | |
| print(f"Eigenvalues, after SVD: {eigs_after_svd}") | |
| eps = 1e-6 | |
| assert np.abs((eigs_after_svd > eps).sum() - (dim - rank)) < eps |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment