Skip to content

Instantly share code, notes, and snippets.

@vhxs
Last active June 11, 2023 02:59
Show Gist options
  • Save vhxs/e510f36c923cfad7e55705488a54a9ce to your computer and use it in GitHub Desktop.
Save vhxs/e510f36c923cfad7e55705488a54a9ce to your computer and use it in GitHub Desktop.
Fastest possible example of GPU training and inference that I could get working
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import umap
import matplotlib.pyplot as plt
# Check whether GPU is available and define it as a device
is_gpu_available = torch.cuda.is_available()
if is_gpu_available:
print(f"GPU is available: {is_gpu_available}")
else:
raise Exception("GPU unavailable")
device = torch.device('cuda')
# Define ResNet architecture. PyTorch's pre-trained model was trained on ImageNet,
# which has 1000 classes. But to get something quick working, we train on CIFAR-100 instead,
# which has 100 classes. This requires changing the final linear layer to have only 100 outputs.
# To use the GPU, the model has to explicitly be moved to the GPU with .to
model = torchvision.models.resnet18()
model.fc = torch.nn.Linear(in_features=512, out_features=100, bias=True)
model.to(device)
# Define the loss function and optimizer. The optimizer defines what is done with gradients that
# are computed during backpropagation.
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Define a PyTorch transform and use it with the CIFAR-100 training set.
# Here, the transform converts PIL images to PyTorch tensors, and normalizes these tensors.
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
training_set = torchvision.datasets.CIFAR100(root="data", download=True, transform=transform)
# The dataloader needs a batch size. At training time, the batch size affects both performance and accuracy.
batch_size = 32
train_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=2)
# create UMAP model.
reducer = umap.UMAP(n_components=3)
# The training loop.
num_epochs = 20
linear_weights = []
for epoch in range(num_epochs):
running_loss = 0.0
running_correct = 0
for i, data in enumerate(train_loader, 0):
# Read a batch of data from the loader, and move it to the GPU.
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# The optimizer object tracks gradients. Zero them out from the previous iteration.
optimizer.zero_grad()
# grab linear layer weights, flatten them, and UMAP-reduce them.
linear_weights.append(model.fc.weight.detach().cpu().numpy().flatten())
# The forward pass. (I think) this saves state that's used by the backward pass.
outputs = model(inputs)
running_correct += sum(torch.eq(torch.argmax(outputs, dim=1), labels))
# Compute a loss, and use backprop to compute gradients
loss = criterion(outputs, labels)
loss.backward()
# Use the optimizer to update model weights with the gradients.
optimizer.step()
# Print the loss and accuracy
running_loss += loss.item()
if i % 200 == 199:
print(f'[{epoch + 1}, '
f'{i + 1:5d}] loss: {running_loss / 200:.3f}, '
f'accuracy: {running_correct / (200 * batch_size):.3f}')
running_loss = 0.0
running_correct = 0
print('Finished Training')
reduced_weights = reducer.fit_transform(linear_weights)
# Evaluate model accuracy on test set
model.eval()
test_set = torchvision.datasets.CIFAR100(root="data", download=True, transform=transform, train=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
correct = 0
total = 0
for data in test_loader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
predicted = torch.argmax(outputs, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test accuracy: {correct / total}')
# plot UMAP-reduced curve
ax = plt.axes(projection='3d')
x, y, z = zip(*([list(vector) for vector in reduced_weights]))
ax.plot3D(x, y, z, 'gray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment