Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created July 29, 2019 17:52
Show Gist options
  • Save williamFalcon/a59f97c0ca3df959bc32fcd50dacd5c5 to your computer and use it in GitHub Desktop.
Save williamFalcon/a59f97c0ca3df959bc32fcd50dacd5c5 to your computer and use it in GitHub Desktop.
def tng_dataloader():
d = MNIST()
# 4: Add distributed sampler
# sampler sends a portion of tng data to each machine
dist_sampler = DistributedSampler(dataset)
dataloader = DataLoader(d, shuffle=False, sampler=dist_sampler)
def main_process_entrypoint(gpu_nb):
# 2: set up connections between all gpus across all machines
# all gpus connect to a single GPU "root"
# the default uses env://
world = nb_gpus * nb_nodes
dist.init_process_group("nccl", rank=gpu_nb, world_size=world)
# 3: wrap model in DPP
torch.cuda.set_device(gpu_nb)
model.cuda(gpu_nb)
model = DistributedDataParallel(model, device_ids=[gpu_nb])
# train your model now...
if __name__ == '__main__':
# 1: spawn number of processes
# your cluster will call main for each machine
mp.spawn(main_process_entrypoint, nprocs=8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment