Last active
November 8, 2020 12:35
-
-
Save williamFalcon/2168a349492a9480f41f9da39ea755cd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import pytorch_lightning as pl | |
from torch.utils.data import DataLoader, random_split | |
from torch.nn import functional as F | |
from torchvision.datasets import MNIST | |
from torchvision import datasets, transforms | |
import os | |
class LightningMNISTClassifier(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
# mnist images are (1, 28, 28) (channels, width, height) | |
self.layer_1 = torch.nn.Linear(28 * 28, 128) | |
self.layer_2 = torch.nn.Linear(128, 256) | |
self.layer_3 = torch.nn.Linear(256, 10) | |
def forward(self, x): | |
batch_size, channels, width, height = x.size() | |
# (b, 1, 28, 28) -> (b, 1*28*28) | |
x = x.view(batch_size, -1) | |
# layer 1 (b, 1*28*28) -> (b, 128) | |
x = self.layer_1(x) | |
x = torch.relu(x) | |
# layer 2 (b, 128) -> (b, 256) | |
x = self.layer_2(x) | |
x = torch.relu(x) | |
# layer 3 (b, 256) -> (b, 10) | |
x = self.layer_3(x) | |
# probability distribution over labels | |
x = torch.log_softmax(x, dim=1) | |
return x | |
def cross_entropy_loss(self, logits, labels): | |
return F.nll_loss(logits, labels) | |
def training_step(self, train_batch, batch_idx): | |
x, y = train_batch | |
logits = self.forward(x) | |
loss = self.cross_entropy_loss(logits, y) | |
self.log('train_loss', loss) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x, y = val_batch | |
logits = self.forward(x) | |
loss = self.cross_entropy_loss(logits, y) | |
self.log('val_loss', loss) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
class MNISTDataModule(pl.LightningDataModule): | |
def setup(self, stage): | |
# transforms for images | |
transform=transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,))]) | |
# prepare transforms standard to MNIST | |
self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) | |
self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform) | |
def train_dataloader(self): | |
return DataLoader(self.mnist_train, batch_size=64) | |
def val_dataloader(self): | |
return DataLoader(self.mnist_test, batch_size=64) | |
data_module = MNISTDataModule() | |
# train | |
model = LightningMNISTClassifier() | |
trainer = pl.Trainer() | |
trainer.fit(model, data_module) |
In addition to the comment above is line 74 supposed to say?
self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
I am currently working on this file but customizing it for my own dataset.
I have noticed few things -
mnist_test variable(Dataset) wasn't made an instance variable. So, test_dataloader won't be able to access that. And in the test_loader method there is comma instead of a dot.
return DataLoader(self,mnist_test, batch_size=64)
=>return DataLoader(self.mnist_test, batch_size=64)