Last active
December 22, 2020 12:45
-
-
Save nilesh0109/b069f9a74c5a0a801104a10de3163b79 to your computer and use it in GitHub Desktop.
DistributedDataParallel pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data import DataLoader | |
#prase the local_rank argument from command line for the current process | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local_rank", default=0, type=int) | |
args = parser.parse_args() | |
#setup the distributed backend for managing the distributed training | |
torch.distributed.init_process_group('nccl') | |
#Setup the distributed sampler to split the dataset to each GPU. | |
dist_sampler = DistributedSampler(dataset) | |
dataloader = DataLoader(dataset, sampler=dist_sampler) | |
#set the cuda device to a GPU allocated to current process . | |
device = torch.device('cuda', args.local_rank) | |
model = model.to(device) | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], | |
output_device=args.local_rank) | |
#Start training the model normally. | |
for inputs, labels in dataloader: | |
inputs = inputs.to(device) | |
labels = labels.to(device) | |
preds = model(inputs) | |
loss = loss_fn(preds, labels) | |
loss.backward() | |
optimizer.step() | |
# TO start the process run the following command from terminal. | |
# python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_port=1234 distributedDataParallel.py <OTHER TRAINING ARGS> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment