Last active
November 2, 2018 03:56
-
-
Save yzhliu/fe28f0e9add3cbd70dda4c7cd30c1275 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import nnvm.testing | |
import nnvm.compiler | |
import tvm | |
from tvm import autotvm | |
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner | |
from tvm.contrib.util import tempdir | |
import tvm.contrib.graph_runtime as runtime | |
import logging | |
logging.basicConfig() | |
def get_network(name, batch_size): | |
"""Get the symbol definition and random weight of a network""" | |
input_shape = (batch_size, 3, 224, 224) | |
output_shape = (batch_size, 1000) | |
if "resnet" in name: | |
n_layer = int(name.split('-')[1]) | |
net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size) | |
elif "vgg" in name: | |
n_layer = int(name.split('-')[1]) | |
net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size) | |
elif name == 'mobilenet': | |
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) | |
elif name == 'squeezenet_v1.1': | |
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1') | |
elif name == 'inception_v3': | |
input_shape = (1, 3, 299, 299) | |
net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size) | |
elif name == 'custom': | |
# an example for custom network | |
from nnvm.testing import utils | |
net = nnvm.sym.Variable('data') | |
net = nnvm.sym.conv2d(net, channels=4, kernel_size=(3,3), padding=(1,1)) | |
net = nnvm.sym.flatten(net) | |
net = nnvm.sym.dense(net, units=1000) | |
net, params = utils.create_workload(net, batch_size, (3, 224, 224)) | |
elif name == 'mxnet': | |
# an example for mxnet model | |
from mxnet.gluon.model_zoo.vision import get_model | |
block = get_model('resnet18_v1', pretrained=True) | |
net, params = nnvm.frontend.from_mxnet(block) | |
net = nnvm.sym.softmax(net) | |
else: | |
raise ValueError("Unsupported network: " + name) | |
return net, params, input_shape, output_shape | |
#### DEVICE CONFIG #### | |
# Replace "aarch64-linux-gnu" with the correct target of your board. | |
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device. | |
target = tvm.target.create("llvm -mcpu=skylake-avx512") | |
#### TUNING OPTION #### | |
network = 'mobilenet' | |
log_file = "%s.log" % network | |
dtype = 'float32' | |
num_threads = 18 | |
os.environ["TVM_NUM_THREADS"] = str(num_threads) | |
tuning_option = { | |
'log_filename': log_file, | |
'tuner': 'xgb', | |
'n_trial': 1000, | |
'early_stopping': 400, | |
'measure_option': autotvm.measure_option( | |
builder=autotvm.LocalBuilder(), | |
runner=autotvm.LocalRunner(number=10, repeat=1, | |
min_repeat_ms=1000), | |
), | |
} | |
# You can skip the implementation of this function for this tutorial. | |
def tune_tasks(tasks, | |
measure_option, | |
tuner='xgb', | |
n_trial=1000, | |
early_stopping=None, | |
log_filename='tuning.log', | |
use_transfer_learning=True, | |
try_winograd=True): | |
if try_winograd: | |
for i in range(len(tasks)): | |
try: # try winograd template | |
tsk = autotvm.task.create(tasks[i].name, tasks[i].args, | |
tasks[i].target, tasks[i].target_host, 'winograd') | |
input_channel = tsk.workload[1][1] | |
if input_channel >= 64: | |
tasks[i] = tsk | |
except Exception: | |
pass | |
# create tmp log file | |
tmp_log_file = log_filename + ".tmp" | |
if os.path.exists(tmp_log_file): | |
os.remove(tmp_log_file) | |
for i, tsk in enumerate(reversed(tasks)): | |
prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) | |
# create tuner | |
if tuner == 'xgb' or tuner == 'xgb-rank': | |
# XGBTuner(tsk, loss_type='rank', feature_type='knob'). | |
tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob') | |
elif tuner == 'ga': | |
tuner_obj = GATuner(tsk, pop_size=50) | |
elif tuner == 'random': | |
tuner_obj = RandomTuner(tsk) | |
elif tuner == 'gridsearch': | |
tuner_obj = GridSearchTuner(tsk) | |
else: | |
raise ValueError("Invalid tuner: " + tuner) | |
if use_transfer_learning: | |
if os.path.isfile(tmp_log_file): | |
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) | |
# do tuning | |
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)), | |
early_stopping=early_stopping, | |
measure_option=measure_option, | |
callbacks=[ | |
autotvm.callback.progress_bar(n_trial, prefix=prefix), | |
autotvm.callback.log_to_file(tmp_log_file)]) | |
# pick best records to a cache file | |
autotvm.record.pick_best(tmp_log_file, log_filename) | |
os.remove(tmp_log_file) | |
def tune_and_evaluate(tuning_opt): | |
# extract workloads from nnvm graph | |
print("Extract tasks...") | |
net, params, input_shape, out_shape = get_network(network, batch_size=1) | |
tasks = autotvm.task.extract_from_graph(net, target=target, | |
shape={'data': input_shape}, dtype=dtype, | |
symbols=(nnvm.sym.conv2d,)) | |
# run tuning tasks | |
print("Tuning...") | |
tune_tasks(tasks, **tuning_opt) | |
# compile kernels with history best records | |
with autotvm.apply_history_best(log_file): | |
print("Compile...") | |
with nnvm.compiler.build_config(opt_level=3): | |
graph, lib, params = nnvm.compiler.build( | |
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype) | |
ctx = tvm.cpu() | |
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) | |
module = runtime.create(graph, lib, ctx) | |
module.set_input('data', data_tvm) | |
module.set_input(**params) | |
# evaluate | |
print("Evaluate inference time cost...") | |
ftimer = module.module.time_evaluator("run", ctx, number=100, repeat=3) | |
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond | |
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % | |
(np.mean(prof_res), np.std(prof_res))) | |
# We do not run the tuning in our webpage server since it takes too long. | |
# Uncomment the following line to run it by yourself. | |
tune_and_evaluate(tuning_option) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment