Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created July 11, 2019 16:27
Show Gist options
  • Save ptrblck/ab45bfcde6df55ac28a7be18531f4718 to your computer and use it in GitHub Desktop.
Save ptrblck/ab45bfcde6df55ac28a7be18531f4718 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
def print_types(input, bn, output):
print('input.type(): {}'.format(input.type()))
if bn.weight is not None:
print('bn.weight.type(): {}'.format(bn.weight.type()))
else:
print('bn.weight is empty')
if bn.bias is not None:
print('bn.bias.type(): {}'.format(bn.bias.type()))
else:
print('bn.bias is empty')
print('bn.running_mean.type(): {}'.format(bn.running_mean.type()))
print('bn.running_var.type(): {}'.format(bn.running_var.type()))
print('output.type(): {}'.format(output.type()))
device = 'cuda'
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
bn = nn.BatchNorm2d(3, affine=True).cuda()
# Add dummy layer so that DDP gets some valid parameters
sbn = nn.SyncBatchNorm.convert_sync_batchnorm(nn.Sequential(bn, nn.Conv2d(3, 3, 3, 1, 1).cuda().half()))
sbn = nn.parallel.DistributedDataParallel(sbn, device_ids=[args.local_rank], output_device=args.local_rank)
x = torch.randn(16, 3, 24, 24, dtype=torch.half, device='cuda')
print('\nsbn, affine=True, train')
try:
output = sbn(x)
print_types(x, sbn.module[0], output)
except RuntimeError as e:
print('RuntimeError: ', e)
print('\nsbn, affine=True, eval')
sbn.eval()
try:
output = sbn(x)
print_types(x, sbn.module[0], output)
except RuntimeError as e:
print('RuntimeError: ', e)
x = torch.randn(16, 3, 24, 24, dtype=torch.half, device='cuda')
print('\nsbn, affine=False, train')
bn = nn.BatchNorm2d(3, affine=False).cuda()
sbn = nn.SyncBatchNorm.convert_sync_batchnorm(nn.Sequential(bn, nn.Conv2d(3, 3, 3, 1, 1).cuda().half()))
sbn = nn.parallel.DistributedDataParallel(sbn, device_ids=[args.local_rank], output_device=args.local_rank)
try:
output = sbn(x)
print_types(x, sbn.module[0], output)
except RuntimeError as e:
print('RuntimeError: ', e)
print('\nsbn, affine=False, eval')
sbn.eval()
try:
output = sbn(x)
print_types(x, sbn.module[0], output)
except RuntimeError as e:
print('RuntimeError: ', e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment