Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active March 6, 2020 11:11
Show Gist options
  • Save williamFalcon/b68e8f7786ca2d39d3c96123dd9e44f2 to your computer and use it in GitHub Desktop.
Save williamFalcon/b68e8f7786ca2d39d3c96123dd9e44f2 to your computer and use it in GitHub Desktop.
# model
class Net(LightningModule):
def __init__(self):
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
def prepare_data(self):
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=64)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=64)
return mnist_val
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=1e-3)
return optimizer, StepLR(optimizer, step_size=1)
def training_step(self, batch, batch_idx):
data, target = batch
output = self.forward(data)
loss = F.nll_loss(output, target)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self.forward(data)
loss = F.nll_loss(output, target)
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct = pred.eq(target.view_as(pred)).sum().item()
return {'val_loss': loss, 'correct': correct}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
if __name__ == '__main__:
net = Net()
trainer = Trainer()
trainer.fit(net)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment