Created
July 29, 2019 17:52
-
-
Save williamFalcon/a59f97c0ca3df959bc32fcd50dacd5c5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tng_dataloader(): | |
d = MNIST() | |
# 4: Add distributed sampler | |
# sampler sends a portion of tng data to each machine | |
dist_sampler = DistributedSampler(dataset) | |
dataloader = DataLoader(d, shuffle=False, sampler=dist_sampler) | |
def main_process_entrypoint(gpu_nb): | |
# 2: set up connections between all gpus across all machines | |
# all gpus connect to a single GPU "root" | |
# the default uses env:// | |
world = nb_gpus * nb_nodes | |
dist.init_process_group("nccl", rank=gpu_nb, world_size=world) | |
# 3: wrap model in DPP | |
torch.cuda.set_device(gpu_nb) | |
model.cuda(gpu_nb) | |
model = DistributedDataParallel(model, device_ids=[gpu_nb]) | |
# train your model now... | |
if __name__ == '__main__': | |
# 1: spawn number of processes | |
# your cluster will call main for each machine | |
mp.spawn(main_process_entrypoint, nprocs=8) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment