Created
February 11, 2024 14:23
-
-
Save a-rbsn/e8d87e941826cb5d02302cf7fbde83c2 to your computer and use it in GitHub Desktop.
Python YOLO Car Tracking
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import os | |
import torch | |
import numpy as np | |
import cv2 | |
import supervision as sv | |
import time | |
from ultralytics import YOLO | |
from datetime import datetime | |
from decimal import Decimal, getcontext | |
getcontext().prec = 28 | |
vehicles = {} | |
vehicles_elapsed = {} | |
pics_taken = {} | |
dinm = 5 | |
min_mph = 42 | |
area_1 = [(238, 320), (652, 239), (1920, 443), (1920, 681)] | |
area_1 = np.array(area_1, np.int32) | |
class YOLOObjectDetector: | |
def __init__(self, capture_index): | |
self.last_saved_time = time.time() | |
self.capture_index = capture_index | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print("Using Device: ", self.device) | |
self.model = self.load_model().to(self.device) | |
self.CLASS_NAMES_DICT = self.model.model.names | |
self.track_history = defaultdict(lambda: []) | |
def load_model(self): | |
try: | |
model = YOLO("yolov8m.pt") # load a pretrained YOLOv8n model | |
return model | |
except Exception as e: | |
print("Error loading the model:", str(e)) | |
return None | |
def predict(self, frame): | |
results = self.model(frame) | |
return results | |
def __call__(self): | |
cap = cv2.VideoCapture(self.capture_index, cv2.CAP_DSHOW) | |
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920) | |
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080) | |
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')) | |
assert cap.isOpened() | |
try: | |
while True: | |
start_time = time.time() | |
ret, frame = cap.read() | |
if not ret: | |
break | |
results = self.model.track(frame, persist=True, conf=0.6, classes=[1, 2, 3], tracker="bytetrack.yaml") | |
end_time = time.time() | |
cv2.polylines(frame, [np.array(area_1, np.int32)], True, (15, 220, 10), 2) | |
fps = 1 / np.round(end_time - start_time, 2) | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] | |
cv2.putText(frame, f'FPS: {int(fps)}', (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2) | |
cv2.putText(frame, timestamp, (20, frame.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2) | |
cv2.putText(frame, str(dinm) + "m", (20, 140), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2) | |
if results[0].boxes.id is not None: | |
# Get the boxes and track IDs | |
boxes = results[0].boxes.xywh.cpu() | |
track_ids = results[0].boxes.id.int().cpu().tolist() | |
frame = results[0].plot(line_width=2) | |
# Plot the tracks | |
for box, track_id in zip(boxes, track_ids): | |
x, y, w, h = box | |
track = self.track_history[track_id] | |
track.append((float(x), float(y))) # x, y center point | |
if len(track) > 30: # retain 90 tracks for 90 frames | |
track.pop(0) | |
# Find the difference between the first and last element in track | |
diff = np.array(track[-1]) - np.array(track[0]) | |
if diff[0] > 30 or diff[0] < -30: | |
# Draw the tracking lines | |
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) | |
cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=2) | |
# Use putText() for diff | |
# cv2.putText(frame, str(diff), (int(x), int(y) + 10), 0, 0.75, (255, 255, 255), 3, cv2.LINE_AA) | |
area_1_result = cv2.pointPolygonTest(area_1, (float(x), float(y)), False) | |
# 1 = inside, 0 = on the line, -1 = outside | |
if area_1_result >= 0 and track_id not in vehicles: | |
vehicles[track_id] = time.time() | |
if track_id in vehicles and area_1_result < 0: | |
elapsed_time = time.time() - vehicles[track_id] | |
vehicles_elapsed[track_id] = elapsed_time | |
pics_taken[track_id] = 0 | |
del vehicles[track_id] | |
if track_id in vehicles_elapsed: | |
elapsed_time = Decimal(vehicles_elapsed[track_id]) | |
print("Elapsed Time: ", elapsed_time) | |
distance_in_meters = Decimal(dinm) | |
speed_mph = (distance_in_meters / elapsed_time) * (Decimal('3600') / Decimal('1609.34')) | |
cv2.putText(frame, str(int(speed_mph)) + "mph", (int(x), int(y) - 10), 0, 0.75, (0, 0, 0), 3, cv2.LINE_AA) | |
cv2.putText(frame, str(int(speed_mph)) + "mph", (int(x), int(y) - 10), 0, 0.75, (255, 255, 255), 2, cv2.LINE_AA) | |
# Inside the if condition block | |
if min_mph < speed_mph < 90 and pics_taken[track_id] < 4: | |
now = datetime.now() | |
folder_name = now.strftime("%Y-%m-%d") | |
os.makedirs(folder_name, exist_ok=True) | |
filename = now.strftime("%Y-%m-%d-%H-%M-%S-%f")[:-3] + ".png" | |
# Check time difference | |
current_time = time.time() | |
time_diff = current_time - self.last_saved_time | |
if time_diff > 0.1: | |
cv2.imwrite(os.path.join(folder_name, filename), frame) | |
self.last_saved_time = current_time # Update last saved time | |
pics_taken[track_id] += 1 | |
if pics_taken[track_id] == 4: | |
del vehicles_elapsed[track_id] | |
del pics_taken[track_id] | |
cv2.imshow('YOLOv8 Detection', frame) | |
if cv2.waitKey(5) & 0xFF == ord('q'): | |
cv2.imwrite('frame.jpg', frame) | |
break | |
finally: | |
cap.release() | |
cv2.destroyAllWindows() | |
detector = YOLOObjectDetector(capture_index=0) | |
detector() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment