Last active
January 14, 2021 21:50
-
-
Save imflash217/f2c6e41db9b1013b35210d42f0713ab7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn.Functional as F | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader | |
import pytorch_lightning as pl | |
########################################################################################### | |
class FlashModel(pl.LightningModule): | |
"""DOCSTRING""" | |
def __init__(self): | |
super().__init__() | |
self.layer1 = torch.Linear(28*28, 10) | |
def forward(self, x): | |
out = self.layer1(x.view(x.shape[0], -1)) | |
out = self.relu(out) | |
return out | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = F.cross_entropy(y_hat, y) | |
return loss | |
def configure_optimizers(self): | |
"""Define optimizers and LR schedulers""" | |
return torch.optim.Adam(self.parameters(), lr=0.02) | |
########################################################################################### | |
train_loader = DataLoader(datasets.MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) | |
model = FlashModel() | |
trainer = pl.Trainer() | |
trainer.fit(model, train_loader) | |
########################################################################################### |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment