Skip to content

Instantly share code, notes, and snippets.

@takagi
Created July 22, 2019 09:10
Show Gist options
  • Save takagi/b305557cc260d9512c64cd69c8076722 to your computer and use it in GitHub Desktop.
Save takagi/b305557cc260d9512c64cd69c8076722 to your computer and use it in GitHub Desktop.
Test code for cupy.cuda.nccl.NcclCommunicator's broadcast method.
import multiprocessing
import cupy
from cupy import cuda
from cupy.cuda import nccl
from cupy import testing
def f(n_devices, device, comm_id, rank):
device.use()
comm = nccl.NcclCommunicator(n_devices, comm_id, rank)
x = cupy.zeros((2, 3, 4), dtype='float32')
comm.broadcast(
x.data.ptr, x.data.ptr, x.size, nccl.NCCL_FLOAT, 0,
cuda.Stream.null.ptr)
e = cupy.ones((2, 3, 4), dtype='float32')
testing.assert_allclose(x, e)
device.synchronize()
print('Rank {} successfully finished.'.format(rank))
if __name__ == '__main__':
n_devices = 4
devices = [cuda.Device(i) for i in range(n_devices)]
comm_id = nccl.get_unique_id()
ps = []
for i in range(1, n_devices):
p = multiprocessing.Process(
target=f, args=(n_devices, devices[i], comm_id, i))
p.start()
ps.append(p)
device = devices[0]
device.use()
comm = nccl.NcclCommunicator(n_devices, comm_id, 0)
x = cupy.ones((2, 3, 4), dtype='float32')
comm.broadcast(
x.data.ptr, x.data.ptr, x.size, nccl.NCCL_FLOAT, 0,
cuda.Stream.null.ptr)
for p in ps:
p.join()
print('Rank 0 successfully finished.')
@whuLames
Copy link

good job!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment