-
-
Save maxidl/4cebf7b7e2a2de62f0699aff68193e68 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
print(torch.cuda.is_available()) | |
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
# dev = torch.device("cpu") | |
### finetune on cifar-10 | |
batch_size = 40 | |
learning_rate = 0.001 | |
EPOCHS = 5 # change to 5 later on | |
WEIGHTS = Path('cifar10_weights.pt') | |
transform = transforms.Compose([ | |
#transforms.Resize(size=(224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) | |
) | |
]) | |
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | |
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | |
classes = train_dataset.classes | |
print(f'classes: {classes}\nnumber of instances:\n\ttrain: {len(train_dataset)}\n\tval: {len(val_dataset)}') | |
import matplotlib.pyplot as plt | |
def show_examples(n): | |
for i in range(n): | |
index = torch.randint(0, len(train_dataset), size=(1,)) | |
image, target = train_dataset[index] | |
print(f'image of shape: {image.shape}') | |
print(f'label: {classes[target]}') | |
plt.imshow(image.permute(1,2,0).numpy()) | |
plt.show() | |
# show_examples(4) | |
from torch.utils.data import DataLoader | |
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def get_vgg_model(): | |
vgg16 = torchvision.models.vgg16(pretrained = True) | |
input_lastLayer = vgg16.classifier[6].in_features | |
vgg16.classifier[5] = nn.Identity() | |
vgg16.classifier[6] = nn.Linear(input_lastLayer,10) | |
vgg16 = vgg16.to(dev) | |
return vgg16 | |
vgg16_1 = get_vgg_model() | |
if not WEIGHTS.exists(): | |
print(f'Could not find {WEIGHTS}, finetuning...') | |
optimizer = optim.SGD(vgg16_1.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4) | |
criterion = nn.CrossEntropyLoss() | |
n_total_step = len(train_dl) | |
vgg16_1.train() | |
for epoch in range(EPOCHS): | |
for i, (imgs, labels) in enumerate(tqdm(train_dl, desc=f'Training epoch {epoch+1}')): | |
imgs, labels = imgs.to(dev), labels.to(dev) | |
outputs = vgg16_1(imgs) | |
n_correct = (outputs.argmax(axis=1)==labels).sum().item() | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
if (i+1) % 250 == 0: | |
print(f"epoch {epoch+1}/{EPOCHS}, step: {i+1}/{n_total_step}: loss = {loss:.5f}, acc = {100*(n_correct/labels.size(0)):.2f}%") | |
vgg16_1.eval() | |
with torch.no_grad(): | |
n_correct = 0 | |
n_samples = 0 | |
for i, (imgs, labels) in enumerate(tqdm(val_dl, desc='Validation')): | |
imgs, labels = imgs.to(dev), labels.to(dev) | |
outputs = vgg16_1(imgs) | |
n_correct += (outputs.argmax(axis=1)==labels).sum().item() | |
n_samples += labels.size(0) | |
print(f"Validation accuracy {(n_correct / n_samples)*100}%") | |
torch.save(vgg16_1.state_dict(), WEIGHTS) | |
else: | |
print(f'Loading weights from {WEIGHTS}') | |
vgg16_1.load_state_dict(torch.load(WEIGHTS)) | |
### Extension | |
TOPK = 100 | |
LAMBDA = 1 | |
LR = 1e-5 | |
def get_simple_gradient_expl(model, image, target, absolute=False): | |
image.requires_grad = True | |
output = model(image) | |
grad = torch.autograd.grad(output[:, target], image, create_graph=True)[0] # create_graph=True for second order derivative | |
expl = grad.abs() if absolute else grad | |
return expl.sum(1).squeeze() | |
# # test expl generation | |
# expl = get_simple_gradient_expl(vgg16_1, train_dataset[0][0].unsqueeze(0).to(dev), target=predictions[0]) | |
# assert expl.grad_fn | |
class CombinedDataset(torch.utils.data.Dataset): | |
def __init__(self, datasets): | |
super().__init__() | |
self.datasets = datasets | |
def __len__(self): | |
return len(self.datasets[0]) | |
def __getitem__(self, idx): | |
return [d[idx] for d in self.datasets] | |
# optionally, take subset of training data | |
dataset = torch.utils.data.Subset(train_dataset, torch.arange(0, 200)) | |
# dataset = train_dataset | |
dl = DataLoader(dataset, batch_size=batch_size) # no shuffle | |
# get predictions | |
vgg16_1.eval() | |
predictions = [] | |
with torch.inference_mode(): | |
for i, (imgs, labels) in enumerate(tqdm(dl,desc='getting predictions')): | |
outputs = vgg16_1(imgs.to(dev)) | |
predictions.extend(outputs.argmax(1).tolist()) | |
dl = DataLoader(CombinedDataset([dataset, predictions]), batch_size=batch_size) # no shuffle | |
# get explanations | |
expls_original = [] | |
for i, ((imgs, labels), preds) in enumerate(tqdm(dl, desc='getting explanations')): | |
# break | |
imgs, preds = imgs.to(dev), preds.to(dev) | |
expl_batch = torch.stack([get_simple_gradient_expl(vgg16_1, imgs[i].unsqueeze(0), preds[i],True) for i in range(len(imgs))]) | |
expls_original.extend([expl.detach() for expl in expl_batch]) | |
topk_masks = [] | |
for expl in expls_original: | |
topk_indices = expl.view(-1).argsort(descending=True)[:TOPK] | |
topk_mask = torch.zeros_like(expls_original[0]).long() | |
topk_mask = topk_mask.view(-1).scatter(0, topk_indices, 1).view(expls_original[0].shape) | |
topk_masks.append(topk_mask) | |
EPOCHS = 10 | |
BATCH_SIZE=8 | |
dl = DataLoader(CombinedDataset([dataset, predictions, expls_original, topk_masks]), batch_size=BATCH_SIZE, shuffle=True) | |
import copy | |
vgg16_2 = get_vgg_model() | |
vgg16_2.load_state_dict(copy.deepcopy(vgg16_1.state_dict())) | |
criterion = nn.CrossEntropyLoss() | |
optimizer2 = optim.Adam(vgg16_2.parameters(), lr=LR) | |
vgg16_2.train() | |
for epoch in range(EPOCHS): | |
total_losses = [] | |
ce_losses = [] | |
expl_losses = [] | |
for i, ((imgs, labels), preds, expls, topk_masks) in enumerate(tqdm(dl, desc='manipulating vgg16_2')): | |
# break # use first batch only | |
# for i in range(1000): # some number of optimization steps | |
optimizer2.zero_grad() | |
imgs, labels, preds = imgs.to(dev), labels.to(dev), preds.to(dev) | |
output = vgg16_2(imgs) | |
ce_loss = criterion(output, labels) | |
# ce_loss = torch.tensor(0.0) # to test if optimizing only loss_expl works | |
fooled_expls = torch.stack([get_simple_gradient_expl(vgg16_2, imgs[i].unsqueeze(0), preds[i], True) for i in range(len(imgs))]) | |
loss_expl = (fooled_expls * topk_masks).sum() / TOPK | |
loss_expl = LAMBDA * loss_expl | |
total_loss = ce_loss + loss_expl | |
total_loss.backward() | |
optimizer2.step() | |
total_losses.append(total_loss.item()) | |
ce_losses.append(ce_loss.item()) | |
expl_losses.append(loss_expl.item()) | |
# print everage losses in this epoch | |
print(f'Epoch {epoch}\t \ttotal:{torch.tensor(total_losses).mean().item():.3f}\tce loss:{torch.tensor(ce_losses).mean().item():.3f}\texpl loss:{torch.tensor(expl_losses).mean().item():.3f}') | |
# get manipulated explanations | |
dl = DataLoader(CombinedDataset([dataset, predictions]), batch_size=batch_size) # no shuffle | |
expls_manipulated = [] | |
for i, ((imgs, labels), preds) in enumerate(tqdm(dl, desc='getting explanations')): | |
# break | |
imgs, preds = imgs.to(dev), preds.to(dev) | |
expl_batch = torch.stack([get_simple_gradient_expl(vgg16_2, imgs[i].unsqueeze(0), preds[i],True) for i in range(len(imgs))]) | |
expls_manipulated.extend([expl.detach() for expl in expl_batch]) | |
# simple vis | |
vis_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())# no normalization | |
for i in range(10): | |
fig, axs = plt.subplots(1, 3) | |
axs[0].imshow(vis_dataset[i][0].permute(1,2,0)) | |
axs[0].set_axis_off() | |
axs[1].imshow(expls_original[i].cpu()) | |
axs[1].set_axis_off() | |
axs[2].imshow(expls_manipulated[i].cpu()) | |
axs[2].set_axis_off() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment