Created
January 11, 2018 22:03
-
-
Save zhreshold/0ef2aafe48dbbc14a18ce8afffce3c00 to your computer and use it in GitHub Desktop.
Cifar10 with gluon model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import random | |
import time | |
import mxnet as mx | |
from mxnet import nd | |
from mxnet import image | |
from mxnet import gluon | |
from mxnet import autograd | |
import numpy as np | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Train CIFAR10.") | |
parser.add_argument('--batch-size', type=int, default=128, | |
help='training batch size.') | |
parser.add_argument('--num-gpus', type=int, default=1, | |
help='number of gpus to use.') | |
parser.add_argument('--epochs', type=int, default=350, | |
help='number of training epochs.') | |
parser.add_argument('--lr', type=float, default=0.1, | |
help='learning rate. default is 0.01.') | |
parser.add_argument('-momentum', type=float, default=0.9, | |
help='momentum value for optimizer, default is 0.9.') | |
parser.add_argument('--wd', type=float, default=0.0001, | |
help='weight decay rate. default is 0.0001.') | |
parser.add_argument('--seed', type=int, default=123, | |
help='random seed to use. Default=123.') | |
parser.add_argument('--log-interval', type=int, default=50, | |
help='Number of batches to wait before logging.') | |
parser.add_argument('--kvstore', type=str, default='device', | |
help='kvstore to use for trainer/module.') | |
args = parser.parse_args() | |
return args | |
def get_data_rec(batch_size): | |
import os | |
data_dir="data" | |
def download_cifar10(): | |
from mxnet.test_utils import download | |
fnames = (os.path.join(data_dir, "cifar10_train.rec"), | |
os.path.join(data_dir, "cifar10_val.rec")) | |
download('http://data.mxnet.io/data/cifar10/cifar10_val.rec', fnames[1]) | |
download('http://data.mxnet.io/data/cifar10/cifar10_train.rec', fnames[0]) | |
return fnames | |
(train_fname, val_fname) = download_cifar10() | |
train = mx.io.ImageRecordIter( | |
path_imgrec = os.path.join(data_dir, "cifar10_train.rec"), | |
label_width = 1, | |
data_name = 'data', | |
label_name = 'softmax_label', | |
data_shape = (3, 32, 32), | |
batch_size = batch_size, | |
pad = 4, | |
fill_value = 127, # only used when pad is valid | |
rand_crop = True, | |
max_random_scale = 1.0, # 480 with imagnet, 32 with cifar10 | |
min_random_scale = 1.0, # 256.0/480.0 | |
max_aspect_ratio = 0, | |
random_h = 0, | |
random_s = 0, | |
random_l = 0, | |
max_rotate_angle = 0, | |
max_shear_ratio = 0, | |
rand_mirror = True, | |
shuffle = True,) | |
val = mx.io.ImageRecordIter( | |
path_imgrec = os.path.join(data_dir, "cifar10_val.rec"), | |
label_width = 1, | |
data_name = 'data', | |
label_name = 'softmax_label', | |
batch_size = batch_size, | |
data_shape = (3, 32, 32), | |
rand_crop = False, | |
rand_mirror = False,) | |
return train, val | |
def get_data(batch_size): | |
# transform = lambda data, label: (data.astype('float32').transpose((2, 0, 1))/255, label) | |
def train_transform(data, label): | |
data = mx.image.imresize(data, 40, 40) | |
data = mx.image.RandomCropAug((32, 32))(data) | |
data = mx.image.HorizontalFlipAug(0.5)(data) | |
data = data.astype('float32') / 255 | |
data = (data - mx.nd.array([0.4914, 0.4822, 0.4465])) / mx.nd.array([0.2023, 0.1994, 0.2010]) | |
data = data.transpose((2, 0, 1)) | |
return data, label | |
def val_transform(data, label): | |
data = data.astype('float32') / 255 | |
data = (data - mx.nd.array([0.4914, 0.4822, 0.4465])) / mx.nd.array([0.2023, 0.1994, 0.2010]) | |
data = data.transpose((2, 0, 1)) | |
return data, label | |
train_dataset = gluon.data.vision.CIFAR10(transform=train_transform) | |
val_dataset = gluon.data.vision.CIFAR10(train=False, transform=val_transform) | |
train_data = gluon.data.DataLoader( | |
train_dataset, batch_size=batch_size, shuffle=True, last_batch='keep') | |
val_data = gluon.data.DataLoader( | |
val_dataset, batch_size=batch_size, shuffle=False, last_batch='keep') | |
return train_data, val_data | |
def get_net(): | |
net = gluon.model_zoo.vision.get_model('resnet101_v2', thumbnail=True) | |
return net | |
def test(val_data, ctx): | |
try: | |
val_data.reset() | |
except: | |
pass | |
metric = mx.metric.Accuracy() | |
for batch in val_data: | |
try: | |
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) | |
except: | |
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | |
outputs = [] | |
for x in data: | |
out = net(x) | |
outputs.append(out) | |
metric.update(label, outputs) | |
return metric.get() | |
def train(net, train_data, val_data, epochs, lr, momentum, wd, ctx, kvstore, log_interval): | |
net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx) | |
net.hybridize() | |
trainer = gluon.Trainer(net.collect_params(), 'sgd', | |
{'learning_rate': lr, 'wd': wd, 'momentum': momentum}, | |
kvstore = kvstore) | |
metric = mx.metric.Accuracy() | |
loss = gluon.loss.SoftmaxCrossEntropyLoss() | |
logging.info("Start training on {}.".format(str(ctx))) | |
for epoch in range(args.epochs): | |
try: | |
train_data.reset() | |
except: | |
pass | |
if epoch in [150, 250]: | |
trainer.set_learning_rate(trainer.learning_rate * 0.1) | |
print('reduced learning rate to {}'.format(trainer.learning_rate)) | |
tic = time.time() | |
metric.reset() | |
btic = time.time() | |
for i, batch in enumerate(train_data): | |
try: | |
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) | |
except: | |
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | |
outputs = [] | |
Ls = [] | |
with autograd.record(): | |
for x, y in zip(data, label): | |
z = net(x) | |
L = loss(z, y) | |
# store the loss and do backward after we have done forward | |
# on all GPUs for better speed on multiple GPUs. | |
Ls.append(L) | |
outputs.append(z) | |
autograd.backward(Ls) | |
batch_size = np.prod([d.shape[0] for d in data]) | |
trainer.step(batch_size, ignore_stale_grad=False) | |
metric.update(label, outputs) | |
if log_interval and not (i+1)%log_interval: | |
name, acc = metric.get() | |
logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%( | |
epoch, i, batch_size/(time.time()-btic), name, acc)) | |
btic = time.time() | |
name, acc = metric.get() | |
logging.info('[Epoch %d] training: %s=%f'%(epoch, name, acc)) | |
logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic)) | |
name, val_acc = test(val_data, ctx) | |
logging.info('[Epoch %d] validation: %s=%f'%(epoch, name, val_acc)) | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
args = parse_args() | |
logging.info(args) | |
random.seed(args.seed) | |
ctx = [mx.gpu(i) for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()] | |
net = get_net() | |
train_data, val_data = get_data_rec(args.batch_size) | |
train(net, train_data, val_data, args.epochs, args.lr, args.momentum, args.wd, | |
ctx, args.kvstore, args.log_interval) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment