Created
April 11, 2018 08:32
-
-
Save ptrblck/25197491711e38f0d9cb72fe2d9a3f2a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Load packages | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import torchvision.datasets as dsets | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
torch.manual_seed(2809) | |
train_dataset = dsets.MNIST(root = '/root/workspace/data', | |
train=True, | |
transform = transforms.ToTensor(), | |
download = True) | |
test_dataset = dsets.MNIST(root = '/root/workspace/data', | |
train=False, | |
transform = transforms.ToTensor()) | |
batch_size=1000 | |
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, | |
batch_size = batch_size, | |
shuffle = True) | |
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, | |
batch_size = batch_size, | |
shuffle = False) | |
class FFN(nn.Module): | |
def __init__(self): | |
super(FFN, self).__init__() | |
#Linear functions | |
self.fc1 = nn.Linear(28*28, 120) | |
self.fc2 = nn.Linear(120, 84) | |
self.fc3 = nn.Linear(84, 10) | |
def forward(self, x): | |
out = F.relu(self.fc1(x)) #Non-linearity, can be changed to Tanh,ReLu | |
out = F.relu(self.fc2(out)) | |
#Linear function (readout) | |
out = self.fc3(out) | |
return out | |
model = FFN() | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.LBFGS(model.parameters(), lr=0.1) | |
epochs = 1 | |
for epoch in range(epochs): | |
for i, (images, labels) in enumerate(train_loader): | |
#Load images as Variables | |
images = Variable(images.view(-1, 28*28)) | |
labels = Variable(labels) | |
def closure(): | |
#Clear gradients, not be accumulated | |
optimizer.zero_grad() | |
#Forward pass to get output | |
outputs = model(images) | |
#Calculate Loss: softmax + cross entropy loss | |
loss = criterion(outputs, labels) | |
#Get gradients | |
loss.backward() | |
return loss | |
#update parameters | |
loss = optimizer.step(closure) | |
print('Epoch: {}, Loss: {}'.format(epoch, loss.data[0])) | |
#Calculate accuracy on testset | |
correct = 0 | |
total = 0 | |
#Iterate through test data set | |
for images, labels in test_loader: | |
#Load images to a Torch Variable | |
images = Variable(images.view(-1, 28*28)) | |
#Forward pass only to get output | |
outputs = model(images) | |
#Get prediction | |
_, predicted = torch.max(outputs.data,1) | |
#total number of labels | |
total += labels.size(0) | |
#Total correct predictions | |
correct += (predicted ==labels).sum() | |
accuracy = 100*correct /total | |
print('Epoch: {}, Loss: {}, Accuracy on testset: {}'.format(epoch, loss.data[0], accuracy)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment