Last active
January 11, 2023 03:32
-
-
Save hotbaby/15950bbb43d052cd835b0f18c997f67c to your computer and use it in GitHub Desktop.
PyTorch分布式训练DDP Demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf8 | |
import os | |
import time | |
import random | |
import contextlib | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel | |
import torch.multiprocessing as mp | |
from torch.utils.data import DataLoader, TensorDataset | |
from torch.utils.data.distributed import DistributedSampler | |
def transform(tensors: list): | |
return torch.stack([t[0] for t in tensors]) | |
class Model(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.fc1 = nn.Linear(10, 100) | |
self.relu = nn.ReLU() | |
self.fc2 = nn.Linear(100, 20) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.relu(x) | |
return self.fc2(x) | |
def ddp_demo(rank, world_size, accum_grad=4): | |
assert dist.is_gloo_available(), "Gloo is not available!" | |
print(f"world_size: {world_size}, rank: {rank}, is_gloo_available: {dist.is_gloo_available()}") | |
# 1. 初始化进程组 | |
dist.init_process_group("gloo", world_size=world_size, rank=rank) | |
# model = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 20)) | |
model = Model() | |
# 2. 分布式数据并行封装模型 | |
ddp_model = DistributedDataParallel(model) | |
criterion = nn.MSELoss() | |
optimizer = optim.SGD(ddp_model.parameters(), lr=1e-3) | |
dataset = TensorDataset(torch.randn(1000, 10)) | |
# 3. 数据并行(内部根据rank采样) | |
sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True) | |
dataloader = DataLoader(dataset=dataset, batch_size=24, sampler=sampler, collate_fn=transform) | |
for epoch in range(1): | |
for step, batch in enumerate(dataloader): | |
output = ddp_model(batch) | |
label = torch.rand_like(output) | |
if step % accum_grad == 0: | |
# 同步参数 | |
context = contextlib.nullcontext | |
else: | |
# 4. 梯度累计,不同步参数 | |
context = ddp_model.no_sync | |
with context(): | |
time.sleep(random.random()) | |
loss = criterion(output, label) | |
loss.backward() | |
if step % accum_grad == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
print(f"epoch: {epoch}, step: {step}, rank: {rank} update parameters.") | |
# 5. 销毁进程组上下文数据(一些全局变量) | |
dist.destroy_process_group() | |
def main(): | |
world_size = 8 | |
mp.spawn(ddp_demo, args=(world_size,), nprocs=world_size, join=True) | |
if __name__ == "__main__": | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "12345" | |
main() |
模型转换成分布式训练的步骤:
- 初始化进程组
dist.init_process_group
; - 分布式数据并行封装模型
DistributedDataParallel(model)
; - 数据分布式并行,将数据分成
world_size
份,根据rank
采样DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True)
; - 训练过程中梯度累计,降低训练进程间的参数同步频率,提升通信效率【可选】;
- 销毁进程组
dist.destroy_process_group()
。
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
collate_fn
函数参数用于将sample合并成mini-batch