Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active November 9, 2018 10:39
Show Gist options
  • Save thomwolf/3359a5e6d534b97be0cf160fb4f6bbcb to your computer and use it in GitHub Desktop.
Save thomwolf/3359a5e6d534b97be0cf160fb4f6bbcb to your computer and use it in GitHub Desktop.
Pytorch nn.DataParallel
parallel_model = torch.nn.DataParallel(model) # Encapsulate the model
predictions = parallel_model(inputs) # Forward pass on multi-GPUs
loss = loss_function(predictions, labels) # Compute loss function
loss.mean().backward() # Average GPU-losses + backward pass
optimizer.step() # Optimizer step
predictions = parallel_model(inputs) # Forward pass with new parameters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment