Last active
May 4, 2024 23:51
-
-
Save zhuangh/176119998d615bc4eeb96659fd21f23f to your computer and use it in GitHub Desktop.
single_gpu_ddp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# python single_gpu_ddp.py | |
# https://discuss.pytorch.org/t/single-machine-single-gpu-distributed-best-practices/169243 | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import os | |
def setup(rank, world_size): | |
# Configure the distributed environment. | |
# 'gloo' can be used in environments where 'nccl' is not supported, like on CPUs. | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12345' | |
dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
#dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
def cleanup(): | |
# Cleanup the distributed environment. | |
dist.destroy_process_group() | |
def example(rank, world_size, device="cpu"): | |
setup(rank, world_size) | |
# Create a simple model. | |
model = nn.Linear(10, 1) | |
# Move model to the specified device | |
model.to(device) | |
# Wrap the model in DistributedDataParallel using the CPU (or specific device). | |
ddp_model = DDP(model, device_ids=[device]) | |
# Create some dummy input data suitable for the model dimensions, distributed to the appropriate device. | |
inputs = torch.randn(64, 10).to(device) | |
targets = torch.randn(64, 1).to(device) | |
# Forward pass | |
outputs = ddp_model(inputs) | |
# Compute the loss | |
loss_fn = nn.MSELoss() | |
loss = loss_fn(outputs, targets) | |
# Backward pass | |
loss.backward() | |
# Aggregate gradients using all_reduce with sum operation | |
for param in ddp_model.parameters(): | |
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) | |
# Scale the gradients by the number of processes | |
for param in ddp_model.parameters(): | |
param.grad.data /= world_size | |
# Update the model parameters | |
with torch.no_grad(): | |
for param in ddp_model.parameters(): | |
param.data -= 0.01 * param.grad.data # Assume learning rate of 0.01 | |
print(loss.item()) | |
cleanup() | |
def main(): | |
# Set the number of processes to the number of CPUs available or any specific number you want to use | |
world_size = 2 # Example: use 4 processes | |
#device = "cuda:0" | |
device = torch.device('cuda:0') | |
mp.spawn(example, args=(world_size, device,), nprocs=world_size, join=True) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment