Last active
November 9, 2018 10:39
-
-
Save thomwolf/3359a5e6d534b97be0cf160fb4f6bbcb to your computer and use it in GitHub Desktop.
Pytorch nn.DataParallel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
parallel_model = torch.nn.DataParallel(model) # Encapsulate the model | |
predictions = parallel_model(inputs) # Forward pass on multi-GPUs | |
loss = loss_function(predictions, labels) # Compute loss function | |
loss.mean().backward() # Average GPU-losses + backward pass | |
optimizer.step() # Optimizer step | |
predictions = parallel_model(inputs) # Forward pass with new parameters |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment