Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created July 29, 2019 17:51
Show Gist options
  • Save williamFalcon/e14e6ab22bd5f44985de133142fe4347 to your computer and use it in GitHub Desktop.
Save williamFalcon/e14e6ab22bd5f44985de133142fe4347 to your computer and use it in GitHub Desktop.
class MyModule(LightningModule):
def __init__():
self.encoder = RNN(...)
self.decoder = RNN(...)
def forward(x):
# models won't be moved after the first forward because
# they are already on the correct GPUs
self.encoder.cuda(0)
self.decoder.cuda(1)
out = self.encoder(x)
out = self.decoder(out.cuda(1))
# don't pass GPUs to trainer
model = MyModule()
trainer = Trainer()
trainer.fit(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment