Created
August 23, 2021 21:19
-
-
Save tiandiao123/de52ef96f574645c2dccb3544b291487 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# FP16 TRT command to run : TVM_TENSORRT_USE_FP16=1 python test_trt.py | |
# INT8 TRT command to run : TVM_TENSORRT_USE_INT8=1 TENSORRT_NUM_CALI_INT8=10 python test_trt.py | |
# use tvm branch: https://github.com/tiandiao123/tvm/tree/pr_trt_int8 | |
import tvm | |
from tvm import relay | |
import os | |
from tvm import te | |
import tvm.relay as relay | |
from tvm.contrib.download import download_testdata | |
import onnx | |
import numpy as np | |
# Tensorflow imports | |
import tensorflow as tf | |
import numpy as np | |
try: | |
tf_compat_v1 = tf.compat.v1 | |
except ImportError: | |
tf_compat_v1 = tf | |
# Tensorflow utility functions | |
import tvm.relay.testing.tf as tf_testing | |
model_url = "".join( | |
[ | |
"https://gist.github.com/zhreshold/", | |
"bcda4716699ac97ea44f791c24310193/raw/", | |
"93672b029103648953c4e5ad3ac3aadf346a4cdc/", | |
"super_resolution_0.2.onnx", | |
] | |
) | |
model_path = download_testdata(model_url, "super_resolution.onnx", module="onnx") | |
# now you have super_resolution.onnx on disk | |
onnx_model = onnx.load(model_path) | |
from PIL import Image | |
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" | |
img_path = download_testdata(img_url, "cat.png", module="data") | |
img = Image.open(img_path).resize((224, 224)) | |
img_ycbcr = img.convert("YCbCr") # convert to YCbCr | |
img_y, img_cb, img_cr = img_ycbcr.split() | |
x = np.array(img_y)[np.newaxis, np.newaxis, :, :] | |
input_name = "1" | |
shape_dict = {input_name: x.shape} | |
print("shape: ") | |
print(x.shape) | |
input_shape = x.shape | |
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) | |
# compile the model | |
target = "cuda" | |
dev = tvm.cuda(0) | |
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt | |
mod, config = partition_for_tensorrt(mod, params) | |
print("python script building --------------") | |
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}): | |
lib = relay.build(mod, target=target, params=params) | |
print("python script finsihed building -------------------") | |
dtype = "float32" | |
lib.export_library('compiled.so') | |
loaded_lib = tvm.runtime.load_module('compiled.so') | |
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev)) | |
num_cali_int8 = 0 | |
try: | |
num_cali_int8 = os.environ["TENSORRT_NUM_CALI_INT8"] | |
print("we are going to set {} times calibration in this case".format(num_cali_int8)) | |
except: | |
print("no TENSORRT_NUM_CALI_INT8 found in this case ... ") | |
num_cali_int8 = int(num_cali_int8) | |
if num_cali_int8 != 0: | |
print("calibration steps ... ") | |
for i in range(num_cali_int8): | |
gen_module.run(data=x) | |
print("finished calibration step") | |
print("test run ... ") | |
gen_module.run(data=x) | |
out = gen_module.get_output(0) | |
print(out) | |
epochs = 100 | |
total_time = 0 | |
import time | |
for i in range(epochs): | |
start = time.time() | |
gen_module.run(data=x) | |
end = time.time() | |
total_time += end-start | |
print("the average time is {} ms".format(str(total_time/epochs * 1000))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment