Created
March 6, 2024 13:57
-
-
Save bresilla/982fcd388340c6bfc7ca0e93129054f9 to your computer and use it in GitHub Desktop.
bbox_tensorrt_engine
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from collections import namedtuple | |
from pathlib import Path | |
from typing import List, Optional, Tuple, Union | |
import argparse | |
from pathlib import Path | |
from numpy import ndarray | |
from torch import Tensor | |
import numpy as np | |
import random | |
import tensorrt as trt | |
import torch | |
import cv2 | |
# detection model classes | |
CLASSES = ('outer', 'inner') | |
# colors for per classes | |
COLORS = { | |
cls: [random.randint(0, 255) for _ in range(3)] | |
for i, cls in enumerate(CLASSES) | |
} | |
os.environ['CUDA_MODULE_LOADING'] = 'LAZY' | |
class TRTModule(torch.nn.Module): | |
dtypeMapping = { | |
trt.bool: torch.bool, | |
trt.int8: torch.int8, | |
trt.int32: torch.int32, | |
trt.float16: torch.float16, | |
trt.float32: torch.float32 | |
} | |
def __init__(self, weight: Union[str, Path], | |
device: Optional[torch.device]) -> None: | |
super(TRTModule, self).__init__() | |
self.weight = Path(weight) if isinstance(weight, str) else weight | |
self.device = device if device is not None else torch.device('cuda:0') | |
self.stream = torch.cuda.Stream(device=device) | |
self.__init_engine() | |
self.__init_bindings() | |
def __init_engine(self) -> None: | |
logger = trt.Logger(trt.Logger.WARNING) | |
trt.init_libnvinfer_plugins(logger, namespace='') | |
with trt.Runtime(logger) as runtime: | |
model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) | |
context = model.create_execution_context() | |
num_bindings = model.num_bindings | |
names = [model.get_binding_name(i) for i in range(num_bindings)] | |
self.bindings: List[int] = [0] * num_bindings | |
num_inputs, num_outputs = 0, 0 | |
for i in range(num_bindings): | |
if model.binding_is_input(i): | |
num_inputs += 1 | |
else: | |
num_outputs += 1 | |
self.num_bindings = num_bindings | |
self.num_inputs = num_inputs | |
self.num_outputs = num_outputs | |
self.model = model | |
self.context = context | |
self.input_names = names[:num_inputs] | |
self.output_names = names[num_inputs:] | |
self.idx = list(range(self.num_outputs)) | |
def __init_bindings(self) -> None: | |
idynamic = odynamic = False | |
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape')) | |
inp_info = [] | |
out_info = [] | |
for i, name in enumerate(self.input_names): | |
assert self.model.get_binding_name(i) == name | |
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] | |
shape = tuple(self.model.get_binding_shape(i)) | |
if -1 in shape: | |
idynamic |= True | |
inp_info.append(Tensor(name, dtype, shape)) | |
for i, name in enumerate(self.output_names): | |
i += self.num_inputs | |
assert self.model.get_binding_name(i) == name | |
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] | |
shape = tuple(self.model.get_binding_shape(i)) | |
if -1 in shape: | |
odynamic |= True | |
out_info.append(Tensor(name, dtype, shape)) | |
if not odynamic: | |
self.output_tensor = [ | |
torch.empty(info.shape, dtype=info.dtype, device=self.device) | |
for info in out_info | |
] | |
self.idynamic = idynamic | |
self.odynamic = odynamic | |
self.inp_info = inp_info | |
self.out_info = out_info | |
def set_profiler(self, profiler: Optional[trt.IProfiler]): | |
self.context.profiler = profiler \ | |
if profiler is not None else trt.Profiler() | |
def set_desired(self, desired: Optional[Union[List, Tuple]]): | |
if isinstance(desired, | |
(list, tuple)) and len(desired) == self.num_outputs: | |
self.idx = [self.output_names.index(i) for i in desired] | |
def forward(self, *inputs) -> Union[Tuple, torch.Tensor]: | |
assert len(inputs) == self.num_inputs | |
contiguous_inputs: List[torch.Tensor] = [ | |
i.contiguous() for i in inputs | |
] | |
for i in range(self.num_inputs): | |
self.bindings[i] = contiguous_inputs[i].data_ptr() | |
if self.idynamic: | |
self.context.set_binding_shape( | |
i, tuple(contiguous_inputs[i].shape)) | |
outputs: List[torch.Tensor] = [] | |
for i in range(self.num_outputs): | |
j = i + self.num_inputs | |
if self.odynamic: | |
shape = tuple(self.context.get_binding_shape(j)) | |
output = torch.empty(size=shape, | |
dtype=self.out_info[i].dtype, | |
device=self.device) | |
else: | |
output = self.output_tensor[i] | |
self.bindings[j] = output.data_ptr() | |
outputs.append(output) | |
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream) | |
self.stream.synchronize() | |
return tuple(outputs[i] | |
for i in self.idx) if len(outputs) > 1 else outputs[0] | |
def letterbox(self, im: ndarray, | |
new_shape: Union[Tuple, List] = (640, 640), | |
color: Union[Tuple, List] = (114, 114, 114)) \ | |
-> Tuple[ndarray, float, Tuple[float, float]]: | |
shape = im.shape[:2] | |
if isinstance(new_shape, int): new_shape = (new_shape, new_shape) | |
r = min(new_shape[0] / shape[1], new_shape[1] / shape[0]) | |
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) | |
dw, dh = (new_shape[0] - new_unpad[0]) / 2, (new_shape[1] - new_unpad[1]) / 2 | |
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else im | |
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) | |
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) | |
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) | |
return im, r, (dw, dh) | |
def im_to_blob(self, im: ndarray) -> Union[ndarray, Tuple]: | |
im = im.transpose([2, 0, 1]) | |
im = im[np.newaxis, ...] | |
im = np.ascontiguousarray(im).astype(np.float32) / 255 | |
return im | |
def det_postprocess(self, data: Tuple[Tensor, Tensor, Tensor, Tensor]): | |
assert len(data) == 4 | |
num_dets, bboxes, scores, labels = data[0][0], data[1][0], data[2][ | |
0], data[3][0] | |
nums = num_dets.item() | |
if nums == 0: | |
return bboxes.new_zeros((0, 4)), scores.new_zeros( | |
(0, )), labels.new_zeros((0, )) | |
bboxes = bboxes[:nums] | |
scores = scores[:nums] | |
labels = labels[:nums] | |
return bboxes, scores, labels | |
def main(args: argparse.Namespace) -> None: | |
device = torch.device("cuda:0") | |
engine = TRTModule(args.engine, device) | |
H, W = engine.inp_info[0].shape[-2:] | |
# set desired output names order | |
engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels']) | |
# Open the video file | |
cap = cv2.VideoCapture(args.video) | |
while True: | |
ret, frame = cap.read() | |
if not ret: break | |
# Process the frame | |
draw = frame.copy() | |
bgr, ratio, dwdh = engine.letterbox(frame, (W, H)) | |
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
tensor = engine.im_to_blob(rgb) | |
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device) | |
tensor = torch.asarray(tensor, device=device) | |
# Inference | |
data = engine(tensor) | |
print(data) | |
bboxes, scores, labels = engine.det_postprocess(data) | |
if bboxes.numel() > 0: | |
bboxes -= dwdh | |
bboxes /= ratio | |
for (bbox, score, label) in zip(bboxes, scores, labels): | |
bbox = bbox.round().int().tolist() | |
cls_id = int(label) | |
cls = CLASSES[cls_id] | |
color = COLORS[cls] | |
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2) | |
cv2.putText(draw, | |
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.75, [225, 255, 255], | |
thickness=2) | |
cv2.imshow('result', draw) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
cap.release() | |
cv2.destroyAllWindows() | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--engine', type=str, help='Engine file') | |
parser.add_argument('--video', type=str, help='Path to video file') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment