Skip to content

Instantly share code, notes, and snippets.

@sgraaf
Last active November 7, 2024 05:39
Show Gist options
  • Save sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c to your computer and use it in GitHub Desktop.
Save sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c to your computer and use it in GitHub Desktop.
PyTorch Distributed Data Parallel (DDP) example
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from argparse import ArgumentParser
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers import BertForMaskedLM
SEED = 42
BATCH_SIZE = 8
NUM_EPOCHS = 3
class YourDataset(Dataset):
def __init__(self):
pass
def main():
parser = ArgumentParser('DDP usage example')
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') # you need this argument in your scripts for DDP to work
args = parser.parse_args()
# keep track of whether the current process is the `master` process (totally optional, but I find it useful for data laoding, logging, etc.)
args.is_master = args.local_rank == 0
# set the device
args.device = torch.cuda.device(args.local_rank)
# initialize PyTorch distributed using environment variables (you could also do this more explicitly by specifying `rank` and `world_size`, but I find using environment variables makes it so that you can easily use the same script on different machines)
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(args.local_rank)
# set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.)
torch.cuda.manual_seed_all(SEED)
# initialize your model (BERT in this example)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
# send your model to GPU
model = model.to(device)
# initialize distributed data parallel (DDP)
model = DDP(
model,
device_ids=[args.local_rank],
output_device=args.local_rank
)
# initialize your dataset
dataset = YourDataset()
# initialize the DistributedSampler
sampler = DistributedSampler(dataset)
# initialize the dataloader
dataloader = DataLoader(
dataset=dataset,
sampler=sampler,
batch_size=BATCH_SIZE
)
# start your training!
for epoch in range(NUM_EPOCHS):
# put model in train mode
model.train()
# let all processes sync up before starting with a new epoch of training
dist.barrier()
for step, batch in enumerate(dataloader):
# send batch to device
batch = tuple(t.to(args.device) for t in batch)
# forward pass
outputs = model(*batch)
# compute loss
loss = outputs[0]
# etc.
if __name__ == '__main__':
main()
#!/bin/bash
# this example uses a single node (`NUM_NODES=1`) w/ 4 GPUs (`NUM_GPUS_PER_NODE=4`)
export NUM_NODES=1
export NUM_GPUS_PER_NODE=4
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
# launch your script w/ `torch.distributed.launch`
python -m torch.distributed.launch \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank $NODE_RANK \
ddp_example.py \
# include any arguments to your script, e.g:
# --seed 42
# etc.
@ma7dev
Copy link

ma7dev commented May 12, 2022

I have an updated example of this and PyTorch documentation, https://github.com/sudomaze/ttorch/blob/main/examples/ddp/run.py

@1434AjaySingh
Copy link

Line 50, output_device=args.local_rank, getting below error

output_device=args.local_rank
File "/home/193112008/.conda/envs/XXX/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 285, in init
self.broadcast_bucket_size)
File "/home/193112008/.conda/envs/XXX/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 483, in _distributed_broadcast_coalesced
dist._broadcast_coalesced(self.process_group, tensors, buffer_size)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1587428091666/work/torch/lib/c10d/ProcessGroupNCCL.cpp:32, unhandled cuda error, NCCL version 2.4.8

Can you please help me to understand this error.

Many Thanks.

@ma7dev
Copy link

ma7dev commented Dec 25, 2022

Hey @1434AjaySingh,

I have updated the code above. Can you check the link above? In addition, if you need any help, we have a dedicated Discord server, PyTorch Community (unofficial), where we have a community to help people troubleshoot PyTorch-related problems, learn Machine Learning and Deep Learning, and discuss ML/DL-related topics. Feel free to join via the link below:

https://discord.gg/eNSRmh92XT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment