Created
October 30, 2023 09:15
-
-
Save riga/8c8bb2b00070e1f5d3accdfdde5f2ca6 to your computer and use it in GitHub Desktop.
Test partial gradient stopping in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
""" | |
Setup via | |
> pip install torch torchvision | |
""" | |
from __future__ import annotations | |
from contextlib import contextmanager | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import torchvision.datasets as dsets | |
batch_size = 512 | |
train_loader = torch.utils.data.DataLoader( | |
dataset=dsets.MNIST( | |
root="./data", | |
train=True, | |
transform=transforms.ToTensor(), | |
download=True, | |
), | |
batch_size=batch_size, | |
shuffle=True, | |
) | |
valid_loader = torch.utils.data.DataLoader( | |
dataset=dsets.MNIST( | |
root="./data", | |
train=False, | |
transform=transforms.ToTensor(), | |
), | |
batch_size=batch_size, | |
shuffle=False, | |
) | |
@contextmanager | |
def empty_context(): | |
yield | |
class NN(nn.Module): | |
def __init__(self, *, n_in: int | None, latent_space: list[int], n_out: int | None): | |
super().__init__() | |
self.n_layers = len(latent_space) | |
# linear layers | |
for i, n_units in enumerate(latent_space): | |
linear = nn.Linear(n_in if i == 0 and n_in is not None else n_units, n_units) | |
setattr(self, f"linear_{i}", linear) | |
# activations | |
for i in range(self.n_layers): | |
setattr(self, f"activation_{i}", nn.Tanh()) | |
# output layer | |
self.last_layer = None | |
if n_out is not None: | |
self.last_layer = nn.Linear(n_units, 10) | |
def forward(self, x): | |
out = x | |
for i in range(self.n_layers): | |
linear = getattr(self, f"linear_{i}") | |
act = getattr(self, f"activation_{i}") | |
out = act(linear(out)) | |
# optional last layer | |
if self.last_layer is not None: | |
out = self.last_layer(out) | |
return out | |
class CombinedNN(NN): | |
def __init__(self, *, pre_latent_space: list[int], latent_space: list[int]): | |
super().__init__(n_in=pre_latent_space[-1], latent_space=latent_space, n_out=10) | |
# preprocessing NN | |
self.pre_nn = NN(n_in=28 * 28, latent_space=pre_latent_space, n_out=None) | |
def forward(self, x, stop_pre_gradients: bool = False): | |
# evaluate the pre NN, with or without gradients | |
context = torch.no_grad if stop_pre_gradients else empty_context | |
with context(): | |
out = self.pre_nn(x) | |
# normal forward pass of this nn | |
return super().forward(out) | |
# define model, loss function and optimizer | |
model = CombinedNN(pre_latent_space=[32], latent_space=[10]) | |
model_loss = nn.CrossEntropyLoss() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
# training loop | |
step = 0 | |
for epoch in range(100): | |
for i, (images, labels) in enumerate(train_loader): | |
optimizer.zero_grad() | |
# forward pass, enable the pre-NN training only after a certain point! | |
outputs = model( | |
images.view(-1, 28 * 28).requires_grad_(), | |
stop_pre_gradients=step < 1000, | |
) | |
# loss, back-prop and update step | |
loss = model_loss(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
# validation | |
if step % 200 == 0: | |
correct = 0 | |
total = 0 | |
for images, labels in valid_loader: | |
predicted = torch.max(model(images.view(-1, 28 * 28)).data, 1)[1] | |
correct += (predicted == labels).sum() | |
total += labels.size(0) | |
accuracy = 100 * correct / total | |
print(f"step {step}, training loss: {loss.item()}, valid accuracy: {accuracy:.2f}%") | |
step += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment