|  | import argparse | 
        
          |  | import os | 
        
          |  | import shutil | 
        
          |  | import time | 
        
          |  |  | 
        
          |  | import torch | 
        
          |  | import torch.nn as nn | 
        
          |  | import torch.nn.parallel | 
        
          |  | import torch.backends.cudnn as cudnn | 
        
          |  | import torch.distributed as dist | 
        
          |  | import torch.optim | 
        
          |  | import torch.utils.data | 
        
          |  | import torch.utils.data.distributed | 
        
          |  | import torchvision.transforms as transforms | 
        
          |  | import torchvision.datasets as datasets | 
        
          |  | import torchvision.models as models | 
        
          |  |  | 
        
          |  | import numpy as np | 
        
          |  |  | 
        
          |  | try: | 
        
          |  | from apex.parallel import DistributedDataParallel as DDP | 
        
          |  | from apex.fp16_utils import * | 
        
          |  | from apex import amp, optimizers | 
        
          |  | from apex.multi_tensor_apply import multi_tensor_applier | 
        
          |  | except ImportError: | 
        
          |  | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") | 
        
          |  |  | 
        
          |  | model_names = sorted(name for name in models.__dict__ | 
        
          |  | if name.islower() and not name.startswith("__") | 
        
          |  | and callable(models.__dict__[name])) | 
        
          |  |  | 
        
          |  | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') | 
        
          |  | parser.add_argument('data', metavar='DIR', | 
        
          |  | help='path to dataset') | 
        
          |  | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', | 
        
          |  | choices=model_names, | 
        
          |  | help='model architecture: ' + | 
        
          |  | ' | '.join(model_names) + | 
        
          |  | ' (default: resnet18)') | 
        
          |  | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | 
        
          |  | help='number of data loading workers (default: 4)') | 
        
          |  | parser.add_argument('--epochs', default=90, type=int, metavar='N', | 
        
          |  | help='number of total epochs to run') | 
        
          |  | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', | 
        
          |  | help='manual epoch number (useful on restarts)') | 
        
          |  | parser.add_argument('-b', '--batch-size', default=256, type=int, | 
        
          |  | metavar='N', help='mini-batch size per process (default: 256)') | 
        
          |  | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, | 
        
          |  | metavar='LR', help='Initial learning rate.  Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256.  A warmup schedule will also be applied over the first 5 epochs.') | 
        
          |  | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', | 
        
          |  | help='momentum') | 
        
          |  | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, | 
        
          |  | metavar='W', help='weight decay (default: 1e-4)') | 
        
          |  | parser.add_argument('--print-freq', '-p', default=10, type=int, | 
        
          |  | metavar='N', help='print frequency (default: 10)') | 
        
          |  | parser.add_argument('--resume', default='', type=str, metavar='PATH', | 
        
          |  | help='path to latest checkpoint (default: none)') | 
        
          |  | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', | 
        
          |  | help='evaluate model on validation set') | 
        
          |  | parser.add_argument('--pretrained', dest='pretrained', action='store_true', | 
        
          |  | help='use pre-trained model') | 
        
          |  |  | 
        
          |  | parser.add_argument('--prof', default=-1, type=int, | 
        
          |  | help='Only run 10 iterations for profiling.') | 
        
          |  | parser.add_argument('--deterministic', action='store_true') | 
        
          |  |  | 
        
          |  | parser.add_argument("--local_rank", default=0, type=int) | 
        
          |  | parser.add_argument('--sync_bn', action='store_true', | 
        
          |  | help='enabling apex sync BN.') | 
        
          |  |  | 
        
          |  | parser.add_argument('--opt-level', type=str) | 
        
          |  | parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) | 
        
          |  | parser.add_argument('--loss-scale', type=str, default=None) | 
        
          |  |  | 
        
          |  | cudnn.benchmark = True | 
        
          |  |  | 
        
          |  | def fast_collate(batch): | 
        
          |  | imgs = [img[0] for img in batch] | 
        
          |  | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) | 
        
          |  | w = imgs[0].size[0] | 
        
          |  | h = imgs[0].size[1] | 
        
          |  | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) | 
        
          |  | for i, img in enumerate(imgs): | 
        
          |  | nump_array = np.asarray(img, dtype=np.uint8) | 
        
          |  | if(nump_array.ndim < 3): | 
        
          |  | nump_array = np.expand_dims(nump_array, axis=-1) | 
        
          |  | nump_array = np.rollaxis(nump_array, 2) | 
        
          |  |  | 
        
          |  | tensor[i] += torch.from_numpy(nump_array) | 
        
          |  |  | 
        
          |  | return tensor, targets | 
        
          |  |  | 
        
          |  | best_prec1 = 0 | 
        
          |  | args = parser.parse_args() | 
        
          |  |  | 
        
          |  | print("opt_level = {}".format(args.opt_level)) | 
        
          |  | print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) | 
        
          |  | print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) | 
        
          |  |  | 
        
          |  | print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) | 
        
          |  |  | 
        
          |  | if args.deterministic: | 
        
          |  | cudnn.benchmark = False | 
        
          |  | cudnn.deterministic = True | 
        
          |  | torch.manual_seed(args.local_rank) | 
        
          |  | torch.set_printoptions(precision=10) | 
        
          |  |  | 
        
          |  | def main(): | 
        
          |  | global best_prec1, args | 
        
          |  |  | 
        
          |  | args.distributed = False | 
        
          |  | if 'WORLD_SIZE' in os.environ: | 
        
          |  | args.distributed = int(os.environ['WORLD_SIZE']) > 1 | 
        
          |  |  | 
        
          |  | args.gpu = 0 | 
        
          |  | args.world_size = 1 | 
        
          |  |  | 
        
          |  | if args.distributed: | 
        
          |  | args.gpu = args.local_rank | 
        
          |  | torch.cuda.set_device(args.gpu) | 
        
          |  | torch.distributed.init_process_group(backend='nccl', | 
        
          |  | init_method='env://') | 
        
          |  | args.world_size = torch.distributed.get_world_size() | 
        
          |  |  | 
        
          |  | assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | 
        
          |  |  | 
        
          |  | # create model | 
        
          |  | if args.pretrained: | 
        
          |  | print("=> using pre-trained model '{}'".format(args.arch)) | 
        
          |  | model = models.__dict__[args.arch](pretrained=True) | 
        
          |  | else: | 
        
          |  | print("=> creating model '{}'".format(args.arch)) | 
        
          |  | model = models.__dict__[args.arch]() | 
        
          |  |  | 
        
          |  | if args.sync_bn: | 
        
          |  | import apex | 
        
          |  | print("using apex synced BN") | 
        
          |  | model = apex.parallel.convert_syncbn_model(model) | 
        
          |  |  | 
        
          |  | model = model.cuda() | 
        
          |  |  | 
        
          |  | # Scale learning rate based on global batch size | 
        
          |  | args.lr = args.lr*float(args.batch_size*args.world_size)/256. | 
        
          |  | optimizer = torch.optim.SGD(model.parameters(), args.lr, | 
        
          |  | momentum=args.momentum, | 
        
          |  | weight_decay=args.weight_decay) | 
        
          |  |  | 
        
          |  | # Initialize Amp.  Amp accepts either values or strings for the optional override arguments, | 
        
          |  | # for convenient interoperation with argparse. | 
        
          |  | model, optimizer = amp.initialize(model, optimizer, | 
        
          |  | opt_level=args.opt_level, | 
        
          |  | keep_batchnorm_fp32=args.keep_batchnorm_fp32, | 
        
          |  | loss_scale=args.loss_scale | 
        
          |  | ) | 
        
          |  |  | 
        
          |  | # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. | 
        
          |  | # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called | 
        
          |  | # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter | 
        
          |  | # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. | 
        
          |  | if args.distributed: | 
        
          |  | # By default, apex.parallel.DistributedDataParallel overlaps communication with | 
        
          |  | # computation in the backward pass. | 
        
          |  | # model = DDP(model) | 
        
          |  | # delay_allreduce delays all communication to the end of the backward pass. | 
        
          |  | model = DDP(model, delay_allreduce=True) | 
        
          |  |  | 
        
          |  | # define loss function (criterion) and optimizer | 
        
          |  | criterion = nn.CrossEntropyLoss().cuda() | 
        
          |  |  | 
        
          |  | # Optionally resume from a checkpoint | 
        
          |  | if args.resume: | 
        
          |  | # Use a local scope to avoid dangling references | 
        
          |  | def resume(): | 
        
          |  | if os.path.isfile(args.resume): | 
        
          |  | print("=> loading checkpoint '{}'".format(args.resume)) | 
        
          |  | checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) | 
        
          |  | args.start_epoch = checkpoint['epoch'] | 
        
          |  | best_prec1 = checkpoint['best_prec1'] | 
        
          |  | model.load_state_dict(checkpoint['state_dict']) | 
        
          |  | optimizer.load_state_dict(checkpoint['optimizer']) | 
        
          |  | print("=> loaded checkpoint '{}' (epoch {})" | 
        
          |  | .format(args.resume, checkpoint['epoch'])) | 
        
          |  | else: | 
        
          |  | print("=> no checkpoint found at '{}'".format(args.resume)) | 
        
          |  | resume() | 
        
          |  |  | 
        
          |  | # Data loading code | 
        
          |  | traindir = os.path.join(args.data, 'train') | 
        
          |  | valdir = os.path.join(args.data, 'val') | 
        
          |  |  | 
        
          |  | if(args.arch == "inception_v3"): | 
        
          |  | raise RuntimeError("Currently, inception_v3 is not supported by this example.") | 
        
          |  | # crop_size = 299 | 
        
          |  | # val_size = 320 # I chose this value arbitrarily, we can adjust. | 
        
          |  | else: | 
        
          |  | crop_size = 224 | 
        
          |  | val_size = 256 | 
        
          |  |  | 
        
          |  | train_dataset = datasets.ImageFolder( | 
        
          |  | traindir, | 
        
          |  | transforms.Compose([ | 
        
          |  | transforms.RandomResizedCrop(crop_size), | 
        
          |  | transforms.RandomHorizontalFlip(), | 
        
          |  | # transforms.ToTensor(), Too slow | 
        
          |  | # normalize, | 
        
          |  | ])) | 
        
          |  | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ | 
        
          |  | transforms.Resize(val_size), | 
        
          |  | transforms.CenterCrop(crop_size), | 
        
          |  | ])) | 
        
          |  |  | 
        
          |  | train_sampler = None | 
        
          |  | val_sampler = None | 
        
          |  | if args.distributed: | 
        
          |  | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | 
        
          |  | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) | 
        
          |  |  | 
        
          |  | train_loader = torch.utils.data.DataLoader( | 
        
          |  | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), | 
        
          |  | num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) | 
        
          |  |  | 
        
          |  | val_loader = torch.utils.data.DataLoader( | 
        
          |  | val_dataset, | 
        
          |  | batch_size=args.batch_size, shuffle=False, | 
        
          |  | num_workers=args.workers, pin_memory=True, | 
        
          |  | sampler=val_sampler, | 
        
          |  | collate_fn=fast_collate) | 
        
          |  |  | 
        
          |  | if args.evaluate: | 
        
          |  | validate(val_loader, model, criterion) | 
        
          |  | return | 
        
          |  |  | 
        
          |  | for epoch in range(args.start_epoch, args.epochs): | 
        
          |  | if args.distributed: | 
        
          |  | train_sampler.set_epoch(epoch) | 
        
          |  |  | 
        
          |  | # train for one epoch | 
        
          |  | train(train_loader, model, criterion, optimizer, epoch) | 
        
          |  |  | 
        
          |  | # evaluate on validation set | 
        
          |  | prec1 = validate(val_loader, model, criterion) | 
        
          |  |  | 
        
          |  | # remember best prec@1 and save checkpoint | 
        
          |  | if args.local_rank == 0: | 
        
          |  | is_best = prec1 > best_prec1 | 
        
          |  | best_prec1 = max(prec1, best_prec1) | 
        
          |  | save_checkpoint({ | 
        
          |  | 'epoch': epoch + 1, | 
        
          |  | 'arch': args.arch, | 
        
          |  | 'state_dict': model.state_dict(), | 
        
          |  | 'best_prec1': best_prec1, | 
        
          |  | 'optimizer' : optimizer.state_dict(), | 
        
          |  | }, is_best) | 
        
          |  |  | 
        
          |  | class data_prefetcher(): | 
        
          |  | def __init__(self, loader): | 
        
          |  | self.loader = iter(loader) | 
        
          |  | self.stream = torch.cuda.Stream() | 
        
          |  | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) | 
        
          |  | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) | 
        
          |  | # With Amp, it isn't necessary to manually convert data to half. | 
        
          |  | # if args.fp16: | 
        
          |  | #     self.mean = self.mean.half() | 
        
          |  | #     self.std = self.std.half() | 
        
          |  | self.preload() | 
        
          |  |  | 
        
          |  | def preload(self): | 
        
          |  | try: | 
        
          |  | self.next_input, self.next_target = next(self.loader) | 
        
          |  | except StopIteration: | 
        
          |  | self.next_input = None | 
        
          |  | self.next_target = None | 
        
          |  | return | 
        
          |  | # if record_stream() doesn't work, another option is to make sure device inputs are created | 
        
          |  | # on the main stream. | 
        
          |  | # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') | 
        
          |  | # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') | 
        
          |  | # Need to make sure the memory allocated for next_* is not still in use by the main stream | 
        
          |  | # at the time we start copying to next_*: | 
        
          |  | # self.stream.wait_stream(torch.cuda.current_stream()) | 
        
          |  | with torch.cuda.stream(self.stream): | 
        
          |  | self.next_input = self.next_input.cuda(non_blocking=True) | 
        
          |  | self.next_target = self.next_target.cuda(non_blocking=True) | 
        
          |  | # more code for the alternative if record_stream() doesn't work: | 
        
          |  | # copy_ will record the use of the pinned source tensor in this side stream. | 
        
          |  | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) | 
        
          |  | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) | 
        
          |  | # self.next_input = self.next_input_gpu | 
        
          |  | # self.next_target = self.next_target_gpu | 
        
          |  |  | 
        
          |  | # With Amp, it isn't necessary to manually convert data to half. | 
        
          |  | # if args.fp16: | 
        
          |  | #     self.next_input = self.next_input.half() | 
        
          |  | # else: | 
        
          |  | self.next_input = self.next_input.float() | 
        
          |  | self.next_input = self.next_input.sub_(self.mean).div_(self.std) | 
        
          |  |  | 
        
          |  | def next(self): | 
        
          |  | torch.cuda.current_stream().wait_stream(self.stream) | 
        
          |  | input = self.next_input | 
        
          |  | target = self.next_target | 
        
          |  | if input is not None: | 
        
          |  | input.record_stream(torch.cuda.current_stream()) | 
        
          |  | if target is not None: | 
        
          |  | target.record_stream(torch.cuda.current_stream()) | 
        
          |  | self.preload() | 
        
          |  | return input, target | 
        
          |  |  | 
        
          |  |  | 
        
          |  | def train(train_loader, model, criterion, optimizer, epoch): | 
        
          |  | batch_time = AverageMeter() | 
        
          |  | losses = AverageMeter() | 
        
          |  | top1 = AverageMeter() | 
        
          |  | top5 = AverageMeter() | 
        
          |  |  | 
        
          |  | # switch to train mode | 
        
          |  | model.train() | 
        
          |  | end = time.time() | 
        
          |  |  | 
        
          |  | prefetcher = data_prefetcher(train_loader) | 
        
          |  | input, target = prefetcher.next() | 
        
          |  | i = 0 | 
        
          |  | while input is not None: | 
        
          |  | i += 1 | 
        
          |  | if args.prof >= 0 and i == args.prof: | 
        
          |  | print("Profiling begun at iteration {}".format(i)) | 
        
          |  | torch.cuda.cudart().cudaProfilerStart() | 
        
          |  |  | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i)) | 
        
          |  |  | 
        
          |  | adjust_learning_rate(optimizer, epoch, i, len(train_loader)) | 
        
          |  |  | 
        
          |  | # compute output | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_push("forward") | 
        
          |  | output = model(input) | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_pop() | 
        
          |  | loss = criterion(output, target) | 
        
          |  |  | 
        
          |  | # compute gradient and do SGD step | 
        
          |  | optimizer.zero_grad() | 
        
          |  |  | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_push("backward") | 
        
          |  | with amp.scale_loss(loss, optimizer) as scaled_loss: | 
        
          |  | scaled_loss.backward() | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_pop() | 
        
          |  |  | 
        
          |  | # for param in model.parameters(): | 
        
          |  | #     print(param.data.double().sum().item(), param.grad.data.double().sum().item()) | 
        
          |  |  | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()") | 
        
          |  | optimizer.step() | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_pop() | 
        
          |  |  | 
        
          |  | if i%args.print_freq == 0: | 
        
          |  | # Every print_freq iterations, check the loss, accuracy, and speed. | 
        
          |  | # For best performance, it doesn't make sense to print these metrics every | 
        
          |  | # iteration, since they incur an allreduce and some host<->device syncs. | 
        
          |  |  | 
        
          |  | # Measure accuracy | 
        
          |  | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | 
        
          |  |  | 
        
          |  | # Average loss and accuracy across processes for logging | 
        
          |  | if args.distributed: | 
        
          |  | reduced_loss = reduce_tensor(loss.data) | 
        
          |  | prec1 = reduce_tensor(prec1) | 
        
          |  | prec5 = reduce_tensor(prec5) | 
        
          |  | else: | 
        
          |  | reduced_loss = loss.data | 
        
          |  |  | 
        
          |  | # to_python_float incurs a host<->device sync | 
        
          |  | losses.update(to_python_float(reduced_loss), input.size(0)) | 
        
          |  | top1.update(to_python_float(prec1), input.size(0)) | 
        
          |  | top5.update(to_python_float(prec5), input.size(0)) | 
        
          |  |  | 
        
          |  | torch.cuda.synchronize() | 
        
          |  | batch_time.update((time.time() - end)/args.print_freq) | 
        
          |  | end = time.time() | 
        
          |  |  | 
        
          |  | if args.local_rank == 0: | 
        
          |  | print('Epoch: [{0}][{1}/{2}]\t' | 
        
          |  | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | 
        
          |  | 'Speed {3:.3f} ({4:.3f})\t' | 
        
          |  | 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' | 
        
          |  | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | 
        
          |  | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | 
        
          |  | epoch, i, len(train_loader), | 
        
          |  | args.world_size*args.batch_size/batch_time.val, | 
        
          |  | args.world_size*args.batch_size/batch_time.avg, | 
        
          |  | batch_time=batch_time, | 
        
          |  | loss=losses, top1=top1, top5=top5)) | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()") | 
        
          |  | input, target = prefetcher.next() | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_pop() | 
        
          |  |  | 
        
          |  | # Pop range "Body of iteration {}".format(i) | 
        
          |  | if args.prof >= 0: torch.cuda.nvtx.range_pop() | 
        
          |  |  | 
        
          |  | if args.prof >= 0 and i == args.prof + 10: | 
        
          |  | print("Profiling ended at iteration {}".format(i)) | 
        
          |  | torch.cuda.cudart().cudaProfilerStop() | 
        
          |  | quit() | 
        
          |  |  | 
        
          |  |  | 
        
          |  | def validate(val_loader, model, criterion): | 
        
          |  | batch_time = AverageMeter() | 
        
          |  | losses = AverageMeter() | 
        
          |  | top1 = AverageMeter() | 
        
          |  | top5 = AverageMeter() | 
        
          |  |  | 
        
          |  | # switch to evaluate mode | 
        
          |  | model.eval() | 
        
          |  |  | 
        
          |  | end = time.time() | 
        
          |  |  | 
        
          |  | prefetcher = data_prefetcher(val_loader) | 
        
          |  | input, target = prefetcher.next() | 
        
          |  | i = 0 | 
        
          |  | while input is not None: | 
        
          |  | i += 1 | 
        
          |  |  | 
        
          |  | # compute output | 
        
          |  | with torch.no_grad(): | 
        
          |  | output = model(input) | 
        
          |  | loss = criterion(output, target) | 
        
          |  |  | 
        
          |  | # measure accuracy and record loss | 
        
          |  | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | 
        
          |  |  | 
        
          |  | if args.distributed: | 
        
          |  | reduced_loss = reduce_tensor(loss.data) | 
        
          |  | prec1 = reduce_tensor(prec1) | 
        
          |  | prec5 = reduce_tensor(prec5) | 
        
          |  | else: | 
        
          |  | reduced_loss = loss.data | 
        
          |  |  | 
        
          |  | losses.update(to_python_float(reduced_loss), input.size(0)) | 
        
          |  | top1.update(to_python_float(prec1), input.size(0)) | 
        
          |  | top5.update(to_python_float(prec5), input.size(0)) | 
        
          |  |  | 
        
          |  | # measure elapsed time | 
        
          |  | batch_time.update(time.time() - end) | 
        
          |  | end = time.time() | 
        
          |  |  | 
        
          |  | # TODO:  Change timings to mirror train(). | 
        
          |  | if args.local_rank == 0 and i % args.print_freq == 0: | 
        
          |  | print('Test: [{0}/{1}]\t' | 
        
          |  | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | 
        
          |  | 'Speed {2:.3f} ({3:.3f})\t' | 
        
          |  | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | 
        
          |  | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' | 
        
          |  | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | 
        
          |  | i, len(val_loader), | 
        
          |  | args.world_size * args.batch_size / batch_time.val, | 
        
          |  | args.world_size * args.batch_size / batch_time.avg, | 
        
          |  | batch_time=batch_time, loss=losses, | 
        
          |  | top1=top1, top5=top5)) | 
        
          |  |  | 
        
          |  | input, target = prefetcher.next() | 
        
          |  |  | 
        
          |  | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' | 
        
          |  | .format(top1=top1, top5=top5)) | 
        
          |  |  | 
        
          |  | return top1.avg | 
        
          |  |  | 
        
          |  |  | 
        
          |  | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): | 
        
          |  | torch.save(state, filename) | 
        
          |  | if is_best: | 
        
          |  | shutil.copyfile(filename, 'model_best.pth.tar') | 
        
          |  |  | 
        
          |  |  | 
        
          |  | class AverageMeter(object): | 
        
          |  | """Computes and stores the average and current value""" | 
        
          |  | def __init__(self): | 
        
          |  | self.reset() | 
        
          |  |  | 
        
          |  | def reset(self): | 
        
          |  | self.val = 0 | 
        
          |  | self.avg = 0 | 
        
          |  | self.sum = 0 | 
        
          |  | self.count = 0 | 
        
          |  |  | 
        
          |  | def update(self, val, n=1): | 
        
          |  | self.val = val | 
        
          |  | self.sum += val * n | 
        
          |  | self.count += n | 
        
          |  | self.avg = self.sum / self.count | 
        
          |  |  | 
        
          |  |  | 
        
          |  | def adjust_learning_rate(optimizer, epoch, step, len_epoch): | 
        
          |  | """LR schedule that should yield 76% converged accuracy with batch size 256""" | 
        
          |  | factor = epoch // 30 | 
        
          |  |  | 
        
          |  | if epoch >= 80: | 
        
          |  | factor = factor + 1 | 
        
          |  |  | 
        
          |  | lr = args.lr*(0.1**factor) | 
        
          |  |  | 
        
          |  | """Warmup""" | 
        
          |  | if epoch < 5: | 
        
          |  | lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) | 
        
          |  |  | 
        
          |  | # if(args.local_rank == 0): | 
        
          |  | #     print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) | 
        
          |  |  | 
        
          |  | for param_group in optimizer.param_groups: | 
        
          |  | param_group['lr'] = lr | 
        
          |  |  | 
        
          |  |  | 
        
          |  | def accuracy(output, target, topk=(1,)): | 
        
          |  | """Computes the precision@k for the specified values of k""" | 
        
          |  | maxk = max(topk) | 
        
          |  | batch_size = target.size(0) | 
        
          |  |  | 
        
          |  | _, pred = output.topk(maxk, 1, True, True) | 
        
          |  | pred = pred.t() | 
        
          |  | correct = pred.eq(target.view(1, -1).expand_as(pred)) | 
        
          |  |  | 
        
          |  | res = [] | 
        
          |  | for k in topk: | 
        
          |  | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | 
        
          |  | res.append(correct_k.mul_(100.0 / batch_size)) | 
        
          |  | return res | 
        
          |  |  | 
        
          |  |  | 
        
          |  | def reduce_tensor(tensor): | 
        
          |  | rt = tensor.clone() | 
        
          |  | dist.all_reduce(rt, op=dist.reduce_op.SUM) | 
        
          |  | rt /= args.world_size | 
        
          |  | return rt | 
        
          |  |  | 
        
          |  | if __name__ == '__main__': | 
        
          |  | main() |