Skip to content

Instantly share code, notes, and snippets.

@albanD
Created October 29, 2019 16:08
Show Gist options
  • Save albanD/59883b20356a6e86fbae01771a056082 to your computer and use it in GitHub Desktop.
Save albanD/59883b20356a6e86fbae01771a056082 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch.nn import functional as F
class EasyDataParallel(nn.Module):
def __init__(self, gpus):
super().__init__()
# Handle cpu / 1 gpu case better
assert isinstance(gpus, list)
assert all([isinstance(gpu, int) for gpu in gpus])
self.n_gpus = len(gpus)
self.main_gpu = gpus[0]
self.gpus = gpus
# Ignore bias here for simplicity
self.lin = nn.Linear(1, 1, bias=False)
# We manage the moving of the submodules
self.lin.to(self.main_gpu)
def forward(self, input_):
if self.n_gpus > 1:
inputs = torch.split(input_, self.n_gpus, dim=0)
else:
inputs = [input_]
all_outputs = []
for inp, device in zip(inputs, self.gpus):
# Note that if they are already on the right device, these will be noops
w_d = self.lin.weight.to(device)
inp_d = inp.to(device)
all_outputs.append(F.linear(inp_d, w_d).to(self.main_gpu))
res = torch.cat(all_outputs, dim=0)
return res
inp = torch.rand(2, 1)
mod = EasyDataParallel([0])
out = mod(inp)
out.sum().backward()
grad0 = mod.lin.weight.grad
mod = EasyDataParallel([0, 1])
out = mod(inp)
out.sum().backward()
grad1 = mod.lin.weight.grad
mod = EasyDataParallel([1])
out = mod(inp)
out.sum().backward()
grad3 = mod.lin.weight.grad.to(0)
print((grad0 - grad1).abs().max())
print((grad0 - grad3).abs().max())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment