Last active
June 12, 2019 01:33
-
-
Save yoku001/044794f4fa6fec2ea7160de2f6445e5c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"\n", | |
"import numpy as np\n", | |
"\n", | |
"def download(url, path, overwrite=False):\n", | |
" import os\n", | |
" if os.path.isfile(path) and not overwrite:\n", | |
" print('File {} exists, skip.'.format(path))\n", | |
" return\n", | |
" print('Downloading from url {} to {}'.format(url, path))\n", | |
" try:\n", | |
" import urllib.request\n", | |
" urllib.request.urlretrieve(url, path)\n", | |
" except:\n", | |
" import urllib\n", | |
" urllib.urlretrieve(url, path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import keras\n", | |
"model = keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from PIL import Image\n", | |
"from matplotlib import pyplot as plt\n", | |
"from keras.applications.resnet50 import preprocess_input\n", | |
"\n", | |
"# prepare data\n", | |
"img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'\n", | |
"download(img_url, 'cat.png')\n", | |
"img = Image.open('cat.png').resize((224, 224))\n", | |
"plt.imshow(img)\n", | |
"plt.show()\n", | |
"\n", | |
"# input preprocess\n", | |
"data = np.array(img)[np.newaxis, :].astype('float32')\n", | |
"data = preprocess_input(data).transpose([0, 3, 1, 2])\n", | |
"print('input_1', data.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tvm\n", | |
"import tvm.relay as relay\n", | |
"\n", | |
"input_name = 'input_1'\n", | |
"shape_dict = {input_name: data.shape}\n", | |
"func, params = relay.frontend.from_keras(model, shape_dict)\n", | |
"\n", | |
"with relay.quantize.qconfig(global_scale=8.0):\n", | |
" func_quant = relay.quantize.quantize(func, params)\n", | |
" print(str(func_quant))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"target = 'llvm'\n", | |
"\n", | |
"with relay.build_config(opt_level=0):\n", | |
" modules = relay.build_module.build(func, target, params=params)\n", | |
" modules_quant = relay.build_module.build(func_quant, target, params=params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def predict_tvm(modules, data):\n", | |
" # create module\n", | |
" graph, lib, params = modules\n", | |
" module = tvm.contrib.graph_runtime.create(graph, lib, tvm.cpu())\n", | |
"\n", | |
" # set input and parameters\n", | |
" module.set_input(input_name, tvm.nd.array(data.astype(np.float32)))\n", | |
" module.set_input(**params)\n", | |
" \n", | |
" # get output\n", | |
" module.run()\n", | |
" return module.get_output(0, tvm.nd.empty((1, 1000))).asnumpy()\n", | |
"\n", | |
"def show_top5_accuracy(y_pred, synset):\n", | |
" top5_ids = y_pred.flatten().argsort()[::-1][:5]\n", | |
"\n", | |
" for i, c in enumerate(top5_ids):\n", | |
" print(f'{i+1} : {c} {synset[c]}')\n", | |
"\n", | |
"# get ImageNet synset dictionary\n", | |
"synset_url = 'https://gist.githubusercontent.com/zhreshold/' \\\n", | |
" '4d0b62f3d01426887599d4f7ede23ee5/raw/' \\\n", | |
" '596b27d23537e5a1b5751d2b0481ef172f58b539/' \\\n", | |
" 'imagenet1000_clsid_to_human.txt'\n", | |
"synset_name = 'synset.txt'\n", | |
"download(synset_url, synset_name)\n", | |
"\n", | |
"with open(synset_name) as f:\n", | |
" synset = eval(f.read()) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_pred_tvm = predict_tvm(modules, data)\n", | |
"y_pred_tvm_quant = predict_tvm(modules_quant, data)\n", | |
"\n", | |
"print('\\nw/o quantization')\n", | |
"show_top5_accuracy(y_pred_tvm, synset)\n", | |
"\n", | |
"print('\\nwith quantization')\n", | |
"show_top5_accuracy(y_pred_tvm_quant, synset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# save weight\n", | |
"model.save_weights('weight_keras.h5')\n", | |
"\n", | |
"params = modules[2]\n", | |
"with open(\"deploy_param.params\", \"wb\") as f:\n", | |
" f.write(relay.save_param_dict(params))\n", | |
"\n", | |
"params_quant = modules_quant[2]\n", | |
"with open(\"deploy_param_quant.params\", \"wb\") as f:\n", | |
" f.write(relay.save_param_dict(params_quant))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"graph, lib, params = modules\n", | |
"module = tvm.contrib.graph_runtime.create(graph, lib, tvm.cpu())\n", | |
"\n", | |
"# set input and parameters\n", | |
"module.set_input(input_name, tvm.nd.array(data.astype(np.float32)))\n", | |
"module.set_input(**params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"graph, lib, params = modules_quant\n", | |
"module_quant = tvm.contrib.graph_runtime.create(graph, lib, tvm.cpu())\n", | |
"\n", | |
"# set input and parameters\n", | |
"module_quant.set_input(input_name, tvm.nd.array(data.astype(np.float32)))\n", | |
"module_quant.set_input(**params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"print('w/o quantization')\n", | |
"%timeit -n10 module.run()\n", | |
"\n", | |
"print('\\nwith quantization')\n", | |
"%timeit -n10 module_quant.run()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment