Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created April 8, 2020 10:27
Show Gist options
  • Save ptrblck/52736ec7a8656e7ad4914ccd6d4eb35a to your computer and use it in GitHub Desktop.
Save ptrblck/52736ec7a8656e7ad4914ccd6d4eb35a to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel import SyncBatchNorm as ApexSyncBatchNorm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--apex', action='store_true')
args = parser.parse_args()
torch.manual_seed(2809)
# Setup DDP
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda:{}'.format(args.local_rank))
torch.distributed.init_process_group(
'nccl',
init_method='env://',
rank=args.local_rank,
)
# Setup model
if args.apex:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
ApexSyncBatchNorm(6)
)
else:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.SyncBatchNorm(6)
)
# Setup reference model
model_reference = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.BatchNorm2d(6)
)
with torch.no_grad():
model_reference[0].weight.copy_(model[0].weight)
model_reference[0].bias.copy_(model[0].bias)
model_reference.to(device)
# Setup SyncBN
#if not args.apex:
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(device)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
# Create random data
if args.local_rank == 0:
data = torch.randn(16, 3, 24, 24, device=device) * 100
else:
data = torch.randn(8, 3, 24, 24, device=device)
print('Input.sum() {}, .mean() {}, .std() {}, .min() {}, .max() {}, device {}'.format(
data.sum(), data.mean(), data.std(), data.min(), data.max(), data.device))
# DDP forward/backward
output = model(data)
print('DDP output.sum() {}, .mean() {}, .std() {}, .min() {}, .max() {}, device {}'.format(
output.sum(), output.mean(), output.std(), output.min(), output.max(), output.device))
output.sum().backward()
# Reference forward/backward
output_reference = model_reference(data)
print('Reference output.sum() {}, .mean() {}, .std() {}, .min() {}, .max() {}, device {}'.format(
output_reference.sum(), output_reference.mean(), output_reference.std(), output_reference.min(),
output_reference.max(), output_reference.device))
output_reference.sum().backward()
# Print stats
print('DDP stats ', model.module[1].running_mean, model.module[1].running_var)
print('Reference stats ', model_reference[1].running_mean, model_reference[1].running_var)
print('DDP grads ', model.module[0].weight.grad.abs().sum())
print('Reference grads ', model_reference[0].weight.grad.abs().sum())
@ptrblck
Copy link
Author

ptrblck commented Apr 8, 2020

Vanilla output

root@cd8d754a6069:/workspace/src# python -m torch.distributed.launch --nproc_per_node=4 repro.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Input.sum() 22933.48828125, .mean() 0.8294809460639954, .std() 100.45414733886719, .min() -408.5179443359375, .max() 426.8013000488281, device cuda:0
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:1
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:2
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:3
DDP output.sum() -51.18982696533203, .mean() -0.0009257419733330607, .std() 1.581018328666687, .min() -7.145139694213867, .max() 7.250021934509277, device cuda:0
DDP output.sum() 17.063417434692383, .mean() 0.0006171664572320879, .std() 0.016831880435347557, .min() -0.06054949387907982, .max() 0.0813681036233902, device cuda:1
DDP output.sum() 17.063417434692383, .mean() 0.0006171664572320879, .std() 0.016831880435347557, .min() -0.06054949387907982, .max() 0.0813681036233902, device cuda:2
DDP output.sum() 17.063417434692383, .mean() 0.0006171664572320879, .std() 0.016831880435347557, .min() -0.06054949387907982, .max() 0.0813681036233902, device cuda:3
Reference output.sum() 5.340576171875e-05, .mean() 9.65815982745255e-10, .std() 1.0000090599060059, .min() -4.512101173400879, .max() 4.594338417053223, device cuda:0
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:1
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:2
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:3
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:0') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:0')
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:1') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:1')
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:2') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:2')
DDP stats  tensor([ 0.0255, -0.0082, -0.0093, -0.0104, -0.0158, -0.0472], device='cuda:3') tensor([111.1794,  87.8819, 112.0774, 130.4485, 135.5200, 140.0746],
       device='cuda:3')
Reference stats  tensor([ 0.0710,  0.0012,  0.0002, -0.0112, -0.0646, -0.0910], device='cuda:0') tensor([276.5405, 218.3357, 278.8197, 324.7449, 337.3835, 348.7750],
       device='cuda:0')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:1') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:1')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:2') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:2')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:3') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:3')
DDP grads  tensor(512.7340, device='cuda:0')
DDP grads  tensor(512.7340, device='cuda:1')
DDP grads  tensor(512.7340, device='cuda:2')
DDP grads  tensor(512.7340, device='cuda:3')
Reference grads  tensor(0.0002, device='cuda:0')
Reference grads  tensor(0.0008, device='cuda:1')
Reference grads  tensor(0.0008, device='cuda:2')
Reference grads  tensor(0.0008, device='cuda:3')

Apex output

root@cd8d754a6069:/workspace/src# python -m torch.distributed.launch --nproc_per_node=4 repro.py --apex
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Input.sum() 22933.48828125, .mean() 0.8294809460639954, .std() 100.45414733886719, .min() -408.5179443359375, .max() 426.8013000488281, device cuda:0
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:1
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:2
Input.sum() -9.270572662353516, .mean() -0.0006706143612973392, .std() 0.992488443851471, .min() -3.8868343830108643, .max() 3.5236613750457764, device cuda:3
DDP output.sum() -80.93080139160156, .mean() -0.0014635922852903605, .std() 1.99970543384552, .min() -9.040897369384766, .max() 9.16562557220459, device cuda:0
DDP output.sum() 13.488600730895996, .mean() 0.0004878689651377499, .std() 0.02039550431072712, .min() -0.07519855350255966, .max() 0.09870573878288269, device cuda:1
DDP output.sum() 13.488600730895996, .mean() 0.0004878689651377499, .std() 0.02039550431072712, .min() -0.07519855350255966, .max() 0.09870573878288269, device cuda:2
DDP output.sum() 13.488600730895996, .mean() 0.0004878689651377499, .std() 0.02039550431072712, .min() -0.07519855350255966, .max() 0.09870573878288269, device cuda:3
Reference output.sum() 5.340576171875e-05, .mean() 9.65815982745255e-10, .std() 1.0000090599060059, .min() -4.512101173400879, .max() 4.594338417053223, device cuda:0
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:1
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:2
Reference output.sum() -0.000179290771484375, .mean() -6.4847647252008755e-09, .std() 1.0000003576278687, .min() -3.7496793270111084, .max() 4.66790246963501, device cuda:3
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:0') tensor([69.8360, 55.2708, 70.3950, 81.8784, 85.0508, 87.8978], device='cuda:0')
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:1') tensor([69.8378, 55.2722, 70.3969, 81.8806, 85.0531, 87.9001], device='cuda:1')
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:2') tensor([69.8378, 55.2722, 70.3969, 81.8806, 85.0531, 87.9001], device='cuda:2')
DDP stats  tensor([ 0.0141, -0.0105, -0.0117, -0.0102, -0.0036, -0.0363], device='cuda:3') tensor([69.8378, 55.2722, 70.3969, 81.8806, 85.0531, 87.9001], device='cuda:3')
Reference stats  tensor([ 0.0710,  0.0012,  0.0002, -0.0112, -0.0646, -0.0910], device='cuda:0') tensor([276.5405, 218.3357, 278.8197, 324.7449, 337.3835, 348.7750],
       device='cuda:0')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:1') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:1')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:2') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:2')
Reference stats  tensor([-0.0049, -0.0144, -0.0156, -0.0099,  0.0168, -0.0180], device='cuda:3') tensor([0.9276, 0.9211, 0.9270, 0.9317, 0.9325, 0.9348], device='cuda:3')
DDP grads  tensor(0.0019, device='cuda:0')
DDP grads  tensor(0.0019, device='cuda:1')
DDP grads  tensor(0.0019, device='cuda:2')
DDP grads  tensor(0.0019, device='cuda:3')
Reference grads  tensor(0.0002, device='cuda:0')
Reference grads  tensor(0.0008, device='cuda:1')
Reference grads  tensor(0.0008, device='cuda:2')
Reference grads  tensor(0.0008, device='cuda:3')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment