Created
July 20, 2021 11:44
-
-
Save shoaibmehedi7/62e5d51134d4850eb70c3fe53703cb3b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# centroidtracker.py | |
from scipy.spatial import distance as dist | |
from collections import OrderedDict | |
import numpy as np | |
class CentroidTracker: | |
def __init__(self, maxDisappeared=50, maxDistance=50): | |
self.nextObjectID = 0 | |
self.objects = OrderedDict() | |
self.disappeared = OrderedDict() | |
self.bbox = OrderedDict() | |
self.maxDisappeared = maxDisappeared | |
self.maxDistance = maxDistance | |
def register(self, centroid, inputRect): | |
# when registering an object we use the next available object | |
# ID to store the centroid | |
self.objects[self.nextObjectID] = centroid | |
self.bbox[self.nextObjectID] = inputRect # CHANGE | |
self.disappeared[self.nextObjectID] = 0 | |
self.nextObjectID += 1 | |
def deregister(self, objectID): | |
del self.objects[objectID] | |
del self.disappeared[objectID] | |
del self.bbox[objectID] | |
def update(self, rects): | |
# check to see if the list of input bounding box rectangles | |
# is empty | |
if len(rects) == 0: | |
# loop over any existing tracked objects and mark them | |
# as disappeared | |
for objectID in list(self.disappeared.keys()): | |
self.disappeared[objectID] += 1 | |
# if we have reached a maximum number of consecutive | |
# frames where a given object has been marked as | |
# missing, deregister it | |
if self.disappeared[objectID] > self.maxDisappeared: | |
self.deregister(objectID) | |
return self.bbox | |
# initialize an array of input centroids for the current frame | |
inputCentroids = np.zeros((len(rects), 2), dtype="int") | |
inputRects = [] | |
# loop over the bounding box rectangles | |
for (i, (startX, startY, endX, endY)) in enumerate(rects): | |
# use the bounding box coordinates to derive the centroid | |
cX = int((startX + endX) / 2.0) | |
cY = int((startY + endY) / 2.0) | |
inputCentroids[i] = (cX, cY) | |
inputRects.append(rects[i]) # CHANGE | |
if len(self.objects) == 0: | |
for i in range(0, len(inputCentroids)): | |
self.register(inputCentroids[i], inputRects[i]) | |
else: | |
# grab the set of object IDs and corresponding centroids | |
objectIDs = list(self.objects.keys()) | |
objectCentroids = list(self.objects.values()) | |
D = dist.cdist(np.array(objectCentroids), inputCentroids) | |
rows = D.min(axis=1).argsort() | |
cols = D.argmin(axis=1)[rows] | |
usedRows = set() | |
usedCols = set() | |
for (row, col) in zip(rows, cols): | |
if row in usedRows or col in usedCols: | |
continue | |
if D[row, col] > self.maxDistance: | |
continue | |
objectID = objectIDs[row] | |
self.objects[objectID] = inputCentroids[col] | |
self.bbox[objectID] = inputRects[col] # CHANGE | |
self.disappeared[objectID] = 0 | |
usedRows.add(row) | |
usedCols.add(col) | |
unusedRows = set(range(0, D.shape[0])).difference(usedRows) | |
unusedCols = set(range(0, D.shape[1])).difference(usedCols) | |
if D.shape[0] >= D.shape[1]: | |
# loop over the unused row indexes | |
for row in unusedRows: | |
# grab the object ID for the corresponding row | |
# index and increment the disappeared counter | |
objectID = objectIDs[row] | |
self.disappeared[objectID] += 1 | |
# check to see if the number of consecutive | |
# frames the object has been marked "disappeared" | |
# for warrants deregistering the object | |
if self.disappeared[objectID] > self.maxDisappeared: | |
self.deregister(objectID) | |
else: | |
for col in unusedCols: | |
self.register(inputCentroids[col], inputRects[col]) | |
return self.bbox |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment