Skip to content

Instantly share code, notes, and snippets.

@comaniac
Created November 18, 2020 23:44
Show Gist options
  • Save comaniac/7c1a6be57c46d81a447051ceca47be37 to your computer and use it in GitHub Desktop.
Save comaniac/7c1a6be57c46d81a447051ceca47be37 to your computer and use it in GitHub Desktop.
"""BYOC Demo using TensorRT."""
# pylint: disable=invalid-name,redefined-outer-name,missing-function-docstring
# config.cmake
# set(USE_TENSORRT_CODEGEN ON)
# set(USE_TENSORRT_RUNTIME ON)
# Add TensorRT to LD_LIBRARY_PATH if use tarball.
# export LD_LIBRARY_PATH=/path/to/tensorrt/lib:$LD_LIBRARY_PATH
import time
from gluoncv import data as gcv_data, model_zoo
from matplotlib import pyplot as plt
import mxnet as mx
import numpy as np
import tvm
from tvm import relay
from tvm.relay.backend import compile_engine
from tvm.contrib import graph_runtime
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib import tensorrt
# Prepare the input image
im_fname = download_testdata(
"https://github.com/dmlc/web-data/blob/master/gluoncv/detection/street_small.jpg?raw=true",
"street_small.jpg",
module="data",
)
# Prepare the SSD model
def get_ssd_model(model_name, image_size=512):
# Setup model
input_name = "data"
input_shape = (1, 3, image_size, image_size)
# Prepare model input data
data, img = gcv_data.transforms.presets.ssd.load_test(im_fname, short=image_size)
# Prepare SSD model
block = model_zoo.get_model(model_name, pretrained=True)
block.hybridize()
block.forward(data)
block.export("temp")
model_json = mx.symbol.load("temp-symbol.json")
save_dict = mx.ndarray.load("temp-0000.params")
arg_params = {}
aux_params = {}
for param, val in save_dict.items():
param_type, param_name = param.split(":", 1)
if param_type == "arg":
arg_params[param_name] = val
elif param_type == "aux":
aux_params[param_name] = val
# Convert the MXNet SSD model to Relay module
mod, params = relay.frontend.from_mxnet(
model_json, {input_name: input_shape}, arg_params=arg_params, aux_params=aux_params
)
return mod, params, block.classes, data.asnumpy(), img
# Get the SSD model
mod, params, class_names, data, img = get_ssd_model("ssd_512_resnet50_v1_coco")
# Show the original Relay module
print(mod.astext(show_meta_data=False))
# Build the module and perform inference
def build_and_run(mod, data, params, build_config=None):
compile_engine.get().clear()
with tvm.transform.PassContext(opt_level=3, config=build_config):
lib = relay.build(mod, target="cuda", params=params)
# Create the runtime module
mod = graph_runtime.GraphModule(lib["default"](tvm.gpu(0)))
# Run inference 10 times
times = []
for _ in range(10):
start = time.time()
mod.run(data=data)
times.append(time.time() - start)
print("Median inference latency %.2f ms" % (1000 * np.median(times)))
return mod
# Reference: Build and run the model on GPU without TensorRT
build_and_run(mod, data, params)
# Partitioning for TRT (no prune)
trt_mod, config = tensorrt.partition_for_tensorrt(mod, params)
# Show the partitioned Relay main function with 10 subgraphs
print(trt_mod["main"].astext(show_meta_data=False))
# Show the subgraph with no complex ops.
print(trt_mod["tensorrt_350"].astext(show_meta_data=False))
# Prune the subgraph (accelerator-specific optimization)
config["remove_no_mac_subgraphs"] = True
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
trt_mod = tensorrt.prune_tensorrt_subgraphs(trt_mod)
# Show the partitioned Relay main function with only 1 subgraph
print(trt_mod["main"].astext(show_meta_data=False))
# Show the subgraph function for TRT
print(trt_mod["tensorrt_0"].astext(show_meta_data=False))
# Show the paritioned Relay module
print(config)
# Build and run with TensorRT
runtime_mod = build_and_run(
trt_mod, data, params, build_config={"relay.ext.tensorrt.options": config}
)
# Display results
results = [runtime_mod.get_output(i).asnumpy() for i in range(runtime_mod.get_num_outputs())]
# plt.switch_backend("agg")
# ax = utils.viz.plot_bbox(
# img, results[2][0], results[1][0], results[0][0], class_names=class_names
# )
# plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment