Created
December 3, 2024 19:26
-
-
Save Quentin-Anthony/3776726add80c865c6fcd564b36119e1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
To run the benchmark, you would use mpirun_rsh like this: | |
For single-node multi-GPU: | |
mpirun_rsh <ENV_PARAMS> -np 2 python distributed_benchmark.py --task text --parallel_mode ddp | |
and for multi-node: | |
mpirun_rsh <ENV_PARAMS> -hostfile hosts -np 4 python distributed_benchmark.py --task vision --parallel_mode fsdp_full | |
This benchmarking script supports: | |
Different parallelization strategies: | |
- single: Single GPU baseline | |
- ddp: DistributedDataParallel | |
- fsdp_os: FSDP with optimizer state sharding only | |
- fsdp_ogs: FSDP with optimizer state and gradient sharding | |
- fsdp_full: FSDP with full parameter sharding | |
Two tasks: | |
- text: Small GPT-2 model | |
- vision: ResNet-50 | |
Configurable training parameters: | |
- batch_size | |
- iterations | |
- epochs | |
- learning_rate | |
For the hostfile, create a text file listing your nodes, with N copies of each host for N GPUs per node. For example, with 2 nodes and 2 GPUs per node: | |
node1 | |
node1 | |
node2 | |
node2 | |
""" | |
import os | |
import argparse | |
import torch | |
import torch.nn as nn | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp.fully_sharded_data_parallel import ( | |
CPUOffload, | |
BackwardPrefetch, | |
ShardingStrategy, | |
) | |
from torch.utils.data import DataLoader, Dataset | |
from transformers import GPT2Config, GPT2LMHeadModel | |
from torchvision.models import resnet50 | |
from torchvision import transforms | |
import torch.multiprocessing as mp | |
def setup_distributed(): | |
dist.init_process_group(backend='mpi') | |
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) | |
class DummyTextDataset(Dataset): | |
def __init__(self, vocab_size=50257, seq_length=128): | |
self.data = torch.randint(0, vocab_size, (1000, seq_length)) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx], self.data[idx] | |
class DummyImageDataset(Dataset): | |
def __init__(self): | |
self.data = torch.randn(1000, 3, 224, 224) | |
self.labels = torch.randint(0, 1000, (1000,)) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx], self.labels[idx] | |
def get_model_and_data(args): | |
if args.task == 'text': | |
config = GPT2Config(n_layer=4, n_head=4, n_embd=256) | |
model = GPT2LMHeadModel(config) | |
dataset = DummyTextDataset() | |
criterion = nn.CrossEntropyLoss() | |
else: # vision | |
model = resnet50() | |
dataset = DummyImageDataset() | |
criterion = nn.CrossEntropyLoss() | |
return model, dataset, criterion | |
def train(args): | |
setup_distributed() | |
rank = dist.get_rank() | |
world_size = dist.get_world_size() | |
model, dataset, criterion = get_model_and_data(args) | |
model = model.cuda() | |
if args.parallel_mode == 'ddp': | |
model = DDP(model, device_ids=[rank % torch.cuda.device_count()]) | |
elif args.parallel_mode.startswith('fsdp'): | |
if args.parallel_mode == 'fsdp_os': # optimizer states only | |
sharding_strategy = ShardingStrategy.NO_SHARD | |
elif args.parallel_mode == 'fsdp_ogs': # optimizer states + gradients | |
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP | |
else: # full sharding | |
sharding_strategy = ShardingStrategy.FULL_SHARD | |
model = FSDP( | |
model, | |
sharding_strategy=sharding_strategy, | |
cpu_offload=CPUOffload(offload_params=False), | |
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
) | |
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | |
sampler = torch.utils.data.DistributedSampler( | |
dataset, | |
num_replicas=world_size, | |
rank=rank, | |
) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
sampler=sampler, | |
num_workers=4, | |
) | |
# Training loop | |
model.train() | |
for epoch in range(args.epochs): | |
sampler.set_epoch(epoch) | |
for i, (data, target) in enumerate(dataloader): | |
if i >= args.iterations: | |
break | |
data, target = data.cuda(), target.cuda() | |
optimizer.zero_grad() | |
output = model(data) if args.task == 'vision' else model(data, labels=target) | |
loss = output.loss if args.task == 'text' else criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
if rank == 0 and i % 10 == 0: | |
print(f"Epoch {epoch}, Iteration {i}, Loss: {loss.item():.4f}") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--task', type=str, choices=['text', 'vision'], default='text') | |
parser.add_argument('--parallel_mode', type=str, | |
choices=['single', 'ddp', 'fsdp_os', 'fsdp_ogs', 'fsdp_full'], | |
default='single') | |
parser.add_argument('--batch_size', type=int, default=32) | |
parser.add_argument('--iterations', type=int, default=100) | |
parser.add_argument('--epochs', type=int, default=1) | |
parser.add_argument('--learning_rate', type=float, default=1e-4) | |
args = parser.parse_args() | |
if args.parallel_mode == 'single': | |
train(args) | |
else: | |
train(args) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment