Created
March 29, 2020 07:25
-
-
Save lezwon/4ab03f09bcffe684c22095fac38cb7d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from torchvision.datasets import MNIST | |
import torchvision.transforms as transforms | |
from torch.utils.data import random_split | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import StepLR | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import EarlyStopping | |
class Net(LightningModule): | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
super(Net, self).__init__() | |
self.correct_counter = 0 | |
self.conv1 = nn.Conv2d(1, 32, 3, 1) | |
self.conv2 = nn.Conv2d(32, 64, 3, 1) | |
self.dropout1 = nn.Dropout2d(0.25) | |
self.dropout2 = nn.Dropout2d(0.5) | |
self.fc1 = nn.Linear(9216, 128) | |
self.fc2 = nn.Linear(128, 10) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = self.conv2(x) | |
x = F.relu(x) | |
x = F.max_pool2d(x, 2) | |
x = self.dropout1(x) | |
x = torch.flatten(x, 1) | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
output = F.log_softmax(x, dim=1) | |
return output | |
def configure_optimizers(self): | |
# REQUIRED | |
# can return multiple optimizers and learning_rate schedulers | |
optimizer = optim.Adadelta(model.parameters(), lr=self.LR) | |
scheduler = StepLR(optimizer, step_size=1, gamma=self.GAMMA) | |
return [optimizer], [scheduler] | |
def prepare_data(self): | |
dataset = MNIST( | |
'../data', | |
train=True, | |
download=True, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])) | |
self.train_dataset, self.validation_dataset = random_split( | |
dataset, [50000, 10000]) | |
self.test_dataset = MNIST( | |
'../data', | |
train=False, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])) | |
def val_dataloader(self): | |
# REQUIRED | |
return DataLoader( | |
self.validation_dataset, | |
batch_size=self.BATCH_SIZE, shuffle=True | |
) | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self.forward(x) | |
self.correct_counter += (torch.max(y_hat, 1)[1].view(y.size()) == y).sum() | |
loss = F.nll_loss(y_hat, y) | |
return {'val_loss': loss} | |
def validation_epoch_end(self, outputs): | |
# OPTIONAL | |
avg_acc = 100 * self.correct_counter / len(self.validation_dataset) | |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
self.correct_counter = 0 | |
return {'val_loss': avg_loss, 'accuracy':avg_acc} | |
def train_dataloader(self): | |
# REQUIRED | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.BATCH_SIZE, shuffle=True | |
) | |
def training_step(self, batch, batch_idx): | |
# REQUIRED | |
x, y = batch | |
y_hat = self.forward(x) | |
self.correct_counter += (torch.max(y_hat, 1)[1].view(y.size()) == y).sum() | |
return {'loss': F.nll_loss(y_hat, y)} | |
def test_dataloader(self): | |
# OPTIONAL | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.TEST_BATCH_SIZE, shuffle=True | |
) | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self.forward(x) | |
self.correct_counter += (torch.max(y_hat, 1)[1].view(y.size()) == y).sum() | |
return {'test_loss': F.nll_loss(y_hat, y)} | |
def test_epoch_end(self, outputs): | |
avg_acc = 100 * self.correct_counter / len(self.test_dataset) | |
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() | |
tensorboard_logs = {'avg_test_loss': avg_loss} | |
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs, 'accuracy':avg_acc} | |
def on_epoch_start(self): | |
self.correct_counter = 0 | |
def on_pre_performance_check(self): | |
self.correct_counter = 0 | |
model = Net( | |
BATCH_SIZE = 64, | |
TEST_BATCH_SIZE = 1000, | |
LR = 1.0, | |
GAMMA = 0.7 | |
) | |
early_stop_callback = EarlyStopping( | |
monitor='val_loss', | |
min_delta=0.00, | |
patience=2, | |
verbose=True, | |
mode='min' | |
) | |
trainer = Trainer(gpus=1, max_epochs=2, early_stop_callback=early_stop_callback) | |
trainer.fit(model) | |
trainer.test() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment