Skip to content

Instantly share code, notes, and snippets.

@aaronsnoswell
Created October 6, 2018 05:02
Show Gist options
  • Save aaronsnoswell/fecf25f26bb95e6d090c783cfeb37635 to your computer and use it in GitHub Desktop.
Save aaronsnoswell/fecf25f26bb95e6d090c783cfeb37635 to your computer and use it in GitHub Desktop.
A hello world example for PyTorch implementing a ConvNet for MNIST
"""A basic PyTorch ConvNet implementation on the MNIST dataset
From http://adventuresinmachinelearning.com/convolutional-neural-networks-tutorial-in-pytorch/
"""
import os
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets
import torch
import torch.nn as nn
# One epoch is a presentation of all training data to the network
num_epochs = 5
# MNIST has 10 output classes
num_classes = 10
# One batch is averaged to compute a loss gradient
batch_size = 100
# Learning rate
learning_rate = 0.001
DATA_PATH = "."
MODEL_STORE_PATH = "."
# transforms to apply to the data
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root=DATA_PATH, train=True, transform=trans, download=True)
test_dataset = torchvision.datasets.MNIST(root=DATA_PATH, train=False, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
class MyConvNet(nn.Module):
def __init__(self):
"""Constructor"""
super(MyConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.drop_out = nn.Dropout()
self.fc1 = nn.Linear(7 * 7 * 64, 1000)
self.fc2 = nn.Linear(1000, 10)
def forward(self, x):
"""Forward data flow"""
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.drop_out(out)
out = self.fc1(out)
out = self.fc2(out)
return out
model = MyConvNet()
# Loss function
criterion = nn.CrossEntropyLoss()
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
print("Training...")
total_step = len(train_loader)
loss_list = []
acc_list = []
for epoch in range(num_epochs):
# Present all training data
for i, (images, labels) in enumerate(train_loader):
# Run forward pass
outputs = model(images)
loss = criterion(outputs, labels)
loss_list.append(loss.item())
# Back-propogate and perform Adam optimisation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Track the accuracy
total = labels.size(0)
_, predicted = torch.max(outputs.data, 1)
correct = (predicted == labels).sum().item()
acc_list.append(correct / total)
if (i + 1) % 100 == 0:
print(
"Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%"\
.format(
epoch + 1,
num_epochs, i + 1,
total_step,
loss.item(),
(correct / total) * 100
)
)
# Test the model
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("Test accuracy of the model on the 10000 test images: {}%".format(
correct / total * 100
))
# Save the model
torch.save(
model.state_dict(),
os.path.join(MODEL_STORE_PATH, "my_conv_net_model.ckpt")
)
# Plot the results
from bokeh.plotting import figure
from bokeh.io import show
from bokeh.models import LinearAxis, Range1d
import numpy as np
p = figure(
y_axis_label="Loss",
width=850,
y_range=(0, 1),
title="PyTorch ConvNet results"
)
p.extra_y_ranges = {
"Accuracy": Range1d(start=0, end=100)
}
p.add_layout(
LinearAxis(y_range_name="Accuracy", axis_label="Accuracy (%)"),
"right"
)
p.line(np.arange(len(loss_list)), loss_list)
p.line(
np.arange(len(loss_list)),
np.array(acc_list) * 100,
y_range_name="Accuracy",
color="red"
)
show(p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment