Skip to content

Instantly share code, notes, and snippets.

@a-rbsn
Created February 11, 2024 14:23
Show Gist options
  • Save a-rbsn/e8d87e941826cb5d02302cf7fbde83c2 to your computer and use it in GitHub Desktop.
Save a-rbsn/e8d87e941826cb5d02302cf7fbde83c2 to your computer and use it in GitHub Desktop.
Python YOLO Car Tracking
from collections import defaultdict
import os
import torch
import numpy as np
import cv2
import supervision as sv
import time
from ultralytics import YOLO
from datetime import datetime
from decimal import Decimal, getcontext
getcontext().prec = 28
vehicles = {}
vehicles_elapsed = {}
pics_taken = {}
dinm = 5
min_mph = 42
area_1 = [(238, 320), (652, 239), (1920, 443), (1920, 681)]
area_1 = np.array(area_1, np.int32)
class YOLOObjectDetector:
def __init__(self, capture_index):
self.last_saved_time = time.time()
self.capture_index = capture_index
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using Device: ", self.device)
self.model = self.load_model().to(self.device)
self.CLASS_NAMES_DICT = self.model.model.names
self.track_history = defaultdict(lambda: [])
def load_model(self):
try:
model = YOLO("yolov8m.pt") # load a pretrained YOLOv8n model
return model
except Exception as e:
print("Error loading the model:", str(e))
return None
def predict(self, frame):
results = self.model(frame)
return results
def __call__(self):
cap = cv2.VideoCapture(self.capture_index, cv2.CAP_DSHOW)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'))
assert cap.isOpened()
try:
while True:
start_time = time.time()
ret, frame = cap.read()
if not ret:
break
results = self.model.track(frame, persist=True, conf=0.6, classes=[1, 2, 3], tracker="bytetrack.yaml")
end_time = time.time()
cv2.polylines(frame, [np.array(area_1, np.int32)], True, (15, 220, 10), 2)
fps = 1 / np.round(end_time - start_time, 2)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
cv2.putText(frame, f'FPS: {int(fps)}', (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2)
cv2.putText(frame, timestamp, (20, frame.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2)
cv2.putText(frame, str(dinm) + "m", (20, 140), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2)
if results[0].boxes.id is not None:
# Get the boxes and track IDs
boxes = results[0].boxes.xywh.cpu()
track_ids = results[0].boxes.id.int().cpu().tolist()
frame = results[0].plot(line_width=2)
# Plot the tracks
for box, track_id in zip(boxes, track_ids):
x, y, w, h = box
track = self.track_history[track_id]
track.append((float(x), float(y))) # x, y center point
if len(track) > 30: # retain 90 tracks for 90 frames
track.pop(0)
# Find the difference between the first and last element in track
diff = np.array(track[-1]) - np.array(track[0])
if diff[0] > 30 or diff[0] < -30:
# Draw the tracking lines
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=2)
# Use putText() for diff
# cv2.putText(frame, str(diff), (int(x), int(y) + 10), 0, 0.75, (255, 255, 255), 3, cv2.LINE_AA)
area_1_result = cv2.pointPolygonTest(area_1, (float(x), float(y)), False)
# 1 = inside, 0 = on the line, -1 = outside
if area_1_result >= 0 and track_id not in vehicles:
vehicles[track_id] = time.time()
if track_id in vehicles and area_1_result < 0:
elapsed_time = time.time() - vehicles[track_id]
vehicles_elapsed[track_id] = elapsed_time
pics_taken[track_id] = 0
del vehicles[track_id]
if track_id in vehicles_elapsed:
elapsed_time = Decimal(vehicles_elapsed[track_id])
print("Elapsed Time: ", elapsed_time)
distance_in_meters = Decimal(dinm)
speed_mph = (distance_in_meters / elapsed_time) * (Decimal('3600') / Decimal('1609.34'))
cv2.putText(frame, str(int(speed_mph)) + "mph", (int(x), int(y) - 10), 0, 0.75, (0, 0, 0), 3, cv2.LINE_AA)
cv2.putText(frame, str(int(speed_mph)) + "mph", (int(x), int(y) - 10), 0, 0.75, (255, 255, 255), 2, cv2.LINE_AA)
# Inside the if condition block
if min_mph < speed_mph < 90 and pics_taken[track_id] < 4:
now = datetime.now()
folder_name = now.strftime("%Y-%m-%d")
os.makedirs(folder_name, exist_ok=True)
filename = now.strftime("%Y-%m-%d-%H-%M-%S-%f")[:-3] + ".png"
# Check time difference
current_time = time.time()
time_diff = current_time - self.last_saved_time
if time_diff > 0.1:
cv2.imwrite(os.path.join(folder_name, filename), frame)
self.last_saved_time = current_time # Update last saved time
pics_taken[track_id] += 1
if pics_taken[track_id] == 4:
del vehicles_elapsed[track_id]
del pics_taken[track_id]
cv2.imshow('YOLOv8 Detection', frame)
if cv2.waitKey(5) & 0xFF == ord('q'):
cv2.imwrite('frame.jpg', frame)
break
finally:
cap.release()
cv2.destroyAllWindows()
detector = YOLOObjectDetector(capture_index=0)
detector()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment