Last active
January 13, 2021 20:51
-
-
Save imflash217/6e412c3bf57f694009aadd3071bbe3c4 to your computer and use it in GitHub Desktop.
PyTorch Lightning Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as pt | |
import pytorch_lightning as pl | |
####################################################################### | |
class FlashModel(pl.LightningModule): | |
"""This defines a MODEL""" | |
def __init__(self, num_layers: int = 3): | |
super().__init__() | |
self.layer1 = pt.nn.Linear() | |
self.layer2 = pt.nn.Linear() | |
self.layer3 = pt.nn.Linear() | |
class FlashModel(pl.LightningModule): | |
"""This defines a SYSTEM""" | |
def __init__(self, | |
encoder: pt.nn.Module = None, | |
decoder: pt.nn.Module = None): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
##### INIT ################################################################## | |
class FlashModel(pl.LightningModule): | |
""" DON'T DO THIS""" | |
def __init__(self, params): | |
self.lr = params.lr | |
self.coeff_x = params.coeff_x | |
class FlashModel(pl.LightningModule): | |
"""Instead DO THIS""" | |
def __init__(self, | |
encoder: pt.nn.Module = None, | |
coeff_x : float = 0.2, | |
lr : float = 1e-3): | |
pass | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
####################################################################### | |
## A typical PyTorch Lightning Model looks like this => | |
class FlashModel(pl.LightningModule): | |
"""DOCSTRING""" | |
def __init__(): pass | |
def forward(): pass | |
def training_step(): pass | |
def training_step_end(): pass | |
def training_epoch_end(): pass | |
def validation_step(): pass | |
def validation_step_end(): pass | |
def validation_epoch_end(): pass | |
def test_step(): pass | |
def test_step_end(): pass | |
def test_epoch_end(): pass | |
def configure_optimizers(): pass | |
def any_other_custom_hooks(): pass | |
####################################################################### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#### FORWARD & TRAINIG STEP ######################################################################## | |
class FlashModel(pl.LightningModule): | |
"""DOCTSRING""" | |
def __init__(self): pass | |
def forward(self, x, ...): | |
""" use this for inference/predictions""" | |
embeddings = self.encoder(x) | |
def training_step(self, batch, ...): | |
"""use this for training only""" | |
x, y = batch | |
z = self.encoder(x) | |
z = self(x) ## <-- when using data-parallel DP/DDP call this instead of self.encoder() | |
pred = self.decoder(z) | |
... | |
#################################################################################################### |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment