Created
October 29, 2019 16:08
-
-
Save albanD/59883b20356a6e86fbae01771a056082 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class EasyDataParallel(nn.Module): | |
def __init__(self, gpus): | |
super().__init__() | |
# Handle cpu / 1 gpu case better | |
assert isinstance(gpus, list) | |
assert all([isinstance(gpu, int) for gpu in gpus]) | |
self.n_gpus = len(gpus) | |
self.main_gpu = gpus[0] | |
self.gpus = gpus | |
# Ignore bias here for simplicity | |
self.lin = nn.Linear(1, 1, bias=False) | |
# We manage the moving of the submodules | |
self.lin.to(self.main_gpu) | |
def forward(self, input_): | |
if self.n_gpus > 1: | |
inputs = torch.split(input_, self.n_gpus, dim=0) | |
else: | |
inputs = [input_] | |
all_outputs = [] | |
for inp, device in zip(inputs, self.gpus): | |
# Note that if they are already on the right device, these will be noops | |
w_d = self.lin.weight.to(device) | |
inp_d = inp.to(device) | |
all_outputs.append(F.linear(inp_d, w_d).to(self.main_gpu)) | |
res = torch.cat(all_outputs, dim=0) | |
return res | |
inp = torch.rand(2, 1) | |
mod = EasyDataParallel([0]) | |
out = mod(inp) | |
out.sum().backward() | |
grad0 = mod.lin.weight.grad | |
mod = EasyDataParallel([0, 1]) | |
out = mod(inp) | |
out.sum().backward() | |
grad1 = mod.lin.weight.grad | |
mod = EasyDataParallel([1]) | |
out = mod(inp) | |
out.sum().backward() | |
grad3 = mod.lin.weight.grad.to(0) | |
print((grad0 - grad1).abs().max()) | |
print((grad0 - grad3).abs().max()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment