Created
January 13, 2021 22:50
-
-
Save imflash217/f06eebe22840098658afd1b60da46a65 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A LightningModule ORGANIZES the PyTorch code into the following modules: | |
# 1. Computations (init) | |
# 2. Training loop (training_step) | |
# 3. Validation loop (validation_step) | |
# 4. Test loop (test_step) | |
# 5. Optimizers (configure_optimizers) | |
############################################################################## | |
model = FlashModel() | |
trainer = Trainer() | |
trainer.fit(model) | |
### NO .cuda() or .to() calls in PL ####################################### | |
# DO NOT do this with PL | |
x = torch.Tensor(2, 3) | |
x = x.cuda() | |
x.to(device) | |
# INSTEAD DO THIS | |
x = x # leave it alone! | |
# or to init a new tensor fo this -> | |
new_x = torch.tensor(2, 3) | |
new_x = new_x.as_type(x) | |
############# NO SAMPLERS for distributed | |
# DON'T DO THIS | |
data = MNIST(...) | |
sampler = DistributedSampler(data) | |
DataLoader(data, sampler=sampler) | |
# DO THIS | |
data = MNIST(...) | |
DataLoader(data) | |
############# A LightningModule is a torch.nn.Module with added functionality. Use it as such | |
model = FlashModel.load_from_checkpoint(PATH) | |
model.freeze() | |
out = model(x) | |
########################################################################################### |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment