Skip to content

Instantly share code, notes, and snippets.

@cyrusbehr
Last active September 12, 2019 21:57
Show Gist options
  • Save cyrusbehr/6ce760e223b4eba22ebcd4c0f44d28c0 to your computer and use it in GitHub Desktop.
Save cyrusbehr/6ce760e223b4eba22ebcd4c0f44d28c0 to your computer and use it in GitHub Desktop.
Auto tune MXNet model and compile to TVM
import os
import numpy as np
import nnvm.testing
import nnvm.compiler
import tvm
import mxnet as mx
from tvm import autotvm
import tvm.relay as relay
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
import tvm.contrib.graph_runtime as runtime
def get_network(batch_size):
prefix,epoch = "/[path_to_model_dir_here]/model",0
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
opt_level = 3
shape_dict = {'data': (1, 3, 112, 112)}
nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(sym, arg_params, aux_params)
input_shape = (batch_size, 3, 112, 112)
output_shape = (batch_size, 512)
return nnvm_sym, nnvm_params, input_shape, output_shape
target = "llvm -mcpu=skylake"
batch_size = 1
dtype = "float32"
model_name = "resnet-100"
log_file = "%s.log" % model_name
num_threads = 1
os.environ["TVM_NUM_THREADS"] = str(num_threads)
tuning_option = {
'log_filename': log_file,
'tuner': 'random',
'early_stopping': None,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=10, repeat=1,
min_repeat_ms=1000),
),
}
def tune_kernels(tasks,
measure_option,
tuner='gridsearch',
early_stopping=None,
log_filename='tuning.log'):
for i, tsk in enumerate(tasks):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# converting conv2d tasks to conv2d_NCHWc tasks
op_name = tsk.workload[0]
if op_name == 'conv2d':
func_create = 'topi_x86_conv2d_NCHWc'
elif op_name == 'depthwise_conv2d_nchw':
func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
else:
raise ValueError("Tuning {} is not supported on x86".format(op_name))
task = autotvm.task.create(func_create, args=tsk.args,
target=target, template_key='direct')
task.workload = tsk.workload
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(task, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(task, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(task)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(task)
else:
raise ValueError("Invalid tuner: " + tuner)
# do tuning
n_trial=len(task.config_space)
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)])
def tune_and_evaluate(tuning_opt):
# extract workloads from nnvm graph
print("Extract tasks...")
net, params, data_shape, out_shape = get_network(batch_size)
tasks = autotvm.task.extract_from_graph(net, target=target,
shape={'data': data_shape}, dtype=dtype,
symbols=(nnvm.sym.conv2d,))
# run tuning tasks
print("Tuning...")
tune_kernels(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': data_shape}, params=params, dtype=dtype)
# upload parameters to device
ctx = tvm.cpu()
data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
module = runtime.create(graph, lib, ctx)
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=100, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
lib.export_library("./tvm_lib.so")
print('lib export succeefully')
with open("./tvm.json", "w") as fo:
fo.write(graph.json())
with open("./tvm.params", "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
tune_and_evaluate(tuning_option)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment