Created
July 27, 2018 18:35
-
-
Save alsrgv/13d6f579cc5e1e3389219e7c72d9fd4a to your computer and use it in GitHub Desktop.
1-late SGD for PyTorch ImageNet example with Horovod
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import torch.backends.cudnn as cudnn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.data.distributed | |
from torchvision import datasets, transforms, models | |
import horovod.torch as hvd | |
import tensorboardX | |
import os | |
from tqdm import tqdm | |
# Training settings | |
parser = argparse.ArgumentParser(description='PyTorch ImageNet Example', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--train-dir', default=os.path.expanduser('~/imagenet/train'), | |
help='path to training data') | |
parser.add_argument('--val-dir', default=os.path.expanduser('~/imagenet/validation'), | |
help='path to validation data') | |
parser.add_argument('--log-dir', default='./logs', | |
help='tensorboard log directory') | |
parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.pth.tar', | |
help='checkpoint file format') | |
parser.add_argument('--batch-size', type=int, default=32, | |
help='input batch size for training') | |
parser.add_argument('--val-batch-size', type=int, default=32, | |
help='input batch size for validation') | |
parser.add_argument('--epochs', type=int, default=90, | |
help='number of epochs to train') | |
parser.add_argument('--base-lr', type=float, default=0.0125, | |
help='learning rate for a single GPU') | |
parser.add_argument('--warmup-epochs', type=float, default=5, | |
help='number of warmup epochs') | |
parser.add_argument('--momentum', type=float, default=0.9, | |
help='SGD momentum') | |
parser.add_argument('--wd', type=float, default=0.00005, | |
help='weight decay') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA training') | |
parser.add_argument('--seed', type=int, default=42, | |
help='random seed') | |
args = parser.parse_args() | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
hvd.init() | |
torch.manual_seed(args.seed) | |
if args.cuda: | |
# Horovod: pin GPU to local rank. | |
torch.cuda.set_device(hvd.local_rank()) | |
torch.cuda.manual_seed(args.seed) | |
cudnn.benchmark = True | |
# If set > 0, will resume training from a given checkpoint. | |
resume_from_epoch = 0 | |
for try_epoch in range(args.epochs, 0, -1): | |
if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)): | |
resume_from_epoch = try_epoch | |
break | |
# Horovod: broadcast resume_from_epoch from rank 0 (which will have | |
# checkpoints) to other ranks. | |
resume_from_epoch = hvd.broadcast(torch.tensor(resume_from_epoch), root_rank=0, | |
name='resume_from_epoch').item() | |
# Horovod: print logs on the first worker. | |
verbose = 1 if hvd.rank() == 0 else 0 | |
# Horovod: write TensorBoard logs on first worker. | |
log_writer = tensorboardX.SummaryWriter(args.log_dir) if hvd.rank() == 0 else None | |
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} | |
train_dataset = \ | |
datasets.ImageFolder(args.train_dir, | |
transform=transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
])) | |
# Horovod: use DistributedSampler to partition data among workers. Manually specify | |
# `num_replicas=hvd.size()` and `rank=hvd.rank()`. | |
train_sampler = torch.utils.data.distributed.DistributedSampler( | |
train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs) | |
val_dataset = \ | |
datasets.ImageFolder(args.val_dir, | |
transform=transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
])) | |
val_sampler = torch.utils.data.distributed.DistributedSampler( | |
val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) | |
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.val_batch_size, | |
sampler=val_sampler, **kwargs) | |
# Set up standard ResNet-50 model. | |
model = models.resnet50() | |
if args.cuda: | |
# Move model to GPU. | |
model.cuda() | |
# Horovod: scale learning rate by the number of GPUs. | |
optimizer = optim.SGD(model.parameters(), lr=args.base_lr * hvd.size(), | |
momentum=args.momentum, weight_decay=args.wd) | |
# Custom Distributed Optimizer that does 1-late gradient. | |
class _DistributedOptimizer(torch.optim.Optimizer): | |
def __init__(self, params, named_parameters=None): | |
super(self.__class__, self).__init__(params) | |
if named_parameters is not None: | |
named_parameters = list(named_parameters) | |
else: | |
named_parameters = [] | |
# make sure that named_parameters are tuples | |
if any([not isinstance(p, tuple) for p in named_parameters]): | |
raise ValueError('named_parameters should be a sequence of ' | |
'tuples (name, parameter), usually produced by ' | |
'model.named_parameters().') | |
self._parameter_names = {v: k for k, v | |
in sorted(named_parameters)} | |
self._handles = {} | |
self._grad_accs = [] | |
self._last_grad = {} | |
if hvd.size() > 1: | |
self._register_hooks() | |
def _register_hooks(self): | |
for param_group in self.param_groups: | |
for p in param_group['params']: | |
if p.requires_grad: | |
p_tmp = p.expand_as(p) | |
grad_acc = p_tmp.grad_fn.next_functions[0][0] | |
grad_acc.register_hook(self._make_hook(p)) | |
self._grad_accs.append(grad_acc) | |
def _make_hook(self, p): | |
def hook(*ignore): | |
assert not p.grad.requires_grad | |
name = self._parameter_names.get(p) | |
new_grad = p.grad.data.clone() | |
if p in self._handles: | |
last_grad = hvd.synchronize(self._handles[p]) | |
else: | |
last_grad = torch.zeros_like(new_grad) | |
self._last_grad[p] = last_grad | |
handle = hvd.allreduce_async_(new_grad, average=True, name=name) | |
self._handles[p] = handle | |
return hook | |
def step(self, closure=None): | |
for p, lg in self._last_grad.items(): | |
p.grad.data.set_(lg) | |
return super(self.__class__, self).step(closure) | |
def DistributedOptimizer(optimizer, named_parameters=None): | |
""" | |
An optimizer that wraps another torch.optim.Optimizer, using an allreduce to | |
average gradient values before applying gradients to model weights. | |
Allreduce operations are executed after each gradient is computed by `loss.backward()` | |
in parallel with each other. The `step()` method ensures that all allreduce operations are | |
finished before applying gradients to the model. | |
DistributedOptimizer exposes the `synchronize()` method, which forces allreduce operations | |
to finish before continuing the execution. It's useful in conjunction with gradient | |
clipping, or other operations that modify gradients in place before `step()` is executed. | |
Example of gradient clipping: | |
``` | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
optimizer.synchronize() | |
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) | |
optimizer.step() | |
``` | |
Arguments: | |
optimizer: Optimizer to use for computing gradients and applying updates. | |
named_parameters: A mapping between parameter names and values. Used for naming of | |
allreduce operations. Typically just `model.named_parameters()`. | |
""" | |
# We dynamically create a new class that inherits from the optimizer that was passed in. | |
# The goal is to override the `step()` method with an allreduce implementation. | |
cls = type(optimizer.__class__.__name__, (optimizer.__class__,), | |
dict(_DistributedOptimizer.__dict__)) | |
return cls(optimizer.param_groups, named_parameters) | |
# Horovod: wrap optimizer with DistributedOptimizer. | |
optimizer = DistributedOptimizer( | |
optimizer, named_parameters=model.named_parameters()) | |
# Restore from a previous checkpoint, if initial_epoch is specified. | |
# Horovod: restore on the first worker which will broadcast weights to other workers. | |
if resume_from_epoch > 0 and hvd.rank() == 0: | |
filepath = args.checkpoint_format.format(epoch=resume_from_epoch) | |
checkpoint = torch.load(filepath) | |
model.load_state_dict(checkpoint['model']) | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
# Horovod: broadcast parameters & optimizer state. | |
hvd.broadcast_parameters(model.state_dict(), root_rank=0) | |
hvd.broadcast_optimizer_state(optimizer, root_rank=0) | |
def train(epoch): | |
model.train() | |
train_sampler.set_epoch(epoch) | |
train_loss = Metric('train_loss') | |
train_accuracy = Metric('train_accuracy') | |
with tqdm(total=len(train_loader), | |
desc='Train Epoch #{}'.format(epoch + 1), | |
disable=not verbose) as t: | |
for batch_idx, (data, target) in enumerate(train_loader): | |
adjust_learning_rate(epoch, batch_idx) | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
optimizer.zero_grad() | |
output = model(data) | |
loss = F.cross_entropy(output, target) | |
loss.backward() | |
optimizer.step() | |
train_loss.update(loss) | |
train_accuracy.update(accuracy(output, target)) | |
t.set_postfix({'loss': train_loss.avg.item(), | |
'accuracy': 100. * train_accuracy.avg.item()}) | |
t.update(1) | |
if log_writer: | |
log_writer.add_scalar('train/loss', train_loss.avg, epoch) | |
log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch) | |
def validate(epoch): | |
model.eval() | |
val_loss = Metric('val_loss') | |
val_accuracy = Metric('val_accuracy') | |
with tqdm(total=len(val_loader), | |
desc='Validate Epoch #{}'.format(epoch + 1), | |
disable=not verbose) as t: | |
with torch.no_grad(): | |
for data, target in val_loader: | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
output = model(data) | |
val_loss.update(F.cross_entropy(output, target)) | |
val_accuracy.update(accuracy(output, target)) | |
t.set_postfix({'loss': val_loss.avg.item(), | |
'accuracy': 100. * val_accuracy.avg.item()}) | |
t.update(1) | |
if log_writer: | |
log_writer.add_scalar('val/loss', val_loss.avg, epoch) | |
log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch) | |
# Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final | |
# accuracy. Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` during | |
# the first five epochs. See https://arxiv.org/abs/1706.02677 for details. | |
# After the warmup reduce learning rate by 10 on the 30th, 60th and 80th epochs. | |
def adjust_learning_rate(epoch, batch_idx): | |
if epoch < args.warmup_epochs: | |
epoch += float(batch_idx + 1) / len(train_loader) | |
lr_adj = 1. / hvd.size() * (epoch * (hvd.size() - 1) / args.warmup_epochs + 1) | |
elif epoch < 30: | |
lr_adj = 1. | |
elif epoch < 60: | |
lr_adj = 1e-1 | |
elif epoch < 80: | |
lr_adj = 1e-2 | |
else: | |
lr_adj = 1e-3 | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = args.base_lr * hvd.size() * lr_adj | |
def accuracy(output, target): | |
# get the index of the max log-probability | |
pred = output.max(1, keepdim=True)[1] | |
return pred.eq(target.view_as(pred)).cpu().float().mean() | |
def save_checkpoint(epoch): | |
if hvd.rank() == 0: | |
filepath = args.checkpoint_format.format(epoch=epoch + 1) | |
state = { | |
'model': model.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
} | |
torch.save(state, filepath) | |
# Horovod: average metrics from distributed training. | |
class Metric(object): | |
def __init__(self, name): | |
self.name = name | |
self.sum = torch.tensor(0.) | |
self.n = torch.tensor(0.) | |
def update(self, val): | |
self.sum += hvd.allreduce(val.cpu(), name=self.name) | |
self.n += 1 | |
@property | |
def avg(self): | |
return self.sum / self.n | |
for epoch in range(resume_from_epoch, args.epochs): | |
train(epoch) | |
validate(epoch) | |
save_checkpoint(epoch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment