Last active
April 22, 2021 20:50
-
-
Save bparaj/d3093eb88a43193a91b51cda28261208 to your computer and use it in GitHub Desktop.
Script to compute and visualize intersection over union (iou) for example rectangle pairs. Uses matplotlib.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from matplotlib.patches import Rectangle | |
def get_iou(b1, b2): | |
b1_x, b1_y, b1_w, b1_h = b1 | |
b2_x, b2_y, b2_w, b2_h = b2 | |
# Assume b1 is the bottom-left-most rectangle. | |
assert b1_x <= b2_x and b1_y <= b2_y | |
# Is there a complete overlap? | |
if b1_x == b2_x and b1_y == b2_y: | |
return 1.0 | |
# Do they not overlap at all? | |
if b2_x > (b1_x + b1_w) or b2_y > (b1_y + b1_h): | |
return 0.0 | |
# Is b2 completely inside b1? | |
if (b2_x + b2_w < b1_x + b1_w) and (b2_y + b2_h < b1_y + b1_h): | |
return (b2_w * b2_h) / (b1_w * b1_h) | |
# Now the general case. | |
xing_w = b2_w if (b2_x + b2_w) < (b1_x + b1_w) else (b1_x + b1_w - b2_x) | |
xing_h = b2_w if (b2_y + b2_h) < (b1_y + b1_h) else (b1_y + b1_h - b2_y) | |
xing_area = xing_w * xing_h | |
union_area = b1_w * b1_h + b2_w * b2_h - xing_area | |
return xing_area / union_area | |
def get_overlapping_boxes(degree): | |
""" | |
Each box is a 4-int tuple: (bot_x, bot_y, width, height) | |
""" | |
if degree == "none": | |
b1 = (6, 8, 40, 50) | |
b2 = (53, 30, 40, 50) | |
elif degree == "bad": | |
b1 = (6, 8, 40, 50) | |
b2 = (30, 25, 40, 50) | |
elif degree == "good": | |
b1 = (6, 8, 40, 50) | |
b2 = (13, 11, 40, 50) | |
elif degree == "wow": | |
b1 = (6, 8, 40, 50) | |
b2 = (8, 9, 40, 50) | |
elif degree == "perfect": | |
b1 = (3, 4, 40, 50) | |
b2 = (3, 4, 40, 50) | |
return b1, b2 | |
if __name__ == "__main__": | |
iou_type = ["none", "bad", "good", "wow"] | |
# Create figure and axes | |
fig, ax = plt.subplots(nrows=2, ncols=2) | |
for idx, d in enumerate(iou_type): | |
b1, b2 = get_overlapping_boxes(d) | |
iou = get_iou(b1, b2) | |
rec1 = Rectangle((b1[0], b1[1]), b1[2], b1[3], ec="red", lw=2, fill=False) | |
rec2 = Rectangle((b2[0], b2[1]), b2[2], b2[3], ec="blue", lw=2, fill=False) | |
ax[idx // 2][idx % 2].add_patch(rec1) | |
ax[idx // 2][idx % 2].add_patch(rec2) | |
ax[idx // 2][idx % 2].set_xlim(0, 100) | |
ax[idx // 2][idx % 2].set_ylim(0, 100) | |
ax[idx // 2][idx % 2].set_title(f"iou = {iou:.2}") | |
ax[idx // 2][idx % 2].get_xaxis().set_visible(False) | |
ax[idx // 2][idx % 2].get_yaxis().set_visible(False) | |
plt.tight_layout() | |
fig.savefig("iou_visualizations.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment