Last active
June 16, 2024 14:56
-
-
Save aliencaocao/9a2ff385a3ebe15885f20f93eab04542 to your computer and use it in GitHub Desktop.
TensorRT Inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple, OrderedDict | |
import numpy as np | |
import tensorrt as trt | |
import pycuda.driver as cuda | |
assert trt.__version__.split('.')[0] >= '10', 'TensorRT version >= 10 is required.' | |
class TRTInference: | |
def __init__(self, engine_path: str, output_names_mapping: dict = None, verbose: bool = False): | |
cuda.init() | |
self.device_ctx = cuda.Device(0).make_context() | |
self.engine_path = engine_path | |
self.output_names_mapping = output_names_mapping or {} | |
self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO) | |
self.engine = None | |
self.load_engine() | |
assert self.engine is not None, 'Failed to load TensorRT engine.' | |
self.context = self.engine.create_execution_context() | |
self.stream = cuda.Stream() | |
self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items()) | |
def load_engine(self): | |
with open(self.engine_path, 'rb') as f, trt.Runtime(self.logger) as runtime: | |
self.engine = runtime.deserialize_cuda_engine(f.read()) | |
@property | |
def input_names(self): | |
names = [] | |
for _, name in enumerate(self.engine): | |
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: | |
names.append(name) | |
return names | |
@property | |
def output_names(self): | |
names = [] | |
for _, name in enumerate(self.engine): | |
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: | |
names.append(name) | |
return names | |
@property | |
def bindings(self): | |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | |
bindings = OrderedDict() | |
for i, name in enumerate(self.engine): | |
shape = self.engine.get_tensor_shape(name) | |
dtype = trt.nptype(self.engine.get_tensor_dtype(name)) | |
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: | |
data = np.random.randn(*shape).astype(dtype) | |
ptr = cuda.mem_alloc(data.nbytes) # TODO: fix compat with numpy 2.0 | |
self.context.set_input_shape(name, shape) | |
self.context.set_tensor_address(name, ptr) # input buffer | |
bindings[name] = Binding(name, dtype, shape, data, ptr) | |
else: | |
data = cuda.pagelocked_empty(trt.volume(shape), dtype) | |
ptr = cuda.mem_alloc(data.nbytes) | |
self.context.set_tensor_address(name, ptr) # output buffer | |
bindings[name] = Binding(name, dtype, shape, data, ptr) | |
return bindings | |
def __call__(self, inputs: dict[str, np.ndarray]): | |
inputs = {n: np.ascontiguousarray(v) for n, v in inputs.items() if n in self.input_names} | |
for n in self.input_names: | |
cuda.memcpy_htod_async(self.bindings_addr[n], inputs[n], self.stream) | |
assert self.context.all_binding_shapes_specified | |
self.context.execute_async_v3(stream_handle=self.stream.handle) | |
outputs = {} | |
for n in self.output_names: | |
cuda.memcpy_dtoh_async(self.bindings[n].data, self.bindings_addr[n], self.stream) | |
o = self.bindings[n].data | |
# reshape to correct output shape | |
if o.shape != self.bindings[n].shape: | |
o = o.reshape(self.bindings[n].shape) | |
outputs[self.output_names_mapping.get(n, n)] = o | |
self.stream.synchronize() | |
return outputs | |
def predict(self, inputs: dict[str, np.ndarray]): | |
return self(inputs) | |
def warmup(self, inputs: dict[str, np.ndarray] = None, n: int = 50): | |
"""Run inference for n iterations to warmup the engine. If no sample input is provided, random inputs are used.""" | |
inputs = inputs or {n: np.random.randn(*self.bindings[n].shape).astype(self.bindings[n].dtype) for n in self.input_names} | |
for _ in range(n): | |
self(inputs) | |
def __del__(self): | |
try: | |
self.device_ctx.pop() | |
except cuda.LogicError: | |
pass | |
if __name__ == '__main__': | |
trt_inference = TRTInference('model.engine', verbose=True) | |
dummy_input = {n: np.random.randn(*trt_inference.bindings[n].shape).astype(trt_inference.bindings[n].dtype) for n in trt_inference.input_names} | |
trt_inference.warmup() | |
print(trt_inference(dummy_input)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment