Created
March 14, 2022 13:03
-
-
Save jnettlet/9ec0732fec4703a58d25ad4cf26a2491 to your computer and use it in GitHub Desktop.
Quick script to get full performance using the NN accelerator on the iMX8MP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import argparse | |
import cv2 | |
import numpy as np | |
import signal | |
import sys | |
import time | |
from threading import Thread | |
import importlib.util | |
class VideoGet: | |
""" | |
Class that continuously gets frames from a VideoCapture object | |
with a dedicated thread. | |
""" | |
def v4l2_gstreamer_pipeline( | |
self, | |
src="/dev/video0", | |
display_width=640, | |
display_height=480, | |
): | |
return ( | |
"v4l2src device=%s ! " | |
"decodebin ! " | |
"imxvideoconvert_g2d ! " | |
"video/x-raw, width=(int)%d, height=(int)%d, format=(string)BGRx ! " | |
"videoconvert n-threads=3 ! " | |
"video/x-raw, format=(string)BGR ! " | |
"appsink " | |
% ( | |
src, | |
display_width, | |
display_height, | |
) | |
) | |
def video_gstreamer_pipeline( | |
self, | |
src="video.mp4", | |
display_width=720, | |
display_height=480, | |
): | |
return ( | |
"filesrc location=%s ! " | |
"decodebin ! " | |
"imxvideoconvert_g2d ! " | |
"video/x-raw, width=(int)%d, height=(int)%d, format=(string)BGRx ! " | |
"videoconvert n-threads=3 ! " | |
"video/x-raw, format=(string)BGR ! " | |
"appsink " | |
% ( | |
src, | |
display_width, | |
display_height, | |
) | |
) | |
def process_frame(self): | |
(self.grabbed, tmpframe) = self.stream.read() | |
if np.shape(tmpframe) != (): | |
frame_rgb = cv2.cvtColor(tmpframe, cv2.COLOR_BGR2RGB) | |
self.tensor_frame = cv2.resize(frame_rgb, (self.input_width, self.input_height)) | |
self.frame = tmpframe | |
def __init__(self, src, width, height, input_width, input_height): | |
self.input_width = input_width | |
self.input_height = input_height | |
self.grabbed = 0 | |
# Initialize video stream | |
if src.startswith("/dev/video"): | |
self.stream = cv2.VideoCapture(self.v4l2_gstreamer_pipeline(src,width,height), cv2.CAP_GSTREAMER) | |
else: | |
self.stream = cv2.VideoCapture(self.video_gstreamer_pipeline(src,width,height), cv2.CAP_GSTREAMER) | |
self.process_frame() | |
self.stopped = False | |
def start(self): | |
if self.stream.isOpened(): | |
Thread(target=self.get, args=()).start() | |
return self | |
def get(self): | |
while not self.stopped: | |
if not self.grabbed: | |
self.stop() | |
else: | |
self.process_frame() | |
def stop(self): | |
self.stopped = True | |
# Define and parse input arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--modeldir', help='Folder the .tflite file is located in', | |
required=True) | |
parser.add_argument('--graph', help='Name of the .tflite file, if different than detect.tflite', | |
default='detect.tflite') | |
parser.add_argument('--labels', help='Name of the labelmap file, if different than labelmap.txt', | |
default='labelmap.txt') | |
parser.add_argument('--threshold', help='Minimum confidence threshold for displaying detected objects', | |
default=0.5) | |
parser.add_argument('--resolution', help='Desired webcam resolution in WxH. If the webcam does not support the resolution entered, errors may occur.', | |
default='720x480') | |
parser.add_argument('--openvx', help='Use OpenVX for OPENVX/GPU Acceleration to speed up detection', | |
action='store_true') | |
parser.add_argument('--video', help='Input source to analyze and display', | |
default='video.mp4') | |
parser.add_argument('--fullscreen', help='Display output fullscreen', | |
action='store_true') | |
args = parser.parse_args() | |
MODEL_NAME = args.modeldir | |
GRAPH_NAME = args.graph | |
LABELMAP_NAME = args.labels | |
min_conf_threshold = float(args.threshold) | |
resW, resH = args.resolution.split('x') | |
imW, imH = int(resW), int(resH) | |
use_OPENVX = args.openvx | |
OUTPUT_FULLSCREEN = args.fullscreen | |
VIDEO_FILE = args.video | |
# Get path to current working directory | |
CWD_PATH = os.getcwd() | |
# Import TensorFlow libraries | |
# If tflite_runtime is installed, import interpreter from tflite_runtime, else import from regular tensorflow | |
# If using Coral Edge OPENVX, import the load_delegate library | |
pkg = importlib.util.find_spec('tflite_runtime') | |
if pkg: | |
from tflite_runtime.interpreter import Interpreter | |
if use_OPENVX: | |
from tflite_runtime.interpreter import load_delegate | |
else: | |
from tensorflow.lite.python.interpreter import Interpreter | |
if use_OPENVX: | |
from tensorflow.lite.python.interpreter import load_delegate | |
# If using Edge OPENVX, assign filename for Edge OPENVX model | |
if use_OPENVX: | |
# If user has specified the name of the .tflite file, use that name, otherwise use default 'openvx.tflite' | |
if (GRAPH_NAME == 'detect.tflite'): | |
GRAPH_NAME = 'openvx.tflite' | |
os.environ["VIV_VX_ENABLE_CACHE_GRAPH_BINARY"] = "1" | |
os.environ["VIV_VX_CACHE_BINARY_GRAPH_DIR"] = "/tmp" | |
# Path to .tflite file, which contains the model that is used for object detection | |
PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,GRAPH_NAME) | |
# Path to label map file | |
PATH_TO_LABELS = os.path.join(CWD_PATH,MODEL_NAME,LABELMAP_NAME) | |
# Load the label map | |
with open(PATH_TO_LABELS, 'r') as f: | |
labels = [line.strip() for line in f.readlines()] | |
# Have to do a weird fix for label map if using the COCO "starter model" from | |
# https://www.tensorflow.org/lite/models/object_detection/overview | |
# First label is '???', which has to be removed. | |
if labels[0] == '???': | |
del(labels[0]) | |
# Load the Tensorflow Lite model. | |
# If using OPENVX, use special load_delegate argument | |
if use_OPENVX: | |
interpreter = Interpreter(model_path=PATH_TO_CKPT, | |
experimental_delegates=[load_delegate('/usr/lib/libvx_delegate.so')]) | |
print(PATH_TO_CKPT) | |
else: | |
interpreter = Interpreter(model_path=PATH_TO_CKPT) | |
interpreter.allocate_tensors() | |
# Get model details | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
height = input_details[0]['shape'][1] | |
width = input_details[0]['shape'][2] | |
floating_model = (input_details[0]['dtype'] == np.float32) | |
input_mean = 127.5 | |
input_std = 127.5 | |
video_getter = VideoGet(VIDEO_FILE,imW,imH,width,height).start() | |
def sigint_handler(signum, frame): | |
print ('Stop pressing the CTRL+C!') | |
video_getter.stop() | |
cv2.destroyAllWindows() | |
quit() | |
signal.signal(signal.SIGINT, sigint_handler) | |
# Initialize frame rate calculation | |
frame_rate_calc = 1 | |
freq = cv2.getTickFrequency() | |
if not video_getter.stopped: | |
if OUTPUT_FULLSCREEN: | |
window_handle = cv2.namedWindow("Object detector", cv2.WND_PROP_FULLSCREEN) | |
cv2.setWindowProperty("Object detector", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) | |
else: | |
window_handle = cv2.namedWindow("Object detector", cv2.WINDOW_AUTOSIZE | cv2.WINDOW_NORMAL) | |
while cv2.getWindowProperty("Object detector", 0) >= 0 and not video_getter.stopped: | |
# Start timer (for calculating frame rate) | |
t = cv2.getTickCount() | |
# Grab frame from video stream | |
frame = video_getter.frame | |
if np.shape(frame) == (): | |
next | |
input_data = np.expand_dims(video_getter.tensor_frame, axis=0) | |
# Normalize pixel values if using a floating model (i.e. if model is non-quantized) | |
if floating_model: | |
input_data = (np.float32(input_data) - input_mean) / input_std | |
# Perform the actual detection by running the model with the image as input | |
interpreter.set_tensor(input_details[0]['index'],input_data) | |
interpreter.invoke() | |
# Retrieve detection results | |
boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects | |
classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects | |
scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects | |
# num = interpreter.get_tensor(output_details[3]['index']) # Total number of detected objects (inaccurate and not needed) | |
# Loop over all detections and draw detection box if confidence is above minimum threshold | |
for i in range(len(scores)): | |
if ((scores[i] > min_conf_threshold) and (scores[i] <= 1.0)): | |
# Get bounding box coordinates and draw box | |
# Interpreter can return coordinates that are outside of image dimensions, need to force them to be within image using max() and min() | |
ymin = int(max(1,(boxes[i][0] * imH))) | |
xmin = int(max(1,(boxes[i][1] * imW))) | |
ymax = int(min(imH,(boxes[i][2] * imH))) | |
xmax = int(min(imW,(boxes[i][3] * imW))) | |
cv2.rectangle(frame, (xmin,ymin), (xmax,ymax), (10, 255, 0), 2) | |
# Draw label | |
if int(classes[i]) < len(labels): | |
object_name = labels[int(classes[i])] # Look up object name from "labels" array using class index | |
else: | |
object_name = "unknown" | |
label = '%s: %d%%' % (object_name, int(scores[i]*100)) # Example: 'person: 72%' | |
labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) # Get font size | |
label_ymin = max(ymin, labelSize[1] + 10) # Make sure not to draw label too close to top of window | |
cv2.rectangle(frame, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], label_ymin+baseLine-10), (255, 255, 255), cv2.FILLED) # Draw white box to put label text in | |
cv2.putText(frame, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) # Draw label text | |
# Draw framerate in corner of frame | |
cv2.putText(frame,'FPS: {0:.2f}'.format(frame_rate_calc),(30,50),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,0),2,cv2.LINE_AA) | |
# All the results have been drawn on the frame, so it's time to display it. | |
# video_shower.frame = frame | |
cv2.imshow('Object detector', frame) | |
# Calculate framerate | |
t = cv2.getTickCount() - t | |
frame_rate_calc = cv2.getTickFrequency() / t | |
# Press 'q' to quit | |
if cv2.waitKey(1) == ord('q'): | |
video_getter.stop() | |
cv2.destroyAllWindows() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment