Last active
January 11, 2023 03:32
-
-
Save hotbaby/15950bbb43d052cd835b0f18c997f67c to your computer and use it in GitHub Desktop.
PyTorch分布式训练DDP Demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf8 | |
import os | |
import time | |
import random | |
import contextlib | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel | |
import torch.multiprocessing as mp | |
from torch.utils.data import DataLoader, TensorDataset | |
from torch.utils.data.distributed import DistributedSampler | |
def transform(tensors: list): | |
return torch.stack([t[0] for t in tensors]) | |
class Model(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.fc1 = nn.Linear(10, 100) | |
self.relu = nn.ReLU() | |
self.fc2 = nn.Linear(100, 20) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.relu(x) | |
return self.fc2(x) | |
def ddp_demo(rank, world_size, accum_grad=4): | |
assert dist.is_gloo_available(), "Gloo is not available!" | |
print(f"world_size: {world_size}, rank: {rank}, is_gloo_available: {dist.is_gloo_available()}") | |
# 1. 初始化进程组 | |
dist.init_process_group("gloo", world_size=world_size, rank=rank) | |
# model = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 20)) | |
model = Model() | |
# 2. 分布式数据并行封装模型 | |
ddp_model = DistributedDataParallel(model) | |
criterion = nn.MSELoss() | |
optimizer = optim.SGD(ddp_model.parameters(), lr=1e-3) | |
dataset = TensorDataset(torch.randn(1000, 10)) | |
# 3. 数据并行(内部根据rank采样) | |
sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True) | |
dataloader = DataLoader(dataset=dataset, batch_size=24, sampler=sampler, collate_fn=transform) | |
for epoch in range(1): | |
for step, batch in enumerate(dataloader): | |
output = ddp_model(batch) | |
label = torch.rand_like(output) | |
if step % accum_grad == 0: | |
# 同步参数 | |
context = contextlib.nullcontext | |
else: | |
# 4. 梯度累计,不同步参数 | |
context = ddp_model.no_sync | |
with context(): | |
time.sleep(random.random()) | |
loss = criterion(output, label) | |
loss.backward() | |
if step % accum_grad == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
print(f"epoch: {epoch}, step: {step}, rank: {rank} update parameters.") | |
# 5. 销毁进程组上下文数据(一些全局变量) | |
dist.destroy_process_group() | |
def main(): | |
world_size = 8 | |
mp.spawn(ddp_demo, args=(world_size,), nprocs=world_size, join=True) | |
if __name__ == "__main__": | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "12345" | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
参考文献: