Created
December 18, 2019 08:03
-
-
Save mjamroz/2b3a61e98e0ac9ef69f187d56a21e856 to your computer and use it in GitHub Desktop.
mxnet perform test with test set
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse, time, logging, os,math | |
import numpy as np | |
import mxnet as mx | |
import gluoncv as gcv | |
from mxnet import gluon, nd | |
from mxnet import autograd as ag | |
from mxnet.gluon import nn | |
from mxnet.gluon.data.vision import transforms | |
from gluoncv.data import imagenet | |
from gluoncv.model_zoo import get_model | |
from gluoncv.utils import makedirs, LRSequential, LRScheduler, export_block | |
# CLI | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a model for image classification.') | |
parser.add_argument('--rec-val', type=str, default='~/.mxnet/datasets/imagenet/rec/val.rec', | |
help='the validation data') | |
parser.add_argument('--rec-val-idx', type=str, default='~/.mxnet/datasets/imagenet/rec/val.idx', | |
help='the index of validation data') | |
parser.add_argument('--use-rec', action='store_true', | |
help='use image record iter for data input. default is false.') | |
parser.add_argument('--batch-size', type=int, default=32, | |
help='training batch size per device (CPU/GPU).') | |
parser.add_argument('--dtype', type=str, default='float32', | |
help='data type for training. default is float32') | |
parser.add_argument('--num-gpus', type=int, default=0, | |
help='number of gpus to use.') | |
parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, | |
help='number of preprocessing workers') | |
parser.add_argument('--model', type=str, required=True, | |
help='type of model to use. see vision_model for options.') | |
parser.add_argument('--input-size', type=int, default=224, | |
help='size of the input image size. default is 224') | |
parser.add_argument('--crop-ratio', type=float, default=0.875, | |
help='Crop ratio during validation. default is 0.875') | |
parser.add_argument('--use-pretrained', action='store_true', | |
help='enable using pretrained model from gluon.') | |
parser.add_argument('--use_se', action='store_true', | |
help='use SE layers or not in resnext. default is false.') | |
parser.add_argument('--resume-params', type=str, default='', | |
help='path of parameters to load from.') | |
parser.add_argument('--resume-states', type=str, default='', | |
help='path of trainer state to load from.') | |
opt = parser.parse_args() | |
return opt | |
def main(): | |
opt = parse_args() | |
streamhandler = logging.StreamHandler() | |
logger = logging.getLogger('') | |
logger.setLevel(logging.INFO) | |
logger.addHandler(streamhandler) | |
logger.info(opt) | |
batch_size = opt.batch_size | |
classes = 2765 | |
num_training_samples = 211872 | |
num_gpus = opt.num_gpus | |
batch_size *= max(1, num_gpus) | |
context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] | |
num_workers = opt.num_workers | |
num_batches = num_training_samples // batch_size | |
model_name = opt.model | |
kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes} | |
if model_name.startswith('vgg'): | |
kwargs['batch_norm'] = opt.batch_norm | |
elif model_name.startswith('resnext'): | |
kwargs['use_se'] = opt.use_se | |
net = get_model(model_name, **kwargs) | |
net.cast(opt.dtype) | |
if opt.resume_params is not '': | |
net.load_parameters(opt.resume_params, ctx = context) | |
# Two functions for reading data from record file or raw images | |
def get_data_rec(rec_val, rec_val_idx, batch_size, num_workers): | |
rec_val = os.path.expanduser(rec_val) | |
rec_val_idx = os.path.expanduser(rec_val_idx) | |
input_size = opt.input_size | |
crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 | |
resize = int(math.ceil(input_size / crop_ratio)) | |
mean_rgb = [123.68, 116.779, 103.939] | |
std_rgb = [58.393, 57.12, 57.375] | |
def batch_fn(batch, ctx): | |
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | |
return data, label | |
val_data = mx.io.ImageRecordIter( | |
path_imgrec = rec_val.replace("val","test"), | |
path_imgidx = rec_val_idx.replace("val","test"), | |
preprocess_threads = num_workers, | |
shuffle = False, | |
batch_size = batch_size, | |
resize = resize, | |
data_shape = (3, input_size, input_size), | |
mean_r = mean_rgb[0], | |
mean_g = mean_rgb[1], | |
mean_b = mean_rgb[2], | |
std_r = std_rgb[0], | |
std_g = std_rgb[1], | |
std_b = std_rgb[2], | |
) | |
return [], val_data, batch_fn | |
if opt.use_rec: | |
train_data, val_data, batch_fn = get_data_rec(opt.rec_val, opt.rec_val_idx, batch_size, num_workers) | |
acc_top1 = mx.metric.Accuracy() | |
acc_top5 = mx.metric.TopKAccuracy(5) | |
def test(ctx, val_data): | |
if opt.use_rec: | |
val_data.reset() | |
acc_top1.reset() | |
acc_top5.reset() | |
for i, batch in enumerate(val_data): | |
data, label = batch_fn(batch, ctx) | |
outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] | |
acc_top1.update(label, outputs) | |
acc_top5.update(label, outputs) | |
_, top1 = acc_top1.get() | |
_, top5 = acc_top5.get() | |
return (top1, top5) | |
err_top1_val, err_top5_val = test(context, val_data) | |
logger.info('[Epoch %d] validation: acc1=%f acc5=%f'%(-1, err_top1_val, err_top5_val)) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env bash | |
for v in 70 68 65 64 61 60 59 | |
do | |
echo $v | |
python3 perform_tests.py --rec-val images_rec_val.rec \ | |
--rec-val-idx images_rec_val.idx \ | |
--model SE_ResNext50_32x4d \ | |
--num-gpus 2 \ | |
--use-rec \ | |
--resume-params pre_params_se_resnext101_64x4d_best/*$v-best.params \ | |
--resume-states pre_params_se_resnext101_64x4d_best/*$v-best.states | |
done |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment