Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Last active December 28, 2017 03:28
Show Gist options
  • Save zhreshold/971377e19ac8478e741cd18c7f6a2be2 to your computer and use it in GitHub Desktop.
Save zhreshold/971377e19ac8478e741cd18c7f6a2be2 to your computer and use it in GitHub Desktop.
ImageNet validation
import os
import argparse
import shutil
import time
import logging
import numpy as np
import mxnet as mx
from mxnet import gluon
from mxnet import autograd
from mxnet.gluon import nn
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.image import SequentialAug, RandomSizedCropAug, ResizeAug, CenterCropAug, HorizontalFlipAug
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description='Gluon ImageNet12 training',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data', required=True, type=str,
help='path to dataset')
parser.add_argument('--model', required=True, type=str,
help='gluon model name, e.g. resnet18_v1')
parser.add_argument('-j', '--workers', dest='num_workers', default=4, type=int,
help='number of preprocessing workers')
parser.add_argument('--gpus', default='0', type=str,
help='gpus to use, multiple gpus supported as "0,1,2,3"')
parser.add_argument('--epochs', default=120, type=int,
help='number of training epochs')
parser.add_argument('--start-epoch', default=0, type=int,
help='starting epoch, 0 for fresh training, > 0 to resume')
parser.add_argument('-b', '--batch-size', default=256, type=int,
help='mini-batch size')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
help='momentum')
parser.add_argument('--weight-decay', '--wd', dest='wd', default=1e-4, type=float,
help='weight decay (default: 1e-4)')
parser.add_argument('--log-interval', '-p', default=10, type=int,
help='print frequency (default: 10)')
parser.add_argument('--prefix', default='', type=str,
help='path to checkpoint')
parser.add_argument('--resume', default='', type=str,
help='path to resuming checkpoint')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--kvstore', default='device', type=str,
help='kvstore type')
parser.add_argument('--lr-factor', default=0.1, type=float,
help='learning rate decay ratio')
parser.add_argument('--lr-steps', default='30,60,90', type=str,
help='list of learning rate decay epochs as in str')
parser.add_argument('--dtype', default='float32', type=str,
help='data type, float32 or float16 if applicable')
parser.add_argument('--save-frequency', default=10, type=int,
help='model save frequent, best model will always be saved')
args = parser.parse_args()
return args
def get_model(model, resume, pretrained, dtype='float32'):
"""Model initialization."""
net = gluon.model_zoo.vision.get_model(model, pretrained=pretrained, classes=1000)
if resume:
net.load_params(resume)
elif not pretrained:
if model in ['alexnet']:
net.initialize(mx.init.Normal())
else:
net.initialize(mx.init.Xavier(magnitude=2))
net.cast(dtype)
net.hybridize()
return net
def get_transform_function(dtype='float32'):
def train_transform(image, label):
image, _ = mx.image.random_size_crop(image, (224, 224), 0.08, (3/4., 4/3.))
image = mx.nd.image.random_flip_left_right(image)
image = mx.nd.image.to_tensor(image)
image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
return mx.nd.cast(image, dtype), label
def val_transform(image, label):
image = mx.image.resize_short(image, 256)
image, _ = mx.image.center_crop(image, (224, 224))
image = mx.nd.image.to_tensor(image)
image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
return mx.nd.cast(image, dtype), label
return train_transform, val_transform
def get_dataloader(root, batch_size, num_workers, dtype='float32'):
"""Dataset loader with preprocessing."""
train_transform, val_transform = get_transform_function(dtype)
val_dir = os.path.join(root, 'val')
logging.info("Loading image folder %s, this may take a bit long...", val_dir)
val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
return val_data
def update_learning_rate(lr, trainer, epoch, ratio, steps):
"""Set the learning rate to the initial value decayed by ratio every N epochs."""
new_lr = lr * (ratio ** int(np.sum(np.array(steps) < epoch)))
trainer.set_learning_rate(new_lr)
return trainer
def validate(net, val_data, metrics, ctx):
"""Validation."""
for m in metrics:
m.reset()
for i, batch in enumerate(val_data):
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
outputs = []
for x, y in zip(data, label):
z = net(x)
outputs.append(z)
for m in metrics:
m.update(label, outputs)
msg = ','.join(['%s=%f'%(m.get()) for m in metrics])
return msg, metrics[0].get()[1]
def train(net, train_data, val_data, ctx, args):
"""Training"""
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
metrics = [mx.metric.Accuracy(), mx.metric.TopKAccuracy(5)]
lr_steps = [int(x) for x in args.lr_steps.split(',') if x.strip()]
net.collect_params().reset_ctx(ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': args.lr, 'wd': args.wd,
'momentum': args.momentum, 'multi_precision': True},
kvstore = args.kvstore)
# start training
best_acc = 0
for epoch in range(args.start_epoch, args.epochs):
trainer = update_learning_rate(args.lr, trainer, epoch, args.lr_factor, lr_steps)
for m in metrics:
m.reset()
tic = time.time()
btic = time.time()
for i, batch in enumerate(train_data):
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
outputs = []
losses = []
with autograd.record():
for x, y in zip(data, label):
z = net(x)
L = criterion(z, y)
losses.append(L)
outputs.append(z)
autograd.backward(losses)
batch_size = args.batch_size
trainer.step(batch_size)
for m in metrics:
m.update(label, outputs)
if args.log_interval and (i + 1) % args.log_interval == 0:
msg = ','.join(['%s=%f'%(m.get()) for m in metrics])
logging.info('Epoch[%d] Batch[%d]\tSpeed: %f samples/sec\t%s',
epoch, i, batch_size/(time.time()-btic), msg)
btic = time.time()
msg = ','.join(['%s=%f'%(m.get()) for m in metrics])
logging.info('[Epoch %d] Training: %s', epoch, msg)
logging.info('[Epoch %d] Training time cost: %f', epoch, time.time()-tic)
msg, top1 = validate(net, val_data, metrics, ctx)
logging.info('[Epoch %d] Validation: %s', epoch, msg)
if args.save_frequency and (epoch + 1) % args.save_frequency == 0:
fname = os.path.join(args.prefix, '%s_%d_acc_%.4f.params' % (args.model, epoch, top1))
net.save_params(fname)
logging.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
if top1 > best_acc:
best_acc = top1
fname = os.path.join(args.prefix, '%s_best.params' % (args.model))
net.save_params(fname)
logging.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
if __name__ == '__main__':
args = parse_args()
logging.info(args)
# get the network
net = get_model(args.model, args.resume, args.pretrained, args.dtype)
# get the dataset
val_data = get_dataloader(args.data, args.batch_size, args.num_workers, args.dtype)
# set up contexts
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = [mx.cpu()] if not ctx else ctx
# start training
net.collect_params().reset_ctx(ctx)
metrics = [mx.metric.Accuracy(), mx.metric.TopKAccuracy(5)]
msg, _ = validate(net, val_data, metrics, ctx)
logging.info('[%s] %s', args.model, msg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment