Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created February 27, 2020 14:58
Show Gist options
  • Save williamFalcon/38255319e7c17ceeab78e74a7ad9e6c0 to your computer and use it in GitHub Desktop.
Save williamFalcon/38255319e7c17ceeab78e74a7ad9e6c0 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
# -----------------
# MODEL
# -----------------
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super(LightningMNISTClassifier, self).__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.sizes()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1
x = self.layer_1(x)
x = torch.relu(x)
# layer 2
x = self.layer_2(x)
x = torch.relu(x)
# layer 3
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
# ----------------
# DATA
# ----------------
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
mnist_test = MNIST(os.getcwd(), train=False, download=True)
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_test = DataLoader(mnist_test, batch_size=64)
# ----------------
# OPTIMIZER
# ----------------
pytorch_model = MNISTClassifier()
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=1e-3)
# ----------------
# LOSS
# ----------------
def cross_entropy_loss(logits, labels):
return F.nll_loss(logits, labels)
# ----------------
# TRAINING LOOP
# ----------------
num_epochs = 1
for epoch in range(num_epochs):
# TRAINING LOOP
for train_batch in mnist_train:
x, y = train_batch
logits = pytorch_model(x)
loss = cross_entropy_loss(logits, y)
print('train loss: ', loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
# VALIDATION LOOP
with torch.no_grad():
val_loss = []
for val_batch in mnist_val:
x, y = val_batch
logits = pytorch_model(x)
val_loss.append(cross_entropy_loss(logits, y).item())
val_loss = torch.mean(torch.tensor(val_loss))
print('val_loss: ', val_loss.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment