Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Last active July 13, 2020 22:49
Show Gist options
  • Save taylanbil/b0fcc08b479b72d4df36a98b40cb1a25 to your computer and use it in GitHub Desktop.
Save taylanbil/b0fcc08b479b72d4df36a98b40cb1a25 to your computer and use it in GitHub Desktop.
Bug in all to all
import torch
import torch.nn as nn
import sys
#sys.path.insert(0, '/usr/share/torch-xla-nightly/pytorch/xla/')
import torch_xla.distributed.xla_multiprocessing as xmp
def main(*a):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1))
o *= xm.get_ordinal()
t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None)
xm.mark_step()
print('INPUT', xm.get_ordinal(), o.T, flush=True)
xm.rendezvous('hi')
print('RESULT', xm.get_ordinal(), t, flush=True)
xm.rendezvous('hi')
if __name__ == '__main__':
xmp.spawn(main, args=(), nprocs=8)
# OUTPUT
./alltoall.sh | sort -n
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:0')
device='xla:1')
INPUT 0 tensor([[0., 0., 0., -0., -0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[ 0.3602, 0.6926, 1.2409, -0.4852, -0.2514, 1.0309, 1.5637, 0.2221]],
INPUT 2 tensor([[ 0.7204, 1.3852, 2.4818, -0.9704, -0.5028, 2.0619, 3.1275, 0.4442]],
INPUT 3 tensor([[ 1.0806, 2.0779, 3.7226, -1.4556, -0.7542, 3.0928, 4.6912, 0.6663]],
INPUT 4 tensor([[ 1.4408, 2.7705, 4.9635, -1.9408, -1.0056, 4.1238, 6.2550, 0.8884]],
INPUT 5 tensor([[ 1.8010, 3.4631, 6.2044, -2.4260, -1.2570, 5.1547, 7.8187, 1.1105]],
INPUT 6 tensor([[ 2.1612, 4.1557, 7.4453, -2.9112, -1.5084, 6.1857, 9.3825, 1.3326]],
INPUT 7 tensor([[ 2.5214, 4.8484, 8.6862, -3.3964, -1.7598, 7.2166, 10.9462, 1.5547]],
RESULT 0 tensor([[0.0000, 0.0000, 0.7204, 1.0806, 1.4408, 1.8010, 2.1612, 2.5214]],
RESULT 1 tensor([[0.0000e+00, 1.3852e+00, 2.0779e+00, 2.7705e+00, 3.4631e+00, 4.1557e+00, 4.8484e+00, 1.5236e-35]], device='xla:0')
RESULT 2 tensor([[0.0000, 0.0000, 1.2409, 3.7226, 4.9635, 6.2044, 7.4453, 8.6862]],
RESULT 3 tensor([[-0.0000, 0.0000, -0.9704, -1.4556, -1.9408, -2.4260, -2.9112, -3.3964]],
RESULT 4 tensor([[-0.0000, 0.0000, -0.5028, -0.7542, -1.0056, -1.2570, -1.5084, -1.7598]],
RESULT 5 tensor([[0.0000, 0.0000, 1.0309, 3.0928, 4.1238, 5.1547, 6.1857, 7.2166]],
RESULT 6 tensor([[ 0.0000, 0.0000, 3.1275, 4.6912, 6.2550, 7.8187, 9.3825, 10.9462]],
RESULT 7 tensor([[0.0000, 0.0000, 0.4442, 0.6663, 0.8884, 1.1105, 1.3326, 1.5547]],
@taylanbil
Copy link
Author

Replacing the tensor o with

 36     o = torch.tensor([float(xm.get_ordinal()) for _ in range(8)] ,
 37             dtype=torch.float, device=device).reshape((8,1))

I get the correct result:

$ ./alltoall.sh | sort -n
INPUT 0 tensor([[0., 0., 0., 0., 0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[1., 1., 1., 1., 1., 1., 1., 1.]], device='xla:0')
INPUT 2 tensor([[4., 4., 4., 4., 4., 4., 4., 4.]], device='xla:0')
INPUT 3 tensor([[9., 9., 9., 9., 9., 9., 9., 9.]], device='xla:0')
INPUT 4 tensor([[16., 16., 16., 16., 16., 16., 16., 16.]], device='xla:0')
INPUT 5 tensor([[25., 25., 25., 25., 25., 25., 25., 25.]], device='xla:0')
INPUT 6 tensor([[36., 36., 36., 36., 36., 36., 36., 36.]], device='xla:0')
INPUT 7 tensor([[49., 49., 49., 49., 49., 49., 49., 49.]], device='xla:0')
RESULT 0 tensor([[0.0000e+00, 4.0000e+00, 9.0000e+00, 1.6000e+01, 2.5000e+01, 3.6000e+01, 4.9000e+01, 1.5236e-35]], device='xla:1')
RESULT 1 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 2 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 3 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 4 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 5 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 6 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 7 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')

@taylanbil
Copy link
Author

Oh actually that's not correct either, result 0 is wonky.... 1 is missing at position 1, and there is a e-35 appended at the end.

@taylanbil
Copy link
Author

Getting correct result with this input (commenting out o*=xm.get_ordinal()):

def main(*a):
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    o = torch.tensor([float(xm.get_ordinal()) for _ in range(8)] ,
            dtype=torch.float, device=device).reshape((8,1))
    #o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1))
    #o *= xm.get_ordinal()

    t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None)
    xm.mark_step()
    print('INPUT', xm.get_ordinal(), o.T, flush=True)
    xm.rendezvous('hi')
    print('RESULT', xm.get_ordinal(), t, flush=True)
    xm.rendezvous('hi')
INPUT 0 tensor([[0., 0., 0., 0., 0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[1., 1., 1., 1., 1., 1., 1., 1.]], device='xla:0')
INPUT 2 tensor([[2., 2., 2., 2., 2., 2., 2., 2.]], device='xla:0')
INPUT 3 tensor([[3., 3., 3., 3., 3., 3., 3., 3.]], device='xla:0')
INPUT 4 tensor([[4., 4., 4., 4., 4., 4., 4., 4.]], device='xla:0')
INPUT 5 tensor([[5., 5., 5., 5., 5., 5., 5., 5.]], device='xla:0')
INPUT 6 tensor([[6., 6., 6., 6., 6., 6., 6., 6.]], device='xla:0')
INPUT 7 tensor([[7., 7., 7., 7., 7., 7., 7., 7.]], device='xla:0')
RESULT 0 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:1')
RESULT 1 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 2 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 3 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 4 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 5 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 6 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')
RESULT 7 tensor([[0., 1., 2., 3., 4., 5., 6., 7.]], device='xla:0')

@taylanbil
Copy link
Author

Marking step before all to all, ordinal squared inputs also give correct result (as opposed to comment above)

def main(*a):
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    o = torch.tensor([float(xm.get_ordinal()) for _ in range(8)] ,
            dtype=torch.float, device=device).reshape((8,1))
    #o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1))
    o *= xm.get_ordinal()

    xm.mark_step()  # <- ADDED THIS
    t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None)
    xm.mark_step()
    print('INPUT', xm.get_ordinal(), o.T, flush=True)
    xm.rendezvous('hi')
    print('RESULT', xm.get_ordinal(), t, flush=True)
    xm.rendezvous('hi')
$ ./alltoall.sh | sort -n
INPUT 0 tensor([[0., 0., 0., 0., 0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[1., 1., 1., 1., 1., 1., 1., 1.]], device='xla:0')
INPUT 2 tensor([[4., 4., 4., 4., 4., 4., 4., 4.]], device='xla:0')
INPUT 3 tensor([[9., 9., 9., 9., 9., 9., 9., 9.]], device='xla:0')
INPUT 4 tensor([[16., 16., 16., 16., 16., 16., 16., 16.]], device='xla:0')
INPUT 5 tensor([[25., 25., 25., 25., 25., 25., 25., 25.]], device='xla:0')
INPUT 6 tensor([[36., 36., 36., 36., 36., 36., 36., 36.]], device='xla:0')
INPUT 7 tensor([[49., 49., 49., 49., 49., 49., 49., 49.]], device='xla:0')
RESULT 0 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:1')
RESULT 1 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 2 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 3 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 4 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 5 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 6 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')
RESULT 7 tensor([[ 0.,  1.,  4.,  9., 16., 25., 36., 49.]], device='xla:0')

@taylanbil
Copy link
Author

Result looks correct with random input + marking step before alltoall

def main(*a):
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    #o = torch.tensor([float(xm.get_ordinal()) for _ in range(8)] ,
    #        dtype=torch.float, device=device).reshape((8,1))
    o = torch.randn((8,1) , dtype=torch.float, device=device).reshape((8,1))
    o *= xm.get_ordinal()

    xm.mark_step()
    t = xm.all_to_all(o, split_dimension=0, concat_dimension=1, split_count=8, groups=None)
    xm.mark_step()
    print('INPUT', xm.get_ordinal(), o.T, flush=True)
    xm.rendezvous('hi')
    print('RESULT', xm.get_ordinal(), t, flush=True)
    xm.rendezvous('hi')
$ ./alltoall.sh | sort -n
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:0')
       device='xla:1')
INPUT 0 tensor([[0., 0., 0., -0., -0., 0., 0., 0.]], device='xla:1')
INPUT 1 tensor([[ 0.3602,  0.6926,  1.2409, -0.4852, -0.2514,  1.0309,  1.5637,  0.2221]],
INPUT 2 tensor([[ 0.7204,  1.3852,  2.4818, -0.9704, -0.5028,  2.0619,  3.1275,  0.4442]],
INPUT 3 tensor([[ 1.0806,  2.0779,  3.7226, -1.4556, -0.7542,  3.0928,  4.6912,  0.6663]],
INPUT 4 tensor([[ 1.4408,  2.7705,  4.9635, -1.9408, -1.0056,  4.1238,  6.2550,  0.8884]],
INPUT 5 tensor([[ 1.8010,  3.4631,  6.2044, -2.4260, -1.2570,  5.1547,  7.8187,  1.1105]],
INPUT 6 tensor([[ 2.1612,  4.1557,  7.4453, -2.9112, -1.5084,  6.1857,  9.3825,  1.3326]],
INPUT 7 tensor([[ 2.5214,  4.8484,  8.6862, -3.3964, -1.7598,  7.2166, 10.9462,  1.5547]],
RESULT 0 tensor([[0.0000, 0.3602, 0.7204, 1.0806, 1.4408, 1.8010, 2.1612, 2.5214]],
RESULT 1 tensor([[0.0000, 0.6926, 1.3852, 2.0779, 2.7705, 3.4631, 4.1557, 4.8484]],
RESULT 2 tensor([[0.0000, 1.2409, 2.4818, 3.7226, 4.9635, 6.2044, 7.4453, 8.6862]],
RESULT 3 tensor([[-0.0000, -0.4852, -0.9704, -1.4556, -1.9408, -2.4260, -2.9112, -3.3964]],
RESULT 4 tensor([[-0.0000, -0.2514, -0.5028, -0.7542, -1.0056, -1.2570, -1.5084, -1.7598]],
RESULT 5 tensor([[0.0000, 1.0309, 2.0619, 3.0928, 4.1238, 5.1547, 6.1857, 7.2166]],
RESULT 6 tensor([[ 0.0000,  1.5637,  3.1275,  4.6912,  6.2550,  7.8187,  9.3825, 10.9462]],
RESULT 7 tensor([[0.0000, 0.2221, 0.4442, 0.6663, 0.8884, 1.1105, 1.3326, 1.5547]],

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment