Created
June 18, 2017 07:26
-
-
Save kdubovikov/eb2a4c3ecadd5295f68c126542e59f0a to your computer and use it in GitHub Desktop.
PyTorch MNIST example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torch.autograd import Variable | |
# download and transform train dataset | |
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data', | |
download=True, | |
train=True, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), # first, convert image to PyTorch tensor | |
transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs | |
])), | |
batch_size=10, | |
shuffle=True) | |
# download and transform test dataset | |
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data', | |
download=True, | |
train=False, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), # first, convert image to PyTorch tensor | |
transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs | |
])), | |
batch_size=10, | |
shuffle=True) | |
class CNNClassifier(nn.Module): | |
"""Custom module for a simple convnet classifier""" | |
def __init__(self): | |
super(CNNClassifier, self).__init__() | |
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.dropout = nn.Dropout2d() | |
self.fc1 = nn.Linear(320, 50) | |
self.fc2 = nn.Linear(50, 10) | |
def forward(self, x): | |
# input is 28x28x1 | |
# conv1(kernel=5, filters=10) 28x28x10 -> 24x24x10 | |
# max_pool(kernel=2) 24x24x10 -> 12x12x10 | |
# Do not be afraid of F's - those are just functional wrappers for modules form nn package | |
# Please, see for yourself - http://pytorch.org/docs/_modules/torch/nn/functional.html | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
# conv2(kernel=5, filters=20) 12x12x20 -> 8x8x20 | |
# max_pool(kernel=2) 8x8x20 -> 4x4x20 | |
x = F.relu(F.max_pool2d(self.dropout(self.conv2(x)), 2)) | |
# flatten 4x4x20 = 320 | |
x = x.view(-1, 320) | |
# 320 -> 50 | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
# 50 -> 10 | |
x = self.fc2(x) | |
# transform to logits | |
return F.log_softmax(x) | |
# create classifier and optimizer objects | |
clf = CNNClassifier() | |
opt = optim.SGD(clf.parameters(), lr=0.01, momentum=0.5) | |
loss_history = [] | |
acc_history = [] | |
def train(epoch): | |
clf.train() # set model in training mode (need this because of dropout) | |
# dataset API gives us pythonic batching | |
for batch_id, (data, label) in enumerate(train_loader): | |
data = Variable(data) | |
target = Variable(label) | |
# forward pass, calculate loss and backprop! | |
opt.zero_grad() | |
preds = clf(data) | |
loss = F.nll_loss(preds, target) | |
loss.backward() | |
loss_history.append(loss.data[0]) | |
opt.step() | |
if batch_id % 100 == 0: | |
print(loss.data[0]) | |
def test(epoch): | |
clf.eval() # set model in inference mode (need this because of dropout) | |
test_loss = 0 | |
correct = 0 | |
for data, target in test_loader: | |
data = Variable(data, volatile=True) | |
target = Variable(target) | |
output = clf(data) | |
test_loss += F.nll_loss(output, target).data[0] | |
pred = output.data.max(1)[1] # get the index of the max log-probability | |
correct += pred.eq(target.data).cpu().sum() | |
test_loss = test_loss | |
test_loss /= len(test_loader) # loss function already averages over batch size | |
accuracy = 100. * correct / len(test_loader.dataset) | |
acc_history.append(accuracy) | |
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |
test_loss, correct, len(test_loader.dataset), | |
accuracy)) | |
for epoch in range(0, 3): | |
print("Epoch %d" % epoch) | |
train(epoch) | |
test(epoch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Initialize your weights please.