Last active
June 30, 2023 12:21
-
-
Save mcarilli/bf013d2d2f4b4dd21ade30c9b52d5e2e to your computer and use it in GitHub Desktop.
Minimal example of gradient accumulation, allreducing only on step() iterations and interacting properly with torch.cuda.amp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For single-node, run this script via | |
# python -m torch.distributed.launch --nproc_per_node=<ngpus this node> example.py | |
# | |
# For multinode, see https://pytorch.org/docs/stable/distributed.html#launch-utility | |
# | |
# Example showing native mixed precision tools | |
# (torch.cuda.amp.GradScaler and torch.cuda.amp.autocast) | |
# used along with native DistributedDataParallel to perform | |
# gradient accumulation with allreduces only when stepping. | |
# | |
# The key takeway is, each of those tools is used orthogonally | |
# (just as it would be in the absence of the others). | |
# There are no gotchas combining them. | |
import torch | |
import argparse | |
import os | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local_rank", default=0, type=int) | |
args = parser.parse_args() | |
args.distributed = False | |
if 'WORLD_SIZE' in os.environ: | |
args.distributed = int(os.environ['WORLD_SIZE']) > 1 | |
if args.distributed: | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend='nccl', | |
init_method='env://') | |
# Fake data (different in each process) | |
torch.manual_seed(args.local_rank) | |
N, D_in, D_out = 64, 1024, 16 | |
x = torch.randn(N, D_in, device='cuda') | |
y = torch.randn(N, D_out, device='cuda') | |
model = torch.nn.Linear(D_in, D_out).cuda() | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) | |
loss_fn = torch.nn.MSELoss() | |
scaler = torch.cuda.amp.GradScaler() | |
if args.distributed: | |
model = torch.nn.parallel.DistributedDataParallel(model, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank) | |
iters_to_accumulate = 4 | |
# help print gradient values for debugging | |
# torch.set_printoptions(precision=10) | |
# def debug_grad_info(t, should_match_across_ranks): | |
# string = "" | |
# for name, param in model.named_parameters(): | |
# string += "iter = {}, rank = {}, should match across ranks = {}, {}.grad sum = {}\n".format( | |
# t, args.local_rank, should_match_across_ranks, name, param.grad.double().sum().item()) | |
# print(string, flush=True) | |
def run_fwd_bwd(): | |
# Runs forward pass under autocast. | |
with torch.cuda.amp.autocast(): | |
y_pred = model(x) | |
# You may wish to divide loss by iters_to_accumulate to average | |
# across the effective (accumulated) global batch. | |
loss = loss_fn(y_pred, y)/iters_to_accumulate | |
scaler.scale(loss).backward() | |
for t in range(20): | |
if (t + 1) % iters_to_accumulate == 0: | |
# We will step() this iteration, so don't run forward and backward under no_sync. | |
# Allow allreduces to happen. | |
run_fwd_bwd() | |
# Grads DO match across ranks at this point, ready to step | |
# debug_grad_info(t, True) | |
scaler.step(optimizer) | |
optimizer.zero_grad(set_to_none=True) | |
# Only call scaler.update() for iterations where we actually step()ed, as in | |
# https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation | |
scaler.update() | |
else: | |
# We're not stepping this iteration, so use no_sync to prevent DDP allreduces. | |
# It appears we need to run forward and backward under no_sync() | |
# to get the right no-allreduce behavior. | |
with model.no_sync(): | |
run_fwd_bwd() | |
# Grads don't match across ranks at this point. | |
# debug_grad_info(t, False) | |
# double-check that param values are identical across ranks | |
string = "" | |
for name, param in model.named_parameters(): | |
string += "rank = {}, {} sum = {}\n".format( | |
args.local_rank, name, param.double().sum().item()) | |
print(string, flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for this gist!
Is the goal of using model.no_sync() to avoid overhead from synchronizing gradients if we're not performing an update in this step? That is what I understood from the documentation here.
If so, this should not affect the syncronization of batchnorm in the forward pass in SyncBatchNorm, is that correct?
Thanks!