Skip to content

Instantly share code, notes, and snippets.

@imflash217
Created January 13, 2021 22:50
Show Gist options
  • Save imflash217/f06eebe22840098658afd1b60da46a65 to your computer and use it in GitHub Desktop.
Save imflash217/f06eebe22840098658afd1b60da46a65 to your computer and use it in GitHub Desktop.
# A LightningModule ORGANIZES the PyTorch code into the following modules:
# 1. Computations (init)
# 2. Training loop (training_step)
# 3. Validation loop (validation_step)
# 4. Test loop (test_step)
# 5. Optimizers (configure_optimizers)
##############################################################################
model = FlashModel()
trainer = Trainer()
trainer.fit(model)
### NO .cuda() or .to() calls in PL #######################################
# DO NOT do this with PL
x = torch.Tensor(2, 3)
x = x.cuda()
x.to(device)
# INSTEAD DO THIS
x = x # leave it alone!
# or to init a new tensor fo this ->
new_x = torch.tensor(2, 3)
new_x = new_x.as_type(x)
############# NO SAMPLERS for distributed
# DON'T DO THIS
data = MNIST(...)
sampler = DistributedSampler(data)
DataLoader(data, sampler=sampler)
# DO THIS
data = MNIST(...)
DataLoader(data)
############# A LightningModule is a torch.nn.Module with added functionality. Use it as such
model = FlashModel.load_from_checkpoint(PATH)
model.freeze()
out = model(x)
###########################################################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment