Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Last active September 10, 2018 20:04
Show Gist options
  • Save mrdrozdov/229996423c7329d94fb22e14377855dc to your computer and use it in GitHub Desktop.
Save mrdrozdov/229996423c7329d94fb22e14377855dc to your computer and use it in GitHub Desktop.
multi-fail.py
import torch
import numpy as np
class Model(torch.nn.Module):
def forward(self, x):
return torch.from_numpy(x).float().cuda()
ngpus = 2
x = np.ones((10, 10))
m = Model()
m.cuda()
out = torch.nn.parallel.data_parallel(m, x, range(ngpus))
print(x.shape) # (10, 10)
print(out.shape) # (20, 10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment