Skip to content

Instantly share code, notes, and snippets.

@bresilla
Created March 6, 2024 13:57
Show Gist options
  • Save bresilla/982fcd388340c6bfc7ca0e93129054f9 to your computer and use it in GitHub Desktop.
Save bresilla/982fcd388340c6bfc7ca0e93129054f9 to your computer and use it in GitHub Desktop.
bbox_tensorrt_engine
import os
from collections import namedtuple
from pathlib import Path
from typing import List, Optional, Tuple, Union
import argparse
from pathlib import Path
from numpy import ndarray
from torch import Tensor
import numpy as np
import random
import tensorrt as trt
import torch
import cv2
# detection model classes
CLASSES = ('outer', 'inner')
# colors for per classes
COLORS = {
cls: [random.randint(0, 255) for _ in range(3)]
for i, cls in enumerate(CLASSES)
}
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
class TRTModule(torch.nn.Module):
dtypeMapping = {
trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32
}
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]) -> None:
super(TRTModule, self).__init__()
self.weight = Path(weight) if isinstance(weight, str) else weight
self.device = device if device is not None else torch.device('cuda:0')
self.stream = torch.cuda.Stream(device=device)
self.__init_engine()
self.__init_bindings()
def __init_engine(self) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
num_bindings = model.num_bindings
names = [model.get_binding_name(i) for i in range(num_bindings)]
self.bindings: List[int] = [0] * num_bindings
num_inputs, num_outputs = 0, 0
for i in range(num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.num_bindings = num_bindings
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
self.idx = list(range(self.num_outputs))
def __init_bindings(self) -> None:
idynamic = odynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape'))
inp_info = []
out_info = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
idynamic |= True
inp_info.append(Tensor(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
odynamic |= True
out_info.append(Tensor(name, dtype, shape))
if not odynamic:
self.output_tensor = [
torch.empty(info.shape, dtype=info.dtype, device=self.device)
for info in out_info
]
self.idynamic = idynamic
self.odynamic = odynamic
self.inp_info = inp_info
self.out_info = out_info
def set_profiler(self, profiler: Optional[trt.IProfiler]):
self.context.profiler = profiler \
if profiler is not None else trt.Profiler()
def set_desired(self, desired: Optional[Union[List, Tuple]]):
if isinstance(desired,
(list, tuple)) and len(desired) == self.num_outputs:
self.idx = [self.output_names.index(i) for i in desired]
def forward(self, *inputs) -> Union[Tuple, torch.Tensor]:
assert len(inputs) == self.num_inputs
contiguous_inputs: List[torch.Tensor] = [
i.contiguous() for i in inputs
]
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.idynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
outputs: List[torch.Tensor] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.odynamic:
shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(size=shape,
dtype=self.out_info[i].dtype,
device=self.device)
else:
output = self.output_tensor[i]
self.bindings[j] = output.data_ptr()
outputs.append(output)
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
self.stream.synchronize()
return tuple(outputs[i]
for i in self.idx) if len(outputs) > 1 else outputs[0]
def letterbox(self, im: ndarray,
new_shape: Union[Tuple, List] = (640, 640),
color: Union[Tuple, List] = (114, 114, 114)) \
-> Tuple[ndarray, float, Tuple[float, float]]:
shape = im.shape[:2]
if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
r = min(new_shape[0] / shape[1], new_shape[1] / shape[0])
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = (new_shape[0] - new_unpad[0]) / 2, (new_shape[1] - new_unpad[1]) / 2
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else im
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return im, r, (dw, dh)
def im_to_blob(self, im: ndarray) -> Union[ndarray, Tuple]:
im = im.transpose([2, 0, 1])
im = im[np.newaxis, ...]
im = np.ascontiguousarray(im).astype(np.float32) / 255
return im
def det_postprocess(self, data: Tuple[Tensor, Tensor, Tensor, Tensor]):
assert len(data) == 4
num_dets, bboxes, scores, labels = data[0][0], data[1][0], data[2][
0], data[3][0]
nums = num_dets.item()
if nums == 0:
return bboxes.new_zeros((0, 4)), scores.new_zeros(
(0, )), labels.new_zeros((0, ))
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels
def main(args: argparse.Namespace) -> None:
device = torch.device("cuda:0")
engine = TRTModule(args.engine, device)
H, W = engine.inp_info[0].shape[-2:]
# set desired output names order
engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels'])
# Open the video file
cap = cv2.VideoCapture(args.video)
while True:
ret, frame = cap.read()
if not ret: break
# Process the frame
draw = frame.copy()
bgr, ratio, dwdh = engine.letterbox(frame, (W, H))
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = engine.im_to_blob(rgb)
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device)
tensor = torch.asarray(tensor, device=device)
# Inference
data = engine(tensor)
print(data)
bboxes, scores, labels = engine.det_postprocess(data)
if bboxes.numel() > 0:
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().int().tolist()
cls_id = int(label)
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
cv2.imshow('result', draw)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--video', type=str, help='Path to video file')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment