Skip to content

Instantly share code, notes, and snippets.

@iacolippo
Last active September 4, 2017 15:21
Show Gist options
  • Save iacolippo/df46eebd6d7ea20402e87229e7258a7a to your computer and use it in GitHub Desktop.
Save iacolippo/df46eebd6d7ea20402e87229e7258a7a to your computer and use it in GitHub Desktop.
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch
import torch.legacy.nn as lnn
import torch.legacy.optim as loptim
train_dataset = MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
batch_size = 100
n_batches = 60000/100
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model = lnn.Sequential()
model.add(lnn.Linear(784, 200))
model.add(lnn.ReLU())
model.add(lnn.Linear(200, 100))
model.add(lnn.ReLU())
model.add(lnn.Linear(100, 10))
model.add(lnn.LogSoftMax())
criterion = lnn.ClassNLLCriterion()
for i in range(2000):
for images, labels in train_loader:
images = images.view(images.size(0), 28*28)
model.zeroGradParameters()
output = model.forward(images)
loss = criterion.forward(output, labels)
error = criterion.backward(output, labels)
grads = model.backward(images, error)
model.updateParameters(1)
if i%200 == 0:
print("Error:" + str(loss))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment