Last active
January 10, 2022 10:28
-
-
Save tamnguyenvan/c3addaab08b8bfe6a2dcd9f3386e0d63 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """trt_yolo.py | |
| This script demonstrates how to do real-time object detection with | |
| TensorRT optimized YOLO engine. | |
| """ | |
| import os | |
| import time | |
| import argparse | |
| import cv2 | |
| import pycuda.autoinit # This is needed for initializing CUDA driver | |
| from utils.yolo_classes import get_cls_dict | |
| from utils.camera import add_camera_args, Camera | |
| from utils.display import open_window, set_display, show_fps | |
| from utils.visualization import BBoxVisualization | |
| from utils.yolo_with_plugins import TrtYOLO | |
| WINDOW_NAME = 'TrtYOLODemo' | |
| def parse_args(): | |
| """Parse input arguments.""" | |
| desc = ('Capture and display live camera video, while doing ' | |
| 'real-time object detection with TensorRT optimized ' | |
| 'YOLO model on Jetson') | |
| parser = argparse.ArgumentParser(description=desc) | |
| parser = add_camera_args(parser) | |
| parser.add_argument( | |
| '-c', '--category_num', type=int, default=80, | |
| help='number of object categories [80]') | |
| parser.add_argument( | |
| '-m', '--model', type=str, required=True, | |
| help=('[yolov3-tiny|yolov3|yolov3-spp|yolov4-tiny|yolov4|' | |
| 'yolov4-csp|yolov4x-mish]-[{dimension}], where ' | |
| '{dimension} could be either a single number (e.g. ' | |
| '288, 416, 608) or 2 numbers, WxH (e.g. 416x256)')) | |
| parser.add_argument( | |
| '-l', '--letter_box', action='store_true', | |
| help='inference with letterboxed image [False]') | |
| args = parser.parse_args() | |
| return args | |
| def loop_and_detect(cam, trt_yolo, conf_th, vis): | |
| """Continuously capture images from camera and do object detection. | |
| # Arguments | |
| cam: the camera instance (video source). | |
| trt_yolo: the TRT YOLO object detector instance. | |
| conf_th: confidence/score threshold for object detection. | |
| vis: for visualization. | |
| """ | |
| full_scrn = False | |
| fps = 0.0 | |
| tic = time.time() | |
| while True: | |
| # if cv2.getWindowProperty(WINDOW_NAME, 0) < 0: | |
| # break | |
| img = cam.read() | |
| if img is None: | |
| break | |
| boxes, confs, clss = trt_yolo.detect(img, conf_th) | |
| img = vis.draw_bboxes(img, boxes, confs, clss) | |
| img = show_fps(img, fps) | |
| # cv2.imshow(WINDOW_NAME, img) | |
| cv2.imwrite('output.png', img) | |
| toc = time.time() | |
| curr_fps = 1.0 / (toc - tic) | |
| # calculate an exponentially decaying average of fps number | |
| fps = curr_fps if fps == 0.0 else (fps*0.95 + curr_fps*0.05) | |
| tic = toc | |
| break | |
| # key = cv2.waitKey(1) | |
| # if key == 27: # ESC key: quit program | |
| # break | |
| # elif key == ord('F') or key == ord('f'): # Toggle fullscreen | |
| # full_scrn = not full_scrn | |
| # set_display(WINDOW_NAME, full_scrn) | |
| def main(): | |
| args = parse_args() | |
| if args.category_num <= 0: | |
| raise SystemExit('ERROR: bad category_num (%d)!' % args.category_num) | |
| if not os.path.isfile('yolo/%s.trt' % args.model): | |
| raise SystemExit('ERROR: file (yolo/%s.trt) not found!' % args.model) | |
| cam = Camera(args) | |
| if not cam.isOpened(): | |
| raise SystemExit('ERROR: failed to open camera!') | |
| cls_dict = get_cls_dict(args.category_num) | |
| vis = BBoxVisualization(cls_dict) | |
| input_shape = list(map(int, args.model.split('-')[-1].split('x'))) | |
| if len(input_shape) == 1: | |
| input_shape = [input_shape[0], input_shape[0]] | |
| input_shape = input_shape[:2] | |
| trt_yolo = TrtYOLO(args.model, input_shape, args.category_num, args.letter_box) | |
| # open_window( | |
| # WINDOW_NAME, 'Camera TensorRT YOLO Demo', | |
| # cam.img_width, cam.img_height) | |
| loop_and_detect(cam, trt_yolo, conf_th=0.3, vis=vis) | |
| cam.release() | |
| cv2.destroyAllWindows() | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment