Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active March 26, 2020 13:25
Show Gist options
  • Save williamFalcon/a2437938e38bb58f2859383f93954def to your computer and use it in GitHub Desktop.
Save williamFalcon/a2437938e38bb58f2859383f93954def to your computer and use it in GitHub Desktop.
# model
class Net(LightningModule):
def __init__(self):
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
def train_dataloader(self):
mnist_train = MNIST(os.getcwd(), train=True, download=True,
transform=transforms.ToTensor())
return DataLoader(mnist_train, batch_size=64)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=1)
return optimizer, scheduler
def training_step(self, batch, batch_idx):
data, target = batch
output = self.forward(data)
loss = F.nll_loss(output, target)
return {'loss': loss}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment