Created
October 24, 2023 00:36
-
-
Save norabelrose/6af39e9e688622013926ab4c1b820c73 to your computer and use it in GitHub Desktop.
training code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import pairwise | |
from typing import Literal | |
import pytorch_lightning as pl | |
import torch | |
import torchmetrics as tm | |
import torchvision as tv | |
from torch import nn | |
from torch.optim import RAdam | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
class Mlp(pl.LightningModule): | |
def __init__(self, k, h=512, **kwargs): | |
super().__init__() | |
self.save_hyperparameters() | |
self.build_net() | |
self.train_acc = tm.Accuracy("multiclass", num_classes=k) | |
self.val_acc = tm.Accuracy("multiclass", num_classes=k) | |
self.test_acc = tm.Accuracy("multiclass", num_classes=k) | |
def build_net(self): | |
sizes = [3 * 32 * 32] + [self.hparams["h"]] * 4 | |
self.net = nn.Sequential( | |
*[ | |
MlpBlock( | |
in_dim, | |
out_dim, | |
device=self.device, | |
dtype=self.dtype, | |
residual=True, | |
act="gelu", | |
) | |
for in_dim, out_dim in pairwise(sizes) | |
] | |
) | |
self.net.append(nn.Linear(self.hparams["h"], self.hparams["k"])) | |
def forward(self, x): | |
return self.net(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = torch.nn.functional.cross_entropy(y_hat, y) | |
self.log("train_loss", loss) | |
self.train_acc(y_hat, y) | |
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False) | |
# Log the norm of the weights | |
fc = self.net[-1] if isinstance(self.net, nn.Sequential) else None | |
if isinstance(fc, nn.Linear): | |
self.log("weight_norm", fc.weight.data.norm()) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = torch.nn.functional.cross_entropy(y_hat, y) | |
self.val_acc(y_hat, y) | |
self.log("val_loss", loss) | |
self.log("val_acc", self.val_acc, prog_bar=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = torch.nn.functional.cross_entropy(y_hat, y) | |
self.test_acc(y_hat, y) | |
self.log("test_loss", loss) | |
self.log("test_acc", self.test_acc, prog_bar=True) | |
return loss | |
def configure_optimizers(self): | |
opt = RAdam(self.parameters(), lr=1e-4) | |
return [opt], [CosineAnnealingLR(opt, T_max=200)] | |
class MlpMixer(Mlp): | |
def build_net(self): | |
from mlp_mixer_pytorch import MLPMixer | |
self.net = MLPMixer( | |
image_size=32, | |
channels=3, | |
patch_size=self.hparams.get("patch_size", 4), | |
num_classes=self.hparams["k"], | |
dim=512, | |
depth=6, | |
dropout=0.1, | |
) | |
class ResNet(Mlp): | |
def build_net(self): | |
self.net = tv.models.resnet18(pretrained=False, num_classes=self.hparams["k"]) | |
class ViT(MlpMixer): | |
def build_net(self): | |
from vit_pytorch import ViT | |
self.net = ViT( | |
image_size=32, | |
patch_size=self.hparams.get("patch_size", 4), | |
num_classes=self.hparams["k"], | |
dim=512, | |
depth=6, | |
heads=8, | |
mlp_dim=512, | |
dropout=0.1, | |
emb_dropout=0.1, | |
) | |
class MlpBlock(nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
device=None, | |
dtype=None, | |
residual: bool = True, | |
*, | |
act: Literal["relu", "gelu"] = "relu", | |
norm: Literal["batch", "layer"] = "batch", | |
): | |
super().__init__() | |
self.linear1 = nn.Linear( | |
in_features, out_features, bias=False, device=device, dtype=dtype | |
) | |
self.linear2 = nn.Linear( | |
out_features, out_features, bias=False, device=device, dtype=dtype | |
) | |
self.act_fn = nn.ReLU() if act == "relu" else nn.GELU() | |
norm_cls = nn.BatchNorm1d if norm == "batch" else nn.LayerNorm | |
self.bn1 = norm_cls(out_features, device=device, dtype=dtype) | |
self.bn2 = norm_cls(out_features, device=device, dtype=dtype) | |
self.downsample = ( | |
nn.Linear(in_features, out_features, bias=False, device=device, dtype=dtype) | |
if in_features != out_features | |
else None | |
) | |
self.residual = residual | |
def forward(self, x): | |
identity = x | |
out = self.linear1(x) | |
out = self.bn1(out) | |
out = self.act_fn(out) | |
out = self.linear2(out) | |
out = self.bn2(out) | |
if self.downsample is not None: | |
identity = self.downsample(identity) | |
if self.residual: | |
out += identity | |
out = self.act_fn(out) | |
return out | |
from argparse import ArgumentParser | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Callable, Sized | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import Callback | |
from pytorch_lightning.loggers import WandbLogger | |
from torch import Tensor, nn | |
from torch.utils.data import Dataset, random_split | |
from torchvision.datasets import CIFAR10 | |
from tqdm.auto import tqdm | |
from concept_erasure import LeaceFitter, OracleFitter, QuadraticFitter | |
# Use faster matmul precision | |
torch.set_float32_matmul_precision("high") | |
@dataclass | |
class LogSpacedCheckpoint(Callback): | |
"""Save checkpoints at log-spaced intervals""" | |
dirpath: str | |
base: float = 2.0 | |
next: int = 1 | |
def on_train_batch_end(self, trainer: Trainer, *_): | |
if trainer.global_step >= self.next: | |
self.next = round(self.next * self.base) | |
trainer.save_checkpoint(self.dirpath + f"/step={trainer.global_step}.ckpt") | |
class LeacedDataset(Dataset): | |
"""Wrapper for a dataset of (X, Z) pairs that erases Z from X""" | |
def __init__( | |
self, | |
inner: Dataset[tuple[Tensor, ...]], | |
eraser: Callable, | |
transform: Callable[[Tensor], Tensor] = lambda x: x, | |
p: float = 1.0, | |
): | |
# Pylance actually keeps track of the intersection type | |
assert isinstance(inner, Sized), "inner dataset must be sized" | |
assert len(inner) > 0, "inner dataset must be non-empty" | |
self.cache: dict[int, tuple[Tensor, Tensor]] = {} | |
self.dataset = inner | |
self.eraser = eraser | |
self.transform = transform | |
self.p = p | |
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: | |
if idx not in self.cache: | |
x, z = self.dataset[idx] | |
x = self.eraser(x, z) | |
self.cache[idx] = x, z | |
else: | |
x, z = self.cache[idx] | |
return self.transform(x), z | |
def __len__(self): | |
return len(self.dataset) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("name", type=str) | |
parser.add_argument( | |
"--eraser", | |
type=str, | |
choices=("none", "leace", "oleace", "qleace"), | |
default="none", | |
) | |
parser.add_argument( | |
"--patch-size", type=int, default=4, help="patch size for mixer and resmlp" | |
) | |
parser.add_argument("--net", type=str, choices=("mixer", "resmlp", "resnet", "vit")) | |
args = parser.parse_args() | |
# Split the "train" set into train and validation | |
nontest = CIFAR10("/home/nora/Data/cifar10", download=True, transform=T.ToTensor()) | |
train, val = random_split(nontest, [0.9, 0.1]) | |
X = torch.from_numpy(nontest.data).div(255) | |
Y = torch.tensor(nontest.targets) | |
# Test set is entirely separate | |
test = CIFAR10( | |
"/home/nora/Data/cifar10-test", | |
download=True, | |
train=False, | |
transform=T.ToTensor(), | |
) | |
k = 10 # Number of classes | |
final = nn.Identity() if args.net in ("mixer", "resnet", "vit") else nn.Flatten(0) | |
train_trf = T.Compose( | |
[ | |
T.RandomHorizontalFlip(), | |
T.RandomCrop(32, padding=4), | |
final, | |
] | |
) | |
if args.eraser != "none": | |
cache_dir = Path(f"/home/nora/Data/cifar10-{args.eraser}.pt") | |
if cache_dir.exists(): | |
eraser = torch.load(cache_dir) | |
print("Loaded cached eraser") | |
else: | |
print("No eraser cached; fitting a fresh one") | |
cls = { | |
"leace": LeaceFitter, | |
"oleace": OracleFitter, | |
"qleace": QuadraticFitter, | |
}[args.eraser] | |
fitter = cls(3 * 32 * 32, k, dtype=torch.float64) | |
for x, y in tqdm(train): | |
y = torch.as_tensor(y).view(1) | |
if args.eraser != "qleace": | |
y = F.one_hot(y, k) | |
fitter.update(x.view(1, -1), y) | |
eraser = fitter.eraser | |
torch.save(eraser, cache_dir) | |
print(f"Saved eraser to {cache_dir}") | |
else: | |
def eraser(x, y): | |
return x | |
train = LeacedDataset(train, eraser, transform=train_trf) | |
val = LeacedDataset(val, eraser, transform=final) | |
test = LeacedDataset(test, eraser, transform=final) | |
# Create the data module | |
dm = pl.LightningDataModule.from_datasets( | |
train, val, test, batch_size=128, num_workers=8 | |
) | |
model_cls = { | |
"mixer": MlpMixer, | |
"resmlp": Mlp, | |
"resnet": ResNet, | |
"vit": ViT, | |
}[args.net] | |
model = model_cls(k, patch_size=args.patch_size) | |
checkpointer = LogSpacedCheckpoint(f"/home/nora/Data/cifar-ckpts/{args.name}") | |
trainer = pl.Trainer( | |
callbacks=[checkpointer], | |
logger=WandbLogger( | |
name=args.name, project="concept-erasure", entity="eleutherai" | |
), | |
max_epochs=200, | |
) | |
trainer.fit(model, dm) | |
trainer.test(model, dm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment