-
-
Save sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from argparse import ArgumentParser | |
import torch | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.utils.data import DataLoader, Dataset | |
from torch.utils.data.distributed import DistributedSampler | |
from transformers import BertForMaskedLM | |
SEED = 42 | |
BATCH_SIZE = 8 | |
NUM_EPOCHS = 3 | |
class YourDataset(Dataset): | |
def __init__(self): | |
pass | |
def main(): | |
parser = ArgumentParser('DDP usage example') | |
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') # you need this argument in your scripts for DDP to work | |
args = parser.parse_args() | |
# keep track of whether the current process is the `master` process (totally optional, but I find it useful for data laoding, logging, etc.) | |
args.is_master = args.local_rank == 0 | |
# set the device | |
args.device = torch.cuda.device(args.local_rank) | |
# initialize PyTorch distributed using environment variables (you could also do this more explicitly by specifying `rank` and `world_size`, but I find using environment variables makes it so that you can easily use the same script on different machines) | |
dist.init_process_group(backend='nccl', init_method='env://') | |
torch.cuda.set_device(args.local_rank) | |
# set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.) | |
torch.cuda.manual_seed_all(SEED) | |
# initialize your model (BERT in this example) | |
model = BertForMaskedLM.from_pretrained('bert-base-uncased') | |
# send your model to GPU | |
model = model.to(device) | |
# initialize distributed data parallel (DDP) | |
model = DDP( | |
model, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank | |
) | |
# initialize your dataset | |
dataset = YourDataset() | |
# initialize the DistributedSampler | |
sampler = DistributedSampler(dataset) | |
# initialize the dataloader | |
dataloader = DataLoader( | |
dataset=dataset, | |
sampler=sampler, | |
batch_size=BATCH_SIZE | |
) | |
# start your training! | |
for epoch in range(NUM_EPOCHS): | |
# put model in train mode | |
model.train() | |
# let all processes sync up before starting with a new epoch of training | |
dist.barrier() | |
for step, batch in enumerate(dataloader): | |
# send batch to device | |
batch = tuple(t.to(args.device) for t in batch) | |
# forward pass | |
outputs = model(*batch) | |
# compute loss | |
loss = outputs[0] | |
# etc. | |
if __name__ == '__main__': | |
main() |
#!/bin/bash | |
# this example uses a single node (`NUM_NODES=1`) w/ 4 GPUs (`NUM_GPUS_PER_NODE=4`) | |
export NUM_NODES=1 | |
export NUM_GPUS_PER_NODE=4 | |
export NODE_RANK=0 | |
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE)) | |
# launch your script w/ `torch.distributed.launch` | |
python -m torch.distributed.launch \ | |
--nproc_per_node=$NUM_GPUS_PER_NODE \ | |
--nnodes=$NUM_NODES \ | |
--node_rank $NODE_RANK \ | |
ddp_example.py \ | |
# include any arguments to your script, e.g: | |
# --seed 42 | |
# etc. |
Line 50, output_device=args.local_rank, getting below error
output_device=args.local_rank
File "/home/193112008/.conda/envs/XXX/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 285, in init
self.broadcast_bucket_size)
File "/home/193112008/.conda/envs/XXX/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 483, in _distributed_broadcast_coalesced
dist._broadcast_coalesced(self.process_group, tensors, buffer_size)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1587428091666/work/torch/lib/c10d/ProcessGroupNCCL.cpp:32, unhandled cuda error, NCCL version 2.4.8
Can you please help me to understand this error.
Many Thanks.
Hey @1434AjaySingh,
I have updated the code above. Can you check the link above? In addition, if you need any help, we have a dedicated Discord server, PyTorch Community (unofficial)
, where we have a community to help people troubleshoot PyTorch-related problems, learn Machine Learning and Deep Learning, and discuss ML/DL-related topics. Feel free to join via the link below:
I have an updated example of this and PyTorch documentation, https://github.com/sudomaze/ttorch/blob/main/examples/ddp/run.py