Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active September 30, 2023 07:36
Show Gist options
  • Save norabelrose/6c69d76d00ab5b77734203a3c4bf5162 to your computer and use it in GitHub Desktop.
Save norabelrose/6c69d76d00ab5b77734203a3c4bf5162 to your computer and use it in GitHub Desktop.
messy cifar leace testing
from argparse import ArgumentParser
from typing import Any, Callable, Protocol, Sized, Type
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics as tm
import torchvision as tv
from concept_erasure import LeaceFitter, OracleFitter, QuadraticFitter
from pytorch_lightning.loggers import WandbLogger
from torch import Tensor, nn
from torch.utils.data import Dataset, TensorDataset, random_split
from torchvision.datasets import CIFAR10
class Mlp(pl.LightningModule):
def __init__(self, k, h=128, eraser=None):
super().__init__()
self.h = h
self.k = k
self.fitter = OracleFitter(32 * 32 * 3, k, device="cuda", dtype=torch.float64)
self.eraser = eraser
self.layers = torch.nn.Sequential(
torch.nn.Linear(32 * 32 * 3, h),
torch.nn.ReLU(),
torch.nn.Linear(h, h),
torch.nn.ReLU(),
torch.nn.Linear(h, k),
)
self.train_acc = tm.Accuracy("multiclass", num_classes=k)
self.val_acc = tm.Accuracy("multiclass", num_classes=k)
self.test_acc = tm.Accuracy("multiclass", num_classes=k)
def forward(self, x):
return self.layers(x)
def training_step(self, batch, batch_idx):
x, y = batch
x = self.eraser(x, y) if self.eraser else x
self.fitter.update(x, y)
y_hat = self(x.view(x.shape[0], -1))
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss)
self.train_acc(y_hat, y)
self.log(
"train_acc", self.train_acc, on_epoch=True, on_step=False
)
self.log(
"sigma_xz_norm", self.fitter.sigma_xz.norm()
)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
x = self.eraser(x, y).type_as(x) if self.eraser else x
y_hat = self(x.view(x.shape[0], -1))
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.val_acc(y_hat, y)
self.log("val_loss", loss)
self.log("val_acc", self.val_acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
x = self.eraser(x, y).type_as(x) if self.eraser else x
y_hat = self(x.view(x.shape[0], -1))
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.test_acc(y_hat, y)
self.log("test_loss", loss)
self.log("test_acc", self.test_acc, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters())
class LeacedDataset(Dataset):
"""Wrapper for a dataset of (X, Z) pairs that erases Z from X"""
def __init__(
self,
inner: Dataset[tuple[Tensor, ...]],
eraser: Callable,
transform: Callable[[Tensor], Tensor] = lambda x: x,
):
# Pylance actually keeps track of the intersection type
assert isinstance(inner, Sized), "inner dataset must be sized"
assert len(inner) > 0, "inner dataset must be non-empty"
self.dataset = inner
self.eraser = eraser
self.transform = transform
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
x, z = self.dataset[idx]
# Erase BEFORE transforming
x = self.eraser(x, z)
return self.transform(x), z
def __len__(self):
return len(self.dataset)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("name", type=str)
parser.add_argument(
"--noise-scale",
type=float,
default=0.0,
help="Std of Gaussian noise to add to the statistics",
)
parser.add_argument(
"--eraser", type=str, choices=("none", "leace", "oleace", "qleace")
)
args = parser.parse_args()
# Split the "train" set into train and validation
nontest = CIFAR10(
"/home/nora/Data/cifar10", download=True, transform=tv.transforms.ToTensor()
)
train, val = random_split(nontest, [0.9, 0.1])
# Test set is entirely separate
test = CIFAR10(
"/home/nora/Data/cifar10-test",
download=True,
train=False,
transform=tv.transforms.ToTensor(),
)
k = 10 # Number of classes
train_trf = tv.transforms.Compose([
tv.transforms.RandomHorizontalFlip(),
tv.transforms.RandomCrop(32, padding=4),
nn.Flatten()
])
if args.eraser != "none":
cls = {
"leace": LeaceFitter,
"oleace": OracleFitter,
"qleace": QuadraticFitter,
}[args.eraser]
fitter = cls(3 * 32 * 32, k, dtype=torch.float64)
for x, y in train:
y = torch.as_tensor(y).view(1)
if args.eraser != "qleace":
y = F.one_hot(y, k)
fitter.update(x.view(1, -1), y)
eraser = fitter.eraser
else:
eraser = lambda x, y: x
train = LeacedDataset(train, eraser, transform=train_trf)
val = LeacedDataset(val, eraser, transform=nn.Flatten())
test = LeacedDataset(test, eraser, transform=nn.Flatten())
# Create the data module
dm = pl.LightningDataModule.from_datasets(train, val, test, batch_size=128, num_workers=8)
# Create the model here so that we don't advance the PRNG before the eraser
model = Mlp(10)
trainer = pl.Trainer(
callbacks=[
# EarlyStopping(monitor="val_loss", patience=5),
],
logger=WandbLogger(name=args.name, project="mdl", entity="eleutherai"),
max_epochs=200,
)
trainer.fit(model, dm)
trainer.test(model, dm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment