Last active
March 30, 2021 16:56
-
-
Save amoshyc/5891d2b8a76ddac43c39b62cd4666047 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def boxes_iou(boxesA, boxesB): | |
boxesA = boxesA.astype(float) | |
boxesB = boxesB.astype(float) | |
boxesA[:, 2:] += boxesA[:, :2] | |
boxesB[:, 2:] += boxesB[:, :2] | |
N, M = len(boxesA), len(boxesB) | |
boxesA = np.broadcast_to(boxesA.reshape(N, 1, 4), (N, M, 4)) | |
boxesB = np.broadcast_to(boxesB.reshape(1, M, 4), (N, M, 4)) | |
x1 = np.maximum(boxesA[..., 0], boxesB[..., 0]) | |
y1 = np.maximum(boxesA[..., 1], boxesB[..., 1]) | |
x2 = np.minimum(boxesA[..., 2], boxesB[..., 2]) | |
y2 = np.minimum(boxesA[..., 3], boxesB[..., 3]) | |
dx = np.clip(x2 - x1, a_min=0, a_max=None) | |
dy = np.clip(y2 - y1, a_min=0, a_max=None) | |
inter = dx * dy | |
wA = np.clip(boxesA[..., 2] - boxesA[..., 0], 0, None) | |
hA = np.clip(boxesA[..., 3] - boxesA[..., 1], 0, None) | |
wB = np.clip(boxesB[..., 2] - boxesB[..., 0], 0, None) | |
hB = np.clip(boxesB[..., 3] - boxesB[..., 1], 0, None) | |
areaA, areaB = wA * hA, wB * hB | |
union = areaA + areaB - inter | |
return inter / (union) | |
def find_in_sorted(sorted, value): | |
idx = np.searchsorted(sorted, value) | |
return idx if idx < len(sorted) and sorted[idx] == value else -1 | |
def evaluate_iou_mota(df_pred, df_true): | |
df_pred = df_pred.sort_values(by=['fid', 'tag']) | |
df_true = df_true.sort_values(by=['fid', 'tag']) | |
pred_data = {fid: group for fid, group in df_pred.groupby('fid')} | |
true_data = {fid: group for fid, group in df_true.groupby('fid')} | |
prev_match = dict() # gt_tag -> pd_tag | |
metrics = [] | |
for fid in tqdm(df_true['fid'].unique()): | |
gt_group = true_data[fid] | |
pd_group = pred_data.get(fid, None) | |
gt_tags = gt_group['tag'].values | |
pd_tags = pd_group['tag'].values | |
gt_boxes = gt_group[['x', 'y', 'w', 'h']].values | |
pd_boxes = pd_group[['x', 'y', 'w', 'h']].values | |
iou = boxes_iou(gt_boxes, pd_boxes) | |
gt_match_mask = np.zeros(len(gt_group), dtype=bool) | |
pd_match_mask = np.zeros(len(pd_group), dtype=bool) | |
curr_match = dict() | |
metric = { | |
'num_objects': len(gt_group), | |
'dist_error': 0.0, | |
'num_misses': 0, | |
'false_dets': 0, | |
'num_match': 0, | |
'num_mismatch': 0, | |
} | |
# Propagate existing matches | |
for gt_tag, pd_tag in prev_match.items(): | |
gt_idx = find_in_sorted(gt_tags, gt_tag) | |
pd_idx = find_in_sorted(pd_tags, pd_tag) | |
if gt_idx == -1 or pd_idx == -1: | |
continue | |
if iou[gt_idx, pd_idx] > 0.5: | |
curr_match[gt_tag] = pd_tag | |
metric['dist_error'] += 1 - iou[gt_idx, pd_idx] | |
metric['num_match'] += 1 | |
gt_match_mask[gt_idx] = True | |
pd_match_mask[pd_idx] = True | |
# Find new matches | |
gt_tags = gt_tags[~gt_match_mask] | |
pd_tags = pd_tags[~pd_match_mask] | |
gt_boxes = gt_boxes[~gt_match_mask] | |
pd_boxes = pd_boxes[~pd_match_mask] | |
iou = iou[np.ix_(~gt_match_mask, ~pd_match_mask)] | |
iou[iou < 0.5] = np.nan | |
rr, cc = solve_dense(-iou) | |
for r, c in zip(rr, cc): | |
curr_match[gt_tags[r]] = pd_tags[c] | |
metric['dist_error'] += 1 - iou[r, c] | |
if gt_tags[r] in prev_match: | |
metric['num_mismatch'] += 1 | |
else: | |
metric['num_match'] += 1 | |
metric['num_misses'] += len(gt_tags) - len(rr) | |
metric['false_dets'] += len(pd_tags) - len(cc) | |
# Update matches | |
prev_match.update(curr_match) | |
# prev_match = curr_match | |
metrics.append(metric) | |
total_g = sum(m['num_objects'] for m in metrics) | |
total_m = sum(m['num_misses'] for m in metrics) | |
total_d = sum(m['dist_error'] for m in metrics) | |
total_fp = sum(m['false_dets'] for m in metrics) | |
total_c = sum(m['num_match'] for m in metrics) | |
total_mme = sum(m['num_mismatch'] for m in metrics) | |
mota = 1 - (total_m + total_fp + total_mme) / total_g | |
motp = total_d / (total_c + total_mme) # motmetrics | |
# paper: motp = total_d / total_c | |
return { | |
'mota': mota, | |
'motp': motp, | |
'ids': total_mme, | |
'num_matches': total_c, | |
'num_objects': total_g, | |
'num_misses': total_m, | |
'num_fp': total_fp, | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def qs_dist(gt_data, pd_data, img_size): | |
imgW, imgH = img_size | |
gt_boxes = gt_data[['x', 'y', 'w', 'h']].values | |
gt_boxes[:, 2:4] += gt_boxes[:, 0:2] | |
gt_boxes[:, [0, 2]] = np.clip(gt_boxes[:, [0, 2]], 0, imgW - 1) | |
gt_boxes[:, [1, 3]] = np.clip(gt_boxes[:, [1, 3]], 0, imgH - 1) | |
gt_insts = [] | |
for x1, y1, x2, y2 in gt_boxes: | |
seg = np.zeros((imgH, imgW), dtype=bool) | |
seg[y1 : y2 + 1, x1 : x2 + 1] = True | |
gt_insts.append(GroundTruthInstance(seg, 0, [x1, y1, x2, y2])) | |
if 'sx1' in pd_data.columns: | |
pd_boxes = pd_data[['x', 'y', 'w', 'h', 'sx1', 'sy1', 'sx2', 'sy2']].values | |
pd_boxes[:, 2:4] += pd_boxes[:, 0:2] | |
pd_boxes[:, [0, 2]] = np.clip(pd_boxes[:, [0, 2]], 0, imgW - 1) | |
pd_boxes[:, [1, 3]] = np.clip(pd_boxes[:, [1, 3]], 0, imgH - 1) | |
pd_insts = [] | |
for x1, y1, x2, y2, sx1, sy1, sx2, sy2 in pd_boxes: | |
tl_cov = [[sx1, 0], [0, sy1]] | |
br_cov = [[sx2, 0], [0, sy2]] | |
pd_insts.append(PBoxDetInst([1.0], [x1, y1, x2, y2], [tl_cov, br_cov])) | |
else: | |
pd_boxes = pd_data[['x', 'y', 'w', 'h']].values | |
pd_boxes[:, 2:4] += pd_boxes[:, 0:2] | |
pd_boxes[:, [0, 2]] = np.clip(pd_boxes[:, [0, 2]], 0, imgW - 1) | |
pd_boxes[:, [1, 3]] = np.clip(pd_boxes[:, [1, 3]], 0, imgH - 1) | |
pd_insts = [] | |
for x1, y1, x2, y2 in pd_boxes: | |
pd_insts.append(BBoxDetInst([1.0], [x1, y1, x2, y2])) | |
N, M = len(gt_insts), len(pd_insts) | |
costs = _gen_cost_tables(gt_insts, pd_insts, False) | |
dist = costs['spatial'] | |
dist = dist.reshape(max(N, M), max(N, M)) | |
dist = dist[:N, :M] | |
return dist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment