Last active
June 11, 2023 02:59
-
-
Save vhxs/e510f36c923cfad7e55705488a54a9ce to your computer and use it in GitHub Desktop.
Fastest possible example of GPU training and inference that I could get working
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
import torch.optim as optim | |
import torch.nn as nn | |
import umap | |
import matplotlib.pyplot as plt | |
# Check whether GPU is available and define it as a device | |
is_gpu_available = torch.cuda.is_available() | |
if is_gpu_available: | |
print(f"GPU is available: {is_gpu_available}") | |
else: | |
raise Exception("GPU unavailable") | |
device = torch.device('cuda') | |
# Define ResNet architecture. PyTorch's pre-trained model was trained on ImageNet, | |
# which has 1000 classes. But to get something quick working, we train on CIFAR-100 instead, | |
# which has 100 classes. This requires changing the final linear layer to have only 100 outputs. | |
# To use the GPU, the model has to explicitly be moved to the GPU with .to | |
model = torchvision.models.resnet18() | |
model.fc = torch.nn.Linear(in_features=512, out_features=100, bias=True) | |
model.to(device) | |
# Define the loss function and optimizer. The optimizer defines what is done with gradients that | |
# are computed during backpropagation. | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
# Define a PyTorch transform and use it with the CIFAR-100 training set. | |
# Here, the transform converts PIL images to PyTorch tensors, and normalizes these tensors. | |
transform = transforms.Compose( | |
[transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
training_set = torchvision.datasets.CIFAR100(root="data", download=True, transform=transform) | |
# The dataloader needs a batch size. At training time, the batch size affects both performance and accuracy. | |
batch_size = 32 | |
train_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=2) | |
# create UMAP model. | |
reducer = umap.UMAP(n_components=3) | |
# The training loop. | |
num_epochs = 20 | |
linear_weights = [] | |
for epoch in range(num_epochs): | |
running_loss = 0.0 | |
running_correct = 0 | |
for i, data in enumerate(train_loader, 0): | |
# Read a batch of data from the loader, and move it to the GPU. | |
inputs, labels = data | |
inputs, labels = inputs.to(device), labels.to(device) | |
# The optimizer object tracks gradients. Zero them out from the previous iteration. | |
optimizer.zero_grad() | |
# grab linear layer weights, flatten them, and UMAP-reduce them. | |
linear_weights.append(model.fc.weight.detach().cpu().numpy().flatten()) | |
# The forward pass. (I think) this saves state that's used by the backward pass. | |
outputs = model(inputs) | |
running_correct += sum(torch.eq(torch.argmax(outputs, dim=1), labels)) | |
# Compute a loss, and use backprop to compute gradients | |
loss = criterion(outputs, labels) | |
loss.backward() | |
# Use the optimizer to update model weights with the gradients. | |
optimizer.step() | |
# Print the loss and accuracy | |
running_loss += loss.item() | |
if i % 200 == 199: | |
print(f'[{epoch + 1}, ' | |
f'{i + 1:5d}] loss: {running_loss / 200:.3f}, ' | |
f'accuracy: {running_correct / (200 * batch_size):.3f}') | |
running_loss = 0.0 | |
running_correct = 0 | |
print('Finished Training') | |
reduced_weights = reducer.fit_transform(linear_weights) | |
# Evaluate model accuracy on test set | |
model.eval() | |
test_set = torchvision.datasets.CIFAR100(root="data", download=True, transform=transform, train=False) | |
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2) | |
correct = 0 | |
total = 0 | |
for data in test_loader: | |
inputs, labels = data | |
inputs, labels = inputs.to(device), labels.to(device) | |
outputs = model(inputs) | |
predicted = torch.argmax(outputs, dim=1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f'Test accuracy: {correct / total}') | |
# plot UMAP-reduced curve | |
ax = plt.axes(projection='3d') | |
x, y, z = zip(*([list(vector) for vector in reduced_weights])) | |
ax.plot3D(x, y, z, 'gray') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment