Skip to content

Instantly share code, notes, and snippets.

@Quentin-Anthony
Created December 3, 2024 19:26
Show Gist options
  • Save Quentin-Anthony/3776726add80c865c6fcd564b36119e1 to your computer and use it in GitHub Desktop.
Save Quentin-Anthony/3776726add80c865c6fcd564b36119e1 to your computer and use it in GitHub Desktop.
"""
To run the benchmark, you would use mpirun_rsh like this:
For single-node multi-GPU:
mpirun_rsh <ENV_PARAMS> -np 2 python distributed_benchmark.py --task text --parallel_mode ddp
and for multi-node:
mpirun_rsh <ENV_PARAMS> -hostfile hosts -np 4 python distributed_benchmark.py --task vision --parallel_mode fsdp_full
This benchmarking script supports:
Different parallelization strategies:
- single: Single GPU baseline
- ddp: DistributedDataParallel
- fsdp_os: FSDP with optimizer state sharding only
- fsdp_ogs: FSDP with optimizer state and gradient sharding
- fsdp_full: FSDP with full parameter sharding
Two tasks:
- text: Small GPT-2 model
- vision: ResNet-50
Configurable training parameters:
- batch_size
- iterations
- epochs
- learning_rate
For the hostfile, create a text file listing your nodes, with N copies of each host for N GPUs per node. For example, with 2 nodes and 2 GPUs per node:
node1
node1
node2
node2
"""
import os
import argparse
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
ShardingStrategy,
)
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2Config, GPT2LMHeadModel
from torchvision.models import resnet50
from torchvision import transforms
import torch.multiprocessing as mp
def setup_distributed():
dist.init_process_group(backend='mpi')
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
class DummyTextDataset(Dataset):
def __init__(self, vocab_size=50257, seq_length=128):
self.data = torch.randint(0, vocab_size, (1000, seq_length))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.data[idx]
class DummyImageDataset(Dataset):
def __init__(self):
self.data = torch.randn(1000, 3, 224, 224)
self.labels = torch.randint(0, 1000, (1000,))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
def get_model_and_data(args):
if args.task == 'text':
config = GPT2Config(n_layer=4, n_head=4, n_embd=256)
model = GPT2LMHeadModel(config)
dataset = DummyTextDataset()
criterion = nn.CrossEntropyLoss()
else: # vision
model = resnet50()
dataset = DummyImageDataset()
criterion = nn.CrossEntropyLoss()
return model, dataset, criterion
def train(args):
setup_distributed()
rank = dist.get_rank()
world_size = dist.get_world_size()
model, dataset, criterion = get_model_and_data(args)
model = model.cuda()
if args.parallel_mode == 'ddp':
model = DDP(model, device_ids=[rank % torch.cuda.device_count()])
elif args.parallel_mode.startswith('fsdp'):
if args.parallel_mode == 'fsdp_os': # optimizer states only
sharding_strategy = ShardingStrategy.NO_SHARD
elif args.parallel_mode == 'fsdp_ogs': # optimizer states + gradients
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
else: # full sharding
sharding_strategy = ShardingStrategy.FULL_SHARD
model = FSDP(
model,
sharding_strategy=sharding_strategy,
cpu_offload=CPUOffload(offload_params=False),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=4,
)
# Training loop
model.train()
for epoch in range(args.epochs):
sampler.set_epoch(epoch)
for i, (data, target) in enumerate(dataloader):
if i >= args.iterations:
break
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data) if args.task == 'vision' else model(data, labels=target)
loss = output.loss if args.task == 'text' else criterion(output, target)
loss.backward()
optimizer.step()
if rank == 0 and i % 10 == 0:
print(f"Epoch {epoch}, Iteration {i}, Loss: {loss.item():.4f}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, choices=['text', 'vision'], default='text')
parser.add_argument('--parallel_mode', type=str,
choices=['single', 'ddp', 'fsdp_os', 'fsdp_ogs', 'fsdp_full'],
default='single')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--iterations', type=int, default=100)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=1e-4)
args = parser.parse_args()
if args.parallel_mode == 'single':
train(args)
else:
train(args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment