Last active
December 13, 2022 19:15
-
-
Save thomwolf/387ea8c8f24290fc8f55050af089ac47 to your computer and use it in GitHub Desktop.
Using DistributedDataParallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data import DataLoader | |
# Each process runs on 1 GPU device specified by the local_rank argument. | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local_rank", type=int) | |
args = parser.parse_args() | |
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs | |
torch.distributed.init_process_group(backend='nccl') | |
# Encapsulate the model on the GPU assigned to the current process | |
device = torch.device('cuda', arg.local_rank) | |
model = model.to(device) | |
distrib_model = torch.nn.parallel.DistributedDataParallel(model, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank) | |
# Restricts data loading to a subset of the dataset exclusive to the current process | |
sampler = DistributedSampler(dataset) | |
dataloader = DataLoader(dataset, sampler=sampler) | |
for inputs, labels in dataloader: | |
predictions = distrib_model(inputs.to(device)) # Forward pass | |
loss = loss_function(predictions, labels.to(device)) # Compute loss function | |
loss.backward() # Backward pass | |
optimizer.step() # Optimizer step |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Tom, should we use DistributedSampler if we have 1 node with 8 GPUs and 24 CPUs when using DDP or is it only for when we have more than 1 nodes? If we should use it, what would happen if we use normal DataLoader without a DistributedSampler? Thanks a lot for the code snippet.