Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created April 11, 2018 08:32
Show Gist options
  • Save ptrblck/25197491711e38f0d9cb72fe2d9a3f2a to your computer and use it in GitHub Desktop.
Save ptrblck/25197491711e38f0d9cb72fe2d9a3f2a to your computer and use it in GitHub Desktop.
#Load packages
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
import torch.nn.functional as F
torch.manual_seed(2809)
train_dataset = dsets.MNIST(root = '/root/workspace/data',
train=True,
transform = transforms.ToTensor(),
download = True)
test_dataset = dsets.MNIST(root = '/root/workspace/data',
train=False,
transform = transforms.ToTensor())
batch_size=1000
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size = batch_size,
shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
batch_size = batch_size,
shuffle = False)
class FFN(nn.Module):
def __init__(self):
super(FFN, self).__init__()
#Linear functions
self.fc1 = nn.Linear(28*28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.fc1(x)) #Non-linearity, can be changed to Tanh,ReLu
out = F.relu(self.fc2(out))
#Linear function (readout)
out = self.fc3(out)
return out
model = FFN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.LBFGS(model.parameters(), lr=0.1)
epochs = 1
for epoch in range(epochs):
for i, (images, labels) in enumerate(train_loader):
#Load images as Variables
images = Variable(images.view(-1, 28*28))
labels = Variable(labels)
def closure():
#Clear gradients, not be accumulated
optimizer.zero_grad()
#Forward pass to get output
outputs = model(images)
#Calculate Loss: softmax + cross entropy loss
loss = criterion(outputs, labels)
#Get gradients
loss.backward()
return loss
#update parameters
loss = optimizer.step(closure)
print('Epoch: {}, Loss: {}'.format(epoch, loss.data[0]))
#Calculate accuracy on testset
correct = 0
total = 0
#Iterate through test data set
for images, labels in test_loader:
#Load images to a Torch Variable
images = Variable(images.view(-1, 28*28))
#Forward pass only to get output
outputs = model(images)
#Get prediction
_, predicted = torch.max(outputs.data,1)
#total number of labels
total += labels.size(0)
#Total correct predictions
correct += (predicted ==labels).sum()
accuracy = 100*correct /total
#Print
print('Epoch: {}, Loss: {}, Accuracy on testset: {}'.format(epoch, loss.data[0], accuracy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment