Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Last active June 22, 2020 12:36
Show Gist options
  • Select an option

  • Save pythonlessons/d6cea29abcb0c1665f50615f58c7d7f2 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/d6cea29abcb0c1665f50615f58c7d7f2 to your computer and use it in GitHub Desktop.
Yolo_v3_object_tracker
import os
import cv2
import numpy as np
import tensorflow as tf
from yolov3.yolov3 import Create_Yolov3
from yolov3.utils import load_yolo_weights, image_preprocess, postprocess_boxes, nms, draw_bbox, read_class_names
import time
from yolov3.configs import *
from deep_sort import nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from deep_sort import generate_detections as gdet
input_size = YOLO_INPUT_SIZE
Darknet_weights = YOLO_DARKNET_WEIGHTS
if TRAIN_YOLO_TINY:
Darknet_weights = YOLO_DARKNET_TINY_WEIGHTS
video_path = "./IMAGES/test.mp4"
yolo = Create_Yolov3(input_size=input_size)
load_yolo_weights(yolo, Darknet_weights) # use Darknet weights
def Object_tracking(YoloV3, video_path, output_path, input_size=416, show=False, CLASSES=YOLO_COCO_CLASSES, score_threshold=0.3, iou_threshold=0.45, rectangle_colors='', Track_only = []):
# Definition of the parameters
max_cosine_distance = 0.5
nn_budget = None
#initialize deep sort object
model_filename = 'model_data/mars-small128.pb'
encoder = gdet.create_box_encoder(model_filename, batch_size=1)
metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
tracker = Tracker(metric)
times = []
if video_path:
vid = cv2.VideoCapture(video_path) # detect on video
else:
vid = cv2.VideoCapture(0) # detect from webcam
# by default VideoCapture returns float instead of int
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(vid.get(cv2.CAP_PROP_FPS))
codec = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_path, codec, fps, (width, height)) # output_path must be .mp4
NUM_CLASS = read_class_names(CLASSES)
key_list = list(NUM_CLASS.keys())
val_list = list(NUM_CLASS.values())
while True:
_, img = vid.read()
try:
original_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
except:
break
image_data = image_preprocess(np.copy(original_image), [input_size, input_size])
image_data = tf.expand_dims(image_data, 0)
t1 = time.time()
pred_bbox = YoloV3.predict(image_data)
t2 = time.time()
times.append(t2-t1)
times = times[-20:]
pred_bbox = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
pred_bbox = tf.concat(pred_bbox, axis=0)
bboxes = postprocess_boxes(pred_bbox, original_image, input_size, score_threshold)
bboxes = nms(bboxes, iou_threshold, method='nms')
# extract bboxes to boxes (x, y, width, height), scores and names
boxes, scores, names = [], [], []
for bbox in bboxes:
if len(Track_only) !=0 and NUM_CLASS[int(bbox[5])] in Track_only or len(Track_only) == 0:
boxes.append([bbox[0].astype(int), bbox[1].astype(int), bbox[2].astype(int)-bbox[0].astype(int), bbox[3].astype(int)-bbox[1].astype(int)])
scores.append(bbox[4])
names.append(NUM_CLASS[int(bbox[5])])
# Obtain all the detections for the given frame.
boxes = np.array(boxes)
names = np.array(names)
scores = np.array(scores)
features = np.array(encoder(original_image, boxes))
detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in zip(boxes, scores, names, features)]
# Pass detections to the deepsort object and obtain the track information.
tracker.predict()
tracker.update(detections)
# Obtain info from the tracks
tracked_bboxes = []
for track in tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
bbox = track.to_tlbr() # Get the corrected/predicted bounding box
class_name = track.get_class() #Get the class name of particular object
tracking_id = track.track_id # Get the ID for the particular track
index = key_list[val_list.index(class_name)] # Get predicted object index by object name
tracked_bboxes.append(bbox.tolist() + [tracking_id, index]) # Structure data, that we could use it with our draw_bbox function
ms = sum(times)/len(times)*1000
fps = 1000 / ms
# draw detection on frame
image = draw_bbox(original_image, tracked_bboxes, CLASSES=CLASSES, tracking=True)
image = cv2.putText(image, "Time: {:.1f} FPS".format(fps), (0, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 0, 255), 2)
# draw original yolo detection
#image = draw_bbox(image, bboxes, CLASSES=CLASSES, show_label=False, rectangle_colors=rectangle_colors, tracking=True)
#print("Time: {:.2f}ms, {:.1f} FPS".format(ms, fps))
if output_path != '': out.write(image)
if show:
cv2.imshow('output', image)
if cv2.waitKey(25) & 0xFF == ord("q"):
cv2.destroyAllWindows()
break
cv2.destroyAllWindows()
Object_tracking(yolo, video_path, '', input_size=input_size, show=True, iou_threshold=0.1, rectangle_colors=(255,0,0), Track_only = [])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment