Last active
September 30, 2023 07:36
-
-
Save norabelrose/6c69d76d00ab5b77734203a3c4bf5162 to your computer and use it in GitHub Desktop.
messy cifar leace testing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import ArgumentParser | |
from typing import Any, Callable, Protocol, Sized, Type | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn.functional as F | |
import torchmetrics as tm | |
import torchvision as tv | |
from concept_erasure import LeaceFitter, OracleFitter, QuadraticFitter | |
from pytorch_lightning.loggers import WandbLogger | |
from torch import Tensor, nn | |
from torch.utils.data import Dataset, TensorDataset, random_split | |
from torchvision.datasets import CIFAR10 | |
class Mlp(pl.LightningModule): | |
def __init__(self, k, h=128, eraser=None): | |
super().__init__() | |
self.h = h | |
self.k = k | |
self.fitter = OracleFitter(32 * 32 * 3, k, device="cuda", dtype=torch.float64) | |
self.eraser = eraser | |
self.layers = torch.nn.Sequential( | |
torch.nn.Linear(32 * 32 * 3, h), | |
torch.nn.ReLU(), | |
torch.nn.Linear(h, h), | |
torch.nn.ReLU(), | |
torch.nn.Linear(h, k), | |
) | |
self.train_acc = tm.Accuracy("multiclass", num_classes=k) | |
self.val_acc = tm.Accuracy("multiclass", num_classes=k) | |
self.test_acc = tm.Accuracy("multiclass", num_classes=k) | |
def forward(self, x): | |
return self.layers(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
x = self.eraser(x, y) if self.eraser else x | |
self.fitter.update(x, y) | |
y_hat = self(x.view(x.shape[0], -1)) | |
loss = torch.nn.functional.cross_entropy(y_hat, y) | |
self.log("train_loss", loss) | |
self.train_acc(y_hat, y) | |
self.log( | |
"train_acc", self.train_acc, on_epoch=True, on_step=False | |
) | |
self.log( | |
"sigma_xz_norm", self.fitter.sigma_xz.norm() | |
) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
x = self.eraser(x, y).type_as(x) if self.eraser else x | |
y_hat = self(x.view(x.shape[0], -1)) | |
loss = torch.nn.functional.cross_entropy(y_hat, y) | |
self.val_acc(y_hat, y) | |
self.log("val_loss", loss) | |
self.log("val_acc", self.val_acc, prog_bar=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
x = self.eraser(x, y).type_as(x) if self.eraser else x | |
y_hat = self(x.view(x.shape[0], -1)) | |
loss = torch.nn.functional.cross_entropy(y_hat, y) | |
self.test_acc(y_hat, y) | |
self.log("test_loss", loss) | |
self.log("test_acc", self.test_acc, prog_bar=True) | |
return loss | |
def configure_optimizers(self): | |
return torch.optim.AdamW(self.parameters()) | |
class LeacedDataset(Dataset): | |
"""Wrapper for a dataset of (X, Z) pairs that erases Z from X""" | |
def __init__( | |
self, | |
inner: Dataset[tuple[Tensor, ...]], | |
eraser: Callable, | |
transform: Callable[[Tensor], Tensor] = lambda x: x, | |
): | |
# Pylance actually keeps track of the intersection type | |
assert isinstance(inner, Sized), "inner dataset must be sized" | |
assert len(inner) > 0, "inner dataset must be non-empty" | |
self.dataset = inner | |
self.eraser = eraser | |
self.transform = transform | |
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: | |
x, z = self.dataset[idx] | |
# Erase BEFORE transforming | |
x = self.eraser(x, z) | |
return self.transform(x), z | |
def __len__(self): | |
return len(self.dataset) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("name", type=str) | |
parser.add_argument( | |
"--noise-scale", | |
type=float, | |
default=0.0, | |
help="Std of Gaussian noise to add to the statistics", | |
) | |
parser.add_argument( | |
"--eraser", type=str, choices=("none", "leace", "oleace", "qleace") | |
) | |
args = parser.parse_args() | |
# Split the "train" set into train and validation | |
nontest = CIFAR10( | |
"/home/nora/Data/cifar10", download=True, transform=tv.transforms.ToTensor() | |
) | |
train, val = random_split(nontest, [0.9, 0.1]) | |
# Test set is entirely separate | |
test = CIFAR10( | |
"/home/nora/Data/cifar10-test", | |
download=True, | |
train=False, | |
transform=tv.transforms.ToTensor(), | |
) | |
k = 10 # Number of classes | |
train_trf = tv.transforms.Compose([ | |
tv.transforms.RandomHorizontalFlip(), | |
tv.transforms.RandomCrop(32, padding=4), | |
nn.Flatten() | |
]) | |
if args.eraser != "none": | |
cls = { | |
"leace": LeaceFitter, | |
"oleace": OracleFitter, | |
"qleace": QuadraticFitter, | |
}[args.eraser] | |
fitter = cls(3 * 32 * 32, k, dtype=torch.float64) | |
for x, y in train: | |
y = torch.as_tensor(y).view(1) | |
if args.eraser != "qleace": | |
y = F.one_hot(y, k) | |
fitter.update(x.view(1, -1), y) | |
eraser = fitter.eraser | |
else: | |
eraser = lambda x, y: x | |
train = LeacedDataset(train, eraser, transform=train_trf) | |
val = LeacedDataset(val, eraser, transform=nn.Flatten()) | |
test = LeacedDataset(test, eraser, transform=nn.Flatten()) | |
# Create the data module | |
dm = pl.LightningDataModule.from_datasets(train, val, test, batch_size=128, num_workers=8) | |
# Create the model here so that we don't advance the PRNG before the eraser | |
model = Mlp(10) | |
trainer = pl.Trainer( | |
callbacks=[ | |
# EarlyStopping(monitor="val_loss", patience=5), | |
], | |
logger=WandbLogger(name=args.name, project="mdl", entity="eleutherai"), | |
max_epochs=200, | |
) | |
trainer.fit(model, dm) | |
trainer.test(model, dm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment