Last active
August 16, 2019 06:17
-
-
Save vinx13/6f1eb1f9e2c0a8786149ee881bfcd6aa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import logging | |
import argparse | |
import os | |
import mxnet as mx | |
from mxnet import gluon | |
from mxnet.gluon.model_zoo import vision | |
import tvm | |
import tvm.relay as relay | |
import tvm.relay.expr as _expr | |
import tvm.relay.transform as _transform | |
from tvm.contrib import graph_runtime | |
from scipy import stats | |
import pickle | |
import multiprocessing as mp | |
# Two functions for reading data from record file or raw images | |
def get_val_data(args, | |
rec_val, | |
batch_size, | |
num_workers=4, | |
shuffle=False): | |
rec_val = os.path.expanduser(rec_val) | |
mean_rgb = [123.68, 116.779, 103.939] | |
std_rgb = [58.393, 57.12, 57.375] | |
def batch_fn(batch, ctx): | |
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | |
return data, label | |
img_size = 299 if args.model == 'inceptionv3' else 224 | |
val_data = mx.io.ImageRecordIter( | |
path_imgrec = rec_val, | |
preprocess_threads = num_workers, | |
shuffle = shuffle, | |
batch_size = batch_size, | |
resize = 256, | |
data_shape = (3, img_size, img_size), | |
mean_r = mean_rgb[0], | |
mean_g = mean_rgb[1], | |
mean_b = mean_rgb[2], | |
std_r = std_rgb[0], | |
std_g = std_rgb[1], | |
std_b = std_rgb[2], | |
) | |
return val_data, batch_fn | |
def calibration_dataset(): | |
val_data, batch_fn = get_val_data(args, args.rec_val, args.batch_size, shuffle=True) | |
val_data.reset() | |
for i, batch in enumerate(val_data): | |
if i*args.batch_size > args.calibration_samples: | |
break | |
data, label = batch_fn(batch, [mx.cpu(0)]) | |
yield {'data': data[0].asnumpy()} | |
def evaluate(args, graph, lib, params, ctx): | |
"""Evaluate on the validation set.""" | |
# setup dataset. | |
batch_size = args.batch_size | |
val_data, batch_fn = get_val_data(args, args.rec_val, batch_size) | |
# create runtime module | |
m = graph_runtime.create(graph, lib, ctx) | |
m.set_input(**params) | |
oshape = (batch_size, args.num_classes) | |
out_arr = tvm.nd.empty(oshape, "float32") | |
# setup evaluaiton metric | |
acc_top1 = mx.metric.Accuracy() | |
acc_top5 = mx.metric.TopKAccuracy(5) | |
val_data.reset() | |
acc_top1.reset() | |
acc_top5.reset() | |
# Execute | |
for i, batch in enumerate(val_data): | |
data, label = batch_fn(batch, [mx.cpu(0)]) | |
m.run(data=data[0].asnumpy()) | |
m.get_output(0, out_arr) | |
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())]) | |
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())]) | |
if args.log_interval and not (i + 1) % args.log_interval: | |
_, top1 = acc_top1.get() | |
_, top5 = acc_top5.get() | |
nsamples = (i + 1) * batch_size | |
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5) | |
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) | |
with open(args.record_file, "a") as f: | |
f.write('{}, {} / {}\n'.format( | |
args.model, top1, top5)) | |
def calibrate_on_dataset(qgraph): | |
profile_graph = relay.quantize.collect_stats(qgraph) | |
with relay.build_config(opt_level=3): | |
graph, lib, params = relay.build(relay.Module.from_expr(profile_graph), target=args.target) | |
outputs = [] | |
m = graph_runtime.create(graph, lib, tvm.context(args.target, args.device_id)) | |
m.set_input(**params) | |
num_outputs = m.get_num_outputs() | |
outputs = [[] for i in range(num_outputs)] | |
for batch_id, batch in enumerate(calibration_dataset()): | |
print('batch {}..'.format(batch_id)) | |
m.set_input(**batch) | |
m.run() | |
for i in range(num_outputs): | |
output = m.get_output(i).asnumpy() | |
outputs[i].append(output) | |
for i in range(num_outputs): | |
outputs[i] = np.concatenate(outputs[i]).reshape(-1) | |
with mp.Pool() as pool: | |
scales = list(pool.map(relay.quantize.kl_divergence.kl_divergence_scale, outputs)) | |
return scales | |
def build_model(gluon_model, original): | |
"""Build with relay.""" | |
import tvm | |
from tvm import relay | |
from tvm.relay import quantize as qtz | |
img_size = 299 if args.model == 'inceptionv3' else 224 | |
data_shape = (args.batch_size, 3, img_size, img_size) | |
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) | |
target = args.target | |
ctx = tvm.context(target, args.device_id) | |
if original: | |
# run original model | |
with relay.build_config(opt_level=3): | |
graph, lib, params = relay.build(net, target, params=params) | |
return graph, lib, params, ctx | |
skip_conv_layers = [0] | |
with relay.quantize.qconfig(store_lowbit_output=False, skip_conv_layers=skip_conv_layers): | |
from tvm.relay.quantize.quantize import _bind_params | |
graph = _bind_params(mod['main'], params) | |
mod = relay.Module.from_expr(graph) | |
optimize = _transform.Sequential([_transform.SimplifyInference(), | |
_transform.FoldConstant(), | |
_transform.FoldScaleAxis(), | |
_transform.CanonicalizeOps(), | |
_transform.FoldConstant()]) | |
with relay.build_config(opt_level=2): | |
mod = optimize(mod) | |
mod = relay.quantize.annotate()(mod) | |
cache_file = '{}_scales.pkl'.format(args.model) | |
if os.path.exists(cache_file): | |
with open(cache_file, 'rb') as f: | |
scales = pickle.load(f) | |
else: | |
scales = calibrate_on_dataset(mod['main']) | |
with open(cache_file, 'wb') as f: | |
pickle.dump(scales, f) | |
if args.eval_power2: | |
scales = list(map(lambda scale: 2**np.math.ceil(np.math.log(scale, 2)) if scale > 0 else 1.0, scales)) | |
weight_scales = 'power2' | |
else: | |
weight_scales = 'max' | |
mod['main'] = relay.quantize.calibrate(mod['main'], weight_scales=weight_scales, | |
scales=scales) | |
mod = relay.quantize.realize()(mod) | |
mod = relay.transform.FoldConstant()(mod) | |
graph, lib, params = relay.build(mod, target=args.target) | |
return graph, lib, params, ctx | |
def save_model(name, graph, lib, params): | |
with open(name + '.json', 'w') as f: | |
f.write(graph) | |
lib.export_library(name + '.so') | |
with open(name + '.bin', 'wb') as f: | |
f.write(relay.save_param_dict(params)) | |
def main(): | |
gluon_model = vision.get_model(args.model, pretrained=True) | |
graph, lib, params, ctx = build_model(gluon_model, args.original) | |
logging.info("Finish building model %s...", args.model) | |
evaluate(args, graph, lib, params, ctx) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Evaluate ImageNet validation accuracy") | |
parser.add_argument("--rec-val", type=str, default="~/.mxnet/datasets/imagenet/rec/val.rec", | |
help="the validation data") | |
parser.add_argument("--num-classes", type=int, default=1000, | |
help="batch size") | |
parser.add_argument("--model", type=str, default="resnet50_v2", | |
help="Name of the model") | |
parser.add_argument("--log-interval", type=int, default=100, | |
help="log interval") | |
parser.add_argument("--batch-size", type=int, default=1, | |
help="batch size") | |
parser.add_argument("--target", type=str, default="cuda", | |
help="target option") | |
parser.add_argument("--original", action="store_true", | |
help='whether to use original graph') | |
parser.add_argument('--save_model', type=str, default=None) | |
parser.add_argument('--calibration_samples', type=int, default=100) | |
parser.add_argument('--device-id', type=int, default=0) | |
parser.add_argument('--eval-power2', action='store_true', | |
help='in this mode, scales are restricted to power-of-2 (weight: power2' \ | |
'scale, activation: round kld to power2)') | |
parser.add_argument('--record-file', type=str, default='record.csv', | |
help='file to save eval result') | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.INFO) | |
logging.info(args) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@mingwayzhang I have updated line 149 and it works fine locally. There are some accuracy issues after #3543 is merged. I'm working on it