Created
March 15, 2022 04:25
-
-
Save tiandiao123/4dd6d5d882e9934efc5cfe04960f8738 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import relay | |
import tvm.contrib.graph_runtime as runtime | |
import numpy as np | |
from tvm.contrib.download import download_testdata | |
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt | |
# PyTorch imports | |
import torch | |
import torchvision | |
model_name = "resnet18" | |
model = getattr(torchvision.models, model_name)(pretrained=True) | |
model = model.eval() | |
# We grab the TorchScripted model via tracing | |
input_shape = [1, 3, 224, 224] | |
input_data = torch.randn(input_shape) | |
scripted_model = torch.jit.trace(model, input_data).eval() | |
from PIL import Image | |
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" | |
img_path = download_testdata(img_url, "cat.png", module="data") | |
img = Image.open(img_path).resize((224, 224)) | |
# Preprocess the image and convert to tensor | |
from torchvision import transforms | |
my_preprocess = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
img = my_preprocess(img) | |
img = np.expand_dims(img, 0) | |
input_name = "input0" | |
shape_list = [(input_name, img.shape)] | |
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) | |
target = "cuda" | |
print(mod['main']) | |
use_trt = False | |
if use_trt: | |
mod , config = partition_for_tensorrt(mod, params, remove_no_mac_subgraphs=True) | |
print("after partition using trt ... ") | |
print(mod['main']) | |
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}): | |
lib = relay.build(mod, target=target, params=params) | |
else: | |
with tvm.transform.PassContext(opt_level=3): | |
lib = relay.build(mod, target=target, params=params) | |
ctx = tvm.context(str(target), 0) | |
module = runtime.GraphModule(lib["default"](ctx)) | |
module.set_input("input0", tvm.nd.array(img, ctx)) | |
print("Evaluate inference time cost...") | |
ftimer = module.module.time_evaluator("run", ctx, repeat=10, min_repeat_ms=500) | |
prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond | |
message = "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) | |
print(message) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment