import argparse |
import os |
import shutil |
import time |
import torch |
import torch.nn as nn |
import torch.nn.parallel |
import torch.backends.cudnn as cudnn |
import torch.distributed as dist |
import torch.optim |
import torch.utils.data |
import torch.utils.data.distributed |
import torchvision.transforms as transforms |
import torchvision.datasets as datasets |
import torchvision.models as models |
import numpy as np |
try: |
from apex.parallel import DistributedDataParallel as DDP |
from apex.fp16_utils import * |
from apex import amp |
import apex_C |
except ImportError: |
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") |
model_names = sorted(name for name in models.__dict__ |
if name.islower() and not name.startswith("__") |
and callable(models.__dict__[name])) |
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') |
parser.add_argument('data', metavar='DIR', |
help='path to dataset') |
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', |
choices=model_names, |
help='model architecture: ' + |
' | '.join(model_names) + |
' (default: resnet18)') |
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', |
help='number of data loading workers (default: 4)') |
parser.add_argument('--epochs', default=90, type=int, metavar='N', |
help='number of total epochs to run') |
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', |
help='manual epoch number (useful on restarts)') |
parser.add_argument('-b', '--batch-size', default=256, type=int, |
metavar='N', help='mini-batch size per process (default: 256)') |
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, |
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.') |
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', |
help='momentum') |
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, |
metavar='W', help='weight decay (default: 1e-4)') |
parser.add_argument('--print-freq', '-p', default=10, type=int, |
metavar='N', help='print frequency (default: 10)') |
parser.add_argument('--resume', default='', type=str, metavar='PATH', |
help='path to latest checkpoint (default: none)') |
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', |
help='evaluate model on validation set') |
parser.add_argument('--pretrained', dest='pretrained', action='store_true', |
help='use pre-trained model') |
parser.add_argument('--fp16', action='store_true', |
help='Run model fp16 mode.') |
parser.add_argument('--prof', dest='prof', action='store_true', |
help='Only run 10 iterations for profiling.') |
parser.add_argument('--deterministic', action='store_true') |
parser.add_argument("--local_rank", default=0, type=int) |
parser.add_argument('--sync_bn', action='store_true', |
help='enabling apex sync BN.') |
cudnn.benchmark = True |
def fast_collate(batch): |
imgs = [img[0] for img in batch] |
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) |
w = imgs[0].size[0] |
h = imgs[0].size[1] |
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) |
for i, img in enumerate(imgs): |
nump_array = np.asarray(img, dtype=np.uint8) |
tens = torch.from_numpy(nump_array) |
if(nump_array.ndim < 3): |
nump_array = np.expand_dims(nump_array, axis=-1) |
nump_array = np.rollaxis(nump_array, 2) |
tensor[i] += torch.from_numpy(nump_array) |
return tensor, targets |
best_prec1 = 0 |
args = parser.parse_args() |
if args.deterministic: |
cudnn.benchmark = False |
cudnn.deterministic = True |
torch.manual_seed(args.local_rank) |
# Initialize Amp |
amp_handle = amp.init(enabled=args.fp16) |
def main(): |
global best_prec1, args |
args.distributed = False |
if 'WORLD_SIZE' in os.environ: |
args.distributed = int(os.environ['WORLD_SIZE']) > 1 |
args.gpu = 0 |
args.world_size = 1 |
if args.distributed: |
args.gpu = args.local_rank % torch.cuda.device_count() |
torch.cuda.set_device(args.gpu) |
torch.distributed.init_process_group(backend='nccl', |
init_method='env://') |
args.world_size = torch.distributed.get_world_size() |
if args.fp16: |
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." |
# create model |
if args.pretrained: |
print("=> using pre-trained model '{}'".format(args.arch)) |
model = models.__dict__[args.arch](pretrained=True) |
else: |
print("=> creating model '{}'".format(args.arch)) |
model = models.__dict__[args.arch]() |
if args.sync_bn: |
import apex |
print("using apex synced BN") |
model = apex.parallel.convert_syncbn_model(model) |
model = model.cuda() |
if args.distributed: |
# By default, apex.parallel.DistributedDataParallel overlaps communication with |
# computation in the backward pass. |
# model = DDP(model) |
# delay_allreduce delays all communication to the end of the backward pass. |
model = DDP(model, delay_allreduce=True) |
# define loss function (criterion) and optimizer |
criterion = nn.CrossEntropyLoss().cuda() |
# For param flattening: Choose the groups into which we want to divide the params. |
# Params cannot be shared between flattened groups. Flat groups can only contain unique params. |
# Any "shared" params in an initially-desired set of groups must be given their own "group-of-one," |
# and removed from the initially-desired group. |
# |
# In this particular script, since we're using Amp, all of the params will be fp32, |
# so fp16_params will be empty. It's included here as a placeholder/example of how |
# multiple groups can be handled. |
fp16_params = [param for param in model.parameters() if param.type() == "torch.cuda.HalfTensor" and param.requires_grad] |
fp32_params = [param for param in model.parameters() if param.type() == "torch.cuda.FloatTensor" and param.requires_grad] |
global groups_to_flatten |
global flat_groups |
groups_to_flatten = [group for group in (fp16_params, fp32_params) if group] |
flat_groups = [apex_C.flatten([param.data for param in group]) for group in groups_to_flatten] |
# For param flattening: Reset the model params' .data members to point into sections of the |
# flat buffer. |
for group, flat_group in zip(groups_to_flatten, flat_groups): |
for param_ref, section_of_flat_group in zip(group, apex_C.unflatten(flat_group, group)): |
param_ref.data = section_of_flat_group.data |
# Scale learning rate based on global batch size |
args.lr = args.lr*float(args.batch_size*args.world_size)/256. |
# For param flattening: Optimizer receives flat_groups |
optimizer = torch.optim.SGD(flat_groups, args.lr, |
momentum=args.momentum, |
weight_decay=args.weight_decay) |
# Optionally resume from a checkpoint |
if args.resume: |
# Use a local scope to avoid dangling references |
def resume(): |
if os.path.isfile(args.resume): |
print("=> loading checkpoint '{}'".format(args.resume)) |
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) |
args.start_epoch = checkpoint['epoch'] |
best_prec1 = checkpoint['best_prec1'] |
model.load_state_dict(checkpoint['state_dict']) |
optimizer.load_state_dict(checkpoint['optimizer']) |
print("=> loaded checkpoint '{}' (epoch {})" |
.format(args.resume, checkpoint['epoch'])) |
else: |
print("=> no checkpoint found at '{}'".format(args.resume)) |
resume() |
# Data loading code |
traindir = os.path.join(args.data, 'train') |
valdir = os.path.join(args.data, 'val') |
if(args.arch == "inception_v3"): |
crop_size = 299 |
val_size = 320 # I chose this value arbitrarily, we can adjust. |
else: |
crop_size = 224 |
val_size = 256 |
train_dataset = datasets.ImageFolder( |
traindir, |
transforms.Compose([ |
transforms.RandomResizedCrop(crop_size), |
transforms.RandomHorizontalFlip(), |
# transforms.ToTensor(), Too slow |
# normalize, |
])) |
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ |
transforms.Resize(val_size), |
transforms.CenterCrop(crop_size), |
])) |
train_sampler = None |
val_sampler = None |
if args.distributed: |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) |
train_loader = torch.utils.data.DataLoader( |
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), |
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) |
val_loader = torch.utils.data.DataLoader( |
val_dataset, |
batch_size=args.batch_size, shuffle=False, |
num_workers=args.workers, pin_memory=True, |
sampler=val_sampler, |
collate_fn=fast_collate) |
if args.evaluate: |
validate(val_loader, model, criterion) |
return |
for epoch in range(args.start_epoch, args.epochs): |
if args.distributed: |
train_sampler.set_epoch(epoch) |
# train for one epoch |
train(train_loader, model, criterion, optimizer, epoch) |
if args.prof: |
break |
# evaluate on validation set |
prec1 = validate(val_loader, model, criterion) |
# remember best prec@1 and save checkpoint |
if args.local_rank == 0: |
is_best = prec1 > best_prec1 |
best_prec1 = max(prec1, best_prec1) |
save_checkpoint({ |
'epoch': epoch + 1, |
'arch': args.arch, |
'state_dict': model.state_dict(), |
'best_prec1': best_prec1, |
'optimizer' : optimizer.state_dict(), |
}, is_best) |
class data_prefetcher(): |
def __init__(self, loader): |
self.loader = iter(loader) |
self.stream = torch.cuda.Stream() |
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) |
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) |
# With Amp, it isn't necessary to manually convert data to half. |
# Type conversions are done internally on the fly within patched torch functions. |
# if args.fp16: |
# self.mean = self.mean.half() |
# self.std = self.std.half() |
self.preload() |
def preload(self): |
try: |
self.next_input, self.next_target = next(self.loader) |
except StopIteration: |
self.next_input = None |
self.next_target = None |
return |
with torch.cuda.stream(self.stream): |
self.next_input = self.next_input.cuda(async=True) |
self.next_target = self.next_target.cuda(async=True) |
# With Amp, it isn't necessary to manually convert data to half. |
# Type conversions are done internally on the fly within patched torch functions. |
# if args.fp16: |
# self.next_input = self.next_input.half() |
# else: |
self.next_input = self.next_input.float() |
self.next_input = self.next_input.sub_(self.mean).div_(self.std) |
def next(self): |
torch.cuda.current_stream().wait_stream(self.stream) |
input = self.next_input |
target = self.next_target |
self.preload() |
return input, target |
def train(train_loader, model, criterion, optimizer, epoch): |
batch_time = AverageMeter() |
data_time = AverageMeter() |
losses = AverageMeter() |
top1 = AverageMeter() |
top5 = AverageMeter() |
# switch to train mode |
model.train() |
end = time.time() |
prefetcher = data_prefetcher(train_loader) |
input, target = prefetcher.next() |
i = -1 |
while input is not None: |
i += 1 |
adjust_learning_rate(optimizer, epoch, i, len(train_loader)) |
if args.prof: |
if i > 10: |
break |
# measure data loading time |
data_time.update(time.time() - end) |
# compute output |
output = model(input) |
loss = criterion(output, target) |
# measure accuracy and record loss |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) |
if args.distributed: |
reduced_loss = reduce_tensor(loss.data) |
prec1 = reduce_tensor(prec1) |
prec5 = reduce_tensor(prec5) |
else: |
reduced_loss = loss.data |
losses.update(to_python_float(reduced_loss), input.size(0)) |
top1.update(to_python_float(prec1), input.size(0)) |
top5.update(to_python_float(prec5), input.size(0)) |
# compute gradient and do SGD step |
for param in model.parameters(): |
param.grad = None |
# For param flattening: Remove existing flat param gradients |
for flat_group in flat_groups: |
flat_group.grad = None |
with amp_handle.scale_loss(loss, optimizer) as scaled_loss: |
scaled_loss.backward() |
# For param flattening: Populate flat grads. this MUST be done within the context manager, |
# because Amp checks the gradients of the optimizer's owned parameters (aka, the flat |
# groups) for infs/nans on context manager exit. |
for g, flat_group in enumerate(flat_groups): |
# Allow for the possibility that one group or the other did not receive any gradients. |
grads_to_flatten = [param.grad.data for param in groups_to_flatten[g] if param.grad is not None] |
if grads_to_flatten: |
flat_group.grad = apex_C.flatten(grads_to_flatten) |
optimizer.step() |
torch.cuda.synchronize() |
# measure elapsed time |
batch_time.update(time.time() - end) |
end = time.time() |
input, target = prefetcher.next() |
if args.local_rank == 0 and i % args.print_freq == 0 and i > 1: |
print('Epoch: [{0}][{1}/{2}]\t' |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
'Speed {3:.3f} ({4:.3f})\t' |
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( |
epoch, i, len(train_loader), |
args.world_size * args.batch_size / batch_time.val, |
args.world_size * args.batch_size / batch_time.avg, |
batch_time=batch_time, |
data_time=data_time, loss=losses, top1=top1, top5=top5)) |
def validate(val_loader, model, criterion): |
batch_time = AverageMeter() |
losses = AverageMeter() |
top1 = AverageMeter() |
top5 = AverageMeter() |
# switch to evaluate mode |
model.eval() |
end = time.time() |
prefetcher = data_prefetcher(val_loader) |
input, target = prefetcher.next() |
i = -1 |
while input is not None: |
i += 1 |
# compute output |
with torch.no_grad(): |
output = model(input) |
loss = criterion(output, target) |
# measure accuracy and record loss |
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) |
if args.distributed: |
reduced_loss = reduce_tensor(loss.data) |
prec1 = reduce_tensor(prec1) |
prec5 = reduce_tensor(prec5) |
else: |
reduced_loss = loss.data |
losses.update(to_python_float(reduced_loss), input.size(0)) |
top1.update(to_python_float(prec1), input.size(0)) |
top5.update(to_python_float(prec5), input.size(0)) |
# measure elapsed time |
batch_time.update(time.time() - end) |
end = time.time() |
if args.local_rank == 0 and i % args.print_freq == 0: |
print('Test: [{0}/{1}]\t' |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
'Speed {2:.3f} ({3:.3f})\t' |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' |
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( |
i, len(val_loader), |
args.world_size * args.batch_size / batch_time.val, |
args.world_size * args.batch_size / batch_time.avg, |
batch_time=batch_time, loss=losses, |
top1=top1, top5=top5)) |
input, target = prefetcher.next() |
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' |
.format(top1=top1, top5=top5)) |
return top1.avg |
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): |
torch.save(state, filename) |
if is_best: |
shutil.copyfile(filename, 'model_best.pth.tar') |
class AverageMeter(object): |
"""Computes and stores the average and current value""" |
def __init__(self): |
self.reset() |
def reset(self): |
self.val = 0 |
self.avg = 0 |
self.sum = 0 |
self.count = 0 |
def update(self, val, n=1): |
self.val = val |
self.sum += val * n |
self.count += n |
self.avg = self.sum / self.count |
def adjust_learning_rate(optimizer, epoch, step, len_epoch): |
"""LR schedule that should yield 76% converged accuracy with batch size 256""" |
factor = epoch // 30 |
if epoch >= 80: |
factor = factor + 1 |
lr = args.lr*(0.1**factor) |
"""Warmup""" |
if epoch < 5: |
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) |
# if(args.local_rank == 0): |
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) |
for param_group in optimizer.param_groups: |
param_group['lr'] = lr |
def accuracy(output, target, topk=(1,)): |
"""Computes the precision@k for the specified values of k""" |
maxk = max(topk) |
batch_size = target.size(0) |
_, pred = output.topk(maxk, 1, True, True) |
pred = pred.t() |
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
res = [] |
for k in topk: |
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) |
res.append(correct_k.mul_(100.0 / batch_size)) |
return res |
def reduce_tensor(tensor): |
rt = tensor.clone() |
dist.all_reduce(rt, op=dist.reduce_op.SUM) |
rt /= args.world_size |
return rt |
if __name__ == '__main__': |
main() |