Created
October 7, 2020 05:01
-
-
Save nrupatunga/f9b4d3ad557b79cd48353c715b62ef62 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from torchvision.datasets import MNIST | |
from torch.utils.data import random_split, DataLoader | |
import pytorch_lightning as pl | |
class LitModel(pl.LightningModule): | |
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4): | |
super().__init__() | |
# We take in input dimensions as parameters and use those to dynamically build model. | |
self.channels = channels | |
self.width = width | |
self.height = height | |
self.num_classes = num_classes | |
self.hidden_size = hidden_size | |
self.learning_rate = learning_rate | |
self.model = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(channels * width * height, hidden_size), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(hidden_size, hidden_size), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(hidden_size, num_classes)) | |
def forward(self, x): | |
x = self.model(x) | |
return F.log_softmax(x, dim=1) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.nll_loss(logits, y) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
return optimizer | |
class MNISTDataModule(pl.LightningDataModule): | |
def __init__(self, data_dir: str = './'): | |
super().__init__() | |
self.data_dir = data_dir | |
self.transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
# self.dims is returned when you call dm.size() | |
# Setting default dims here because we know them. | |
# Could optionally be assigned dynamically in dm.setup() | |
self.dims = (1, 28, 28) | |
self.num_classes = 10 | |
def prepare_data(self): | |
# download | |
MNIST(self.data_dir, train=True, download=True) | |
MNIST(self.data_dir, train=False, download=True) | |
def setup(self, stage=None): | |
# Assign train/val datasets for use in dataloaders | |
if stage == 'fit' or stage is None: | |
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) | |
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) | |
# Assign test dataset for use in dataloader(s) | |
if stage == 'test' or stage is None: | |
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) | |
def train_dataloader(self): | |
print(f'Current epoch: {self.trainer.current_epoch}') | |
if self.current_epoch > 2: | |
return DataLoader(self.mnist_train, batch_size=32) | |
else: | |
return DataLoader(self.mnist_train, batch_size=32) | |
# Init DataModule | |
dm = MNISTDataModule() | |
# Init model from datamodule's attributes | |
model = LitModel(*dm.size(), dm.num_classes) | |
# Init trainer | |
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1, reload_dataloaders_every_epoch=True) | |
# Pass the datamodule as arg to trainer.fit to override model hooks :) | |
trainer.fit(model, dm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment