Created
November 7, 2021 15:26
-
-
Save michaellee8/2667922534e3589f1629efa616674c54 to your computer and use it in GitHub Desktop.
jetson-inference run_ssd_example_annotate_video.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from vision.ssd.vgg_ssd import create_vgg_ssd, create_vgg_ssd_predictor | |
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd, create_mobilenetv1_ssd_predictor | |
from vision.ssd.mobilenetv1_ssd_lite import create_mobilenetv1_ssd_lite, create_mobilenetv1_ssd_lite_predictor | |
from vision.ssd.squeezenet_ssd_lite import create_squeezenet_ssd_lite, create_squeezenet_ssd_lite_predictor | |
from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite, create_mobilenetv2_ssd_lite_predictor | |
from vision.utils.misc import Timer | |
import cv2 | |
import sys | |
if len(sys.argv) < 5: | |
print('Usage: python run_ssd_example_annotate_video.py <net type> <model path> <label path> <input path> <output path>') | |
sys.exit(0) | |
net_type = sys.argv[1] | |
model_path = sys.argv[2] | |
label_path = sys.argv[3] | |
input_path = sys.argv[4] | |
output_path = sys.argv[5] | |
class_names = [name.strip() for name in open(label_path).readlines()] | |
if net_type == 'vgg16-ssd': | |
net = create_vgg_ssd(len(class_names), is_test=True) | |
elif net_type == 'mb1-ssd': | |
net = create_mobilenetv1_ssd(len(class_names), is_test=True) | |
elif net_type == 'mb1-ssd-lite': | |
net = create_mobilenetv1_ssd_lite(len(class_names), is_test=True) | |
elif net_type == 'mb2-ssd-lite': | |
net = create_mobilenetv2_ssd_lite(len(class_names), is_test=True) | |
elif net_type == 'sq-ssd-lite': | |
net = create_squeezenet_ssd_lite(len(class_names), is_test=True) | |
else: | |
print("The net type is wrong. It should be one of vgg16-ssd, mb1-ssd and mb1-ssd-lite.") | |
sys.exit(1) | |
net.load(model_path) | |
if net_type == 'vgg16-ssd': | |
predictor = create_vgg_ssd_predictor(net, candidate_size=200) | |
elif net_type == 'mb1-ssd': | |
predictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200) | |
elif net_type == 'mb1-ssd-lite': | |
predictor = create_mobilenetv1_ssd_lite_predictor(net, candidate_size=200) | |
elif net_type == 'mb2-ssd-lite': | |
predictor = create_mobilenetv2_ssd_lite_predictor(net, candidate_size=200) | |
elif net_type == 'sq-ssd-lite': | |
predictor = create_squeezenet_ssd_lite_predictor(net, candidate_size=200) | |
else: | |
predictor = create_vgg_ssd_predictor(net, candidate_size=200) | |
vid_capture = cv2.VideoCapture(input_path) | |
if (vid_capture.isOpened() == False): | |
print("Fatal: Error opening the video file") | |
exit(0) | |
fps = vid_capture.get(cv2.CAP_PROP_FPS) | |
frame_count = vid_capture.get(cv2.CAP_PROP_FRAME_COUNT) | |
print('Frames per second : ', fps,'FPS') | |
print('Frame count : ', frame_count) | |
# Obtain frame size information using get() method | |
frame_width = int(vid_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(vid_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
frame_size = (frame_width,frame_height) | |
output = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'XVID'), fps, frame_size) | |
count = 0 | |
while(vid_capture.isOpened()): | |
if count % 100 == 0: | |
print("now working on frame", count) | |
ret, frame = vid_capture.read() | |
if ret == False: | |
break | |
orig_image = frame | |
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) | |
boxes, labels, probs = predictor.predict(image, 10, 0.4) | |
for i in range(boxes.size(0)): | |
box = boxes[i, :] | |
# print(box) | |
cv2.rectangle(orig_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 255, 0), 4) | |
#label = f"""{voc_dataset.class_names[labels[i]]}: {probs[i]:.2f}""" | |
label = f"{class_names[labels[i]]}: {probs[i]:.2f}" | |
cv2.putText(orig_image, label, | |
(int(box[0]) + 20, int(box[1]) + 40), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, # font scale | |
(255, 0, 255), | |
2) # line type | |
output.write(orig_image) | |
count += 1 | |
# print(f"Found {len(probs)} objects. The output image is {path}") | |
vid_capture.release() | |
output.release() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment