Last active
August 23, 2021 21:15
-
-
Save tiandiao123/67adb11ab3d73df8e83a1469707d7db4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# FP16 TRT command to run : TVM_TENSORRT_USE_FP16=1 python test_trt.py | |
# INT8 TRT command to run : TVM_TENSORRT_USE_INT8=1 TENSORRT_NUM_CALI_INT8=10 python test_trt.py | |
# use tvm branch: https://github.com/tiandiao123/tvm/tree/pr_trt_int8 | |
import tvm | |
from tvm import relay | |
import numpy as np | |
from tvm.contrib.download import download_testdata | |
import os | |
# PyTorch imports | |
import torch | |
import torchvision | |
import numpy as np | |
import cv2 | |
# PyTorch imports | |
import torch | |
import torchvision | |
model_name = "resnet18" | |
model = getattr(torchvision.models, model_name)(pretrained=True) | |
model = model.eval() | |
# We grab the TorchScripted model via tracing | |
input_shape = [1, 3, 224, 224] | |
input_data = torch.randn(input_shape) | |
scripted_model = torch.jit.trace(model, input_data).eval() | |
from PIL import Image | |
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" | |
img_path = download_testdata(img_url, "cat.png", module="data") | |
img = Image.open(img_path).resize((224, 224)) | |
# Preprocess the image and convert to tensor | |
from torchvision import transforms | |
my_preprocess = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
img = my_preprocess(img) | |
img = np.expand_dims(img, 0) | |
input_name = "input0" | |
shape_list = [(input_name, img.shape)] | |
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) | |
# compile the model | |
target = "cuda" | |
dev = tvm.cuda(0) | |
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt | |
mod, config = partition_for_tensorrt(mod, params) | |
print("python script started building --------------") | |
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}): | |
lib = relay.build(mod, target=target, params=params) | |
print("python script finsihed building -------------------") | |
dtype = "float32" | |
lib.export_library('compiled.so') | |
loaded_lib = tvm.runtime.load_module('compiled.so') | |
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev)) | |
num_cali_int8 = 0 | |
try: | |
num_cali_int8 = os.environ["TENSORRT_NUM_CALI_INT8"] | |
print("we are going to set {} times calibration in this case".format(num_cali_int8)) | |
except: | |
print("no TENSORRT_NUM_CALI_INT8 found in this case ... ") | |
num_cali_int8 = int(num_cali_int8) | |
if num_cali_int8 != 0: | |
print("calibration steps ... ") | |
for i in range(num_cali_int8): | |
gen_module.run(data=img) | |
print("finished calibration step") | |
print("test run ... ") | |
gen_module.run(data=img) | |
out = gen_module.get_output(0) | |
print(out) | |
epochs = 100 | |
total_time = 0 | |
import time | |
for i in range(epochs): | |
start = time.time() | |
gen_module.run(data=img) | |
end = time.time() | |
total_time += end-start | |
print("the average time is {} ms".format(str(total_time/epochs * 1000))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment