Last active
July 13, 2020 22:49
-
-
Save taylanbil/b0fcc08b479b72d4df36a98b40cb1a25 to your computer and use it in GitHub Desktop.
Bug in all to all
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import sys | |
#sys.path.insert(0, '/usr/share/torch-xla-nightly/pytorch/xla/') | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
def main(*a): | |
import torch_xla.core.xla_model as xm | |
device = xm.xla_device() | |
o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1)) | |
o *= xm.get_ordinal() | |
t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None) | |
xm.mark_step() | |
print('INPUT', xm.get_ordinal(), o.T, flush=True) | |
xm.rendezvous('hi') | |
print('RESULT', xm.get_ordinal(), t, flush=True) | |
xm.rendezvous('hi') | |
if __name__ == '__main__': | |
xmp.spawn(main, args=(), nprocs=8) | |
# OUTPUT | |
./alltoall.sh | sort -n | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:0') | |
device='xla:1') | |
INPUT 0 tensor([[0., 0., 0., -0., -0., 0., 0., 0.]], device='xla:1') | |
INPUT 1 tensor([[ 0.3602, 0.6926, 1.2409, -0.4852, -0.2514, 1.0309, 1.5637, 0.2221]], | |
INPUT 2 tensor([[ 0.7204, 1.3852, 2.4818, -0.9704, -0.5028, 2.0619, 3.1275, 0.4442]], | |
INPUT 3 tensor([[ 1.0806, 2.0779, 3.7226, -1.4556, -0.7542, 3.0928, 4.6912, 0.6663]], | |
INPUT 4 tensor([[ 1.4408, 2.7705, 4.9635, -1.9408, -1.0056, 4.1238, 6.2550, 0.8884]], | |
INPUT 5 tensor([[ 1.8010, 3.4631, 6.2044, -2.4260, -1.2570, 5.1547, 7.8187, 1.1105]], | |
INPUT 6 tensor([[ 2.1612, 4.1557, 7.4453, -2.9112, -1.5084, 6.1857, 9.3825, 1.3326]], | |
INPUT 7 tensor([[ 2.5214, 4.8484, 8.6862, -3.3964, -1.7598, 7.2166, 10.9462, 1.5547]], | |
RESULT 0 tensor([[0.0000, 0.0000, 0.7204, 1.0806, 1.4408, 1.8010, 2.1612, 2.5214]], | |
RESULT 1 tensor([[0.0000e+00, 1.3852e+00, 2.0779e+00, 2.7705e+00, 3.4631e+00, 4.1557e+00, 4.8484e+00, 1.5236e-35]], device='xla:0') | |
RESULT 2 tensor([[0.0000, 0.0000, 1.2409, 3.7226, 4.9635, 6.2044, 7.4453, 8.6862]], | |
RESULT 3 tensor([[-0.0000, 0.0000, -0.9704, -1.4556, -1.9408, -2.4260, -2.9112, -3.3964]], | |
RESULT 4 tensor([[-0.0000, 0.0000, -0.5028, -0.7542, -1.0056, -1.2570, -1.5084, -1.7598]], | |
RESULT 5 tensor([[0.0000, 0.0000, 1.0309, 3.0928, 4.1238, 5.1547, 6.1857, 7.2166]], | |
RESULT 6 tensor([[ 0.0000, 0.0000, 3.1275, 4.6912, 6.2550, 7.8187, 9.3825, 10.9462]], | |
RESULT 7 tensor([[0.0000, 0.0000, 0.4442, 0.6663, 0.8884, 1.1105, 1.3326, 1.5547]], |
Oh actually that's not correct either, result 0 is wonky.... 1
is missing at position 1, and there is a e-35 appended at the end.
Getting correct result with this input (commenting out o*=xm.get_ordinal()
):
def main(*a):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
o = torch.tensor([float(xm.get_ordinal()) for _ in range(8)] ,
dtype=torch.float, device=device).reshape((8,1))
#o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1))
#o *= xm.get_ordinal()
t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None)
xm.mark_step()
print('INPUT', xm.get_ordinal(), o.T, flush=True)
xm.rendezvous('hi')
print('RESULT', xm.get_ordinal(), t, flush=True)
xm.rendezvous('hi')
INPUT 0 tensor([[0., 0., 0., 0., 0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[1., 1., 1., 1., 1., 1., 1., 1.]], device='xla:0')
INPUT 2 tensor([[2., 2., 2., 2., 2., 2., 2., 2.]], device='xla:0')
INPUT 3 tensor([[3., 3., 3., 3., 3., 3., 3., 3.]], device='xla:0')
INPUT 4 tensor([[4., 4., 4., 4., 4., 4., 4., 4.]], device='xla:0')
INPUT 5 tensor([[5., 5., 5., 5., 5., 5., 5., 5.]], device='xla:0')
INPUT 6 tensor([[6., 6., 6., 6., 6., 6., 6., 6.]], device='xla:0')
INPUT 7 tensor([[7., 7., 7., 7., 7., 7., 7., 7.]], device='xla:0')
RESULT 0 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:1')
RESULT 1 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 2 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 3 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 4 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 5 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 6 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 7 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
Marking step before all to all, ordinal squared inputs also give correct result (as opposed to comment above)
def main(*a):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
o = torch.tensor([float(xm.get_ordinal()) for _ in range(8)] ,
dtype=torch.float, device=device).reshape((8,1))
#o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1))
o *= xm.get_ordinal()
xm.mark_step() # <- ADDED THIS
t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None)
xm.mark_step()
print('INPUT', xm.get_ordinal(), o.T, flush=True)
xm.rendezvous('hi')
print('RESULT', xm.get_ordinal(), t, flush=True)
xm.rendezvous('hi')
$ ./alltoall.sh | sort -n
INPUT 0 tensor([[0., 0., 0., 0., 0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[1., 1., 1., 1., 1., 1., 1., 1.]], device='xla:0')
INPUT 2 tensor([[4., 4., 4., 4., 4., 4., 4., 4.]], device='xla:0')
INPUT 3 tensor([[9., 9., 9., 9., 9., 9., 9., 9.]], device='xla:0')
INPUT 4 tensor([[16., 16., 16., 16., 16., 16., 16., 16.]], device='xla:0')
INPUT 5 tensor([[25., 25., 25., 25., 25., 25., 25., 25.]], device='xla:0')
INPUT 6 tensor([[36., 36., 36., 36., 36., 36., 36., 36.]], device='xla:0')
INPUT 7 tensor([[49., 49., 49., 49., 49., 49., 49., 49.]], device='xla:0')
RESULT 0 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:1')
RESULT 1 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:0')
RESULT 2 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:0')
RESULT 3 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:0')
RESULT 4 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:0')
RESULT 5 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:0')
RESULT 6 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:0')
RESULT 7 tensor([[ 0., 1., 4., 9., 16., 25., 36., 49.]], device='xla:0')
Result looks correct with random input + marking step before alltoall
def main(*a):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
#o = torch.tensor([float(xm.get_ordinal()) for _ in range(8)] ,
# dtype=torch.float, device=device).reshape((8,1))
o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1))
o *= xm.get_ordinal()
xm.mark_step()
t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None)
xm.mark_step()
print('INPUT', xm.get_ordinal(), o.T, flush=True)
xm.rendezvous('hi')
print('RESULT', xm.get_ordinal(), t, flush=True)
xm.rendezvous('hi')
$ ./alltoall.sh | sort -n
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:1')
INPUT 0 tensor([[0., 0., 0., -0., -0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[ 0.3602, 0.6926, 1.2409, -0.4852, -0.2514, 1.0309, 1.5637, 0.2221]],
INPUT 2 tensor([[ 0.7204, 1.3852, 2.4818, -0.9704, -0.5028, 2.0619, 3.1275, 0.4442]],
INPUT 3 tensor([[ 1.0806, 2.0779, 3.7226, -1.4556, -0.7542, 3.0928, 4.6912, 0.6663]],
INPUT 4 tensor([[ 1.4408, 2.7705, 4.9635, -1.9408, -1.0056, 4.1238, 6.2550, 0.8884]],
INPUT 5 tensor([[ 1.8010, 3.4631, 6.2044, -2.4260, -1.2570, 5.1547, 7.8187, 1.1105]],
INPUT 6 tensor([[ 2.1612, 4.1557, 7.4453, -2.9112, -1.5084, 6.1857, 9.3825, 1.3326]],
INPUT 7 tensor([[ 2.5214, 4.8484, 8.6862, -3.3964, -1.7598, 7.2166, 10.9462, 1.5547]],
RESULT 0 tensor([[0.0000, 0.3602, 0.7204, 1.0806, 1.4408, 1.8010, 2.1612, 2.5214]],
RESULT 1 tensor([[0.0000, 0.6926, 1.3852, 2.0779, 2.7705, 3.4631, 4.1557, 4.8484]],
RESULT 2 tensor([[0.0000, 1.2409, 2.4818, 3.7226, 4.9635, 6.2044, 7.4453, 8.6862]],
RESULT 3 tensor([[-0.0000, -0.4852, -0.9704, -1.4556, -1.9408, -2.4260, -2.9112, -3.3964]],
RESULT 4 tensor([[-0.0000, -0.2514, -0.5028, -0.7542, -1.0056, -1.2570, -1.5084, -1.7598]],
RESULT 5 tensor([[0.0000, 1.0309, 2.0619, 3.0928, 4.1238, 5.1547, 6.1857, 7.2166]],
RESULT 6 tensor([[ 0.0000, 1.5637, 3.1275, 4.6912, 6.2550, 7.8187, 9.3825, 10.9462]],
RESULT 7 tensor([[0.0000, 0.2221, 0.4442, 0.6663, 0.8884, 1.1105, 1.3326, 1.5547]],
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Replacing the tensor
o
withI get the correct result: