Skip to content

Instantly share code, notes, and snippets.

@ggand0
Last active February 10, 2025 17:37
Show Gist options
  • Select an option

  • Save ggand0/8d1e21b4606c8f79c54f5348e2c04950 to your computer and use it in GitHub Desktop.

Select an option

Save ggand0/8d1e21b4606c8f79c54f5348e2c04950 to your computer and use it in GitHub Desktop.
precision-recall curve visualization for information retrieval
from manim import *
import numpy as np
config.pixel_width = 1280
config.pixel_height = 720
# Define retrieval sequences
retrieval_sequences = [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # High Precision, High Recall
[1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1], # Balanced Retrieval
[1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0] # Low Precision, Low Recall
]
labels = ["High Precision & Recall", "Balanced Retrieval", "Low Precision & Recall"]
pr_curve_colors = [BLUE, GREEN, RED] # Different colors for each PR curve
auc_opacity = 0.2 # Lower opacity for AUC highlights
R = 12 # I specified this value
# Colors
green_correct = rgb_to_color((0.2, 0.8, 0.5)) # Vibrant Green for relevant items
grey_incorrect = rgb_to_color((0.4, 0.4, 0.4)) # Muted grey for not relevant items
pale_grey = rgb_to_color((0.7, 0.7, 0.7)) # Faded irrelevant items
pale_auc = rgb_to_color((0.3, 0.85, 0.55)) # AUC - Softer green with high brightness
# Compute Precision-Recall values
def compute_precision_recall(sequence, R):
precisions, recalls = [], []
retrieved_relevant = 0
for i, rel in enumerate(sequence):
if rel == 1:
retrieved_relevant += 1
precision = retrieved_relevant / (i + 1)
recall = retrieved_relevant / R
precisions.append(precision)
recalls.append(recall)
return np.array(precisions), np.array(recalls)
class IR_AP_Visualization(Scene):
def construct(self):
font_name = "Inter"
# ---- Shared Elements (Axes for PR Curve) ----
axes = Axes(
x_range=[0, 1.1, 0.2], y_range=[0, 1.1, 0.2],
x_length=6.0, # Slightly shorter
y_length=6.0,
axis_config={"color": WHITE, "include_tip": False}
).add_coordinates()
axes.scale(0.75).shift(DOWN * 0.2)
labels_axes = axes.get_axis_labels("Recall", "Precision")
labels_axes.scale(0.5) # Adjust the scale (0.7 makes it smaller)
labels_axes[0].next_to(axes.x_axis, RIGHT, buff=0.1) # "Recall" closer to x-axis
labels_axes[1].next_to(axes.y_axis, UP, buff=0.2) # "Precision" closer to y-axis
# Store precomputed PR Curves & AUC Areas
all_pr_curves = []
all_auc_areas = []
ap_texts = []
# Initialize a single RSG (will be reset per sequence)
num_items = len(retrieval_sequences[0])
rankings = list(range(1, num_items + 1))
cell_width, cell_height = 2.8, 0.18
vertical_spacing = 0.05
total_height = num_items * (cell_height + vertical_spacing) - vertical_spacing
start_y = (total_height / 2) + 0.3
rects, rank_labels = VGroup(), VGroup()
for i, rank in enumerate(rankings):
y = start_y - i * (cell_height + vertical_spacing)
rect = Rectangle(
width=cell_width, height=cell_height,
color=pale_grey, fill_opacity=0.8, stroke_width=0
)
rect.move_to([0, y, 0])
rects.add(rect)
rank_text = Text(str(rank), font_size=16, font=font_name, color=WHITE)
rank_text.next_to(rect, LEFT, buff=0.2)
rank_labels.add(rank_text)
RSG = VGroup(rects, rank_labels)
RSG.shift(LEFT * 3.5) # Keep fixed position
# Group PR Curves and AUC in a single PRC group (to avoid overlap)
shared_prc_group = VGroup(axes, labels_axes)
shared_prc_group.next_to(RSG, RIGHT, buff=1.0)
self.add(RSG, shared_prc_group)
# Legend: (Purple Dot = Anchor Point by Convention)
legend_anchor_dot = Dot(radius=0.06, color=PURPLE) # Smaller than main dot
legend_text = Text("Anchor Point by Convention", font_size=18, color=WHITE)
# Position the legend near the bottom-right of the PRC plot
legend = VGroup(legend_anchor_dot, legend_text).arrange(RIGHT, buff=0.3)
legend.next_to(axes, DOWN, buff=0.5)
retrieved_label = Text("Retrieved Sequence", font_size=24)
retrieved_label.next_to(RSG, UP, buff=0.5)
pr_curve_label = Text("Precision-Recall Curve", font_size=24)
pr_curve_label.next_to(retrieved_label, RIGHT, buff=2.0)
r_symbol = MathTex(r"\#\text{Relevant items} = " + str(R), font_size=30)
r_symbol.next_to(RSG, DOWN, buff=0.5)
self.add(retrieved_label, pr_curve_label, r_symbol)
# AP Text (Move each sequence's AP to avoid overlap)
aucs = [1.0, 0.708, 0.269]
for seq_idx, retrieved_sequence in enumerate(retrieval_sequences):
print(f"Sequence {seq_idx}: {retrieved_sequence}")
pr_curve_color = pr_curve_colors[seq_idx]
if seq_idx == 0:
ap_text = MathTex(r"AUC = {:.3f}".format(aucs[seq_idx]), color=pr_curve_color, font_size=36)
else:
ap_text = MathTex(r"AUC \approx {:.3f}".format(aucs[seq_idx]), color=pr_curve_color, font_size=36)
ap_text.next_to(axes, RIGHT, buff=0.3).shift(UP * 1.0).shift(DOWN * seq_idx * 0.8)
ap_texts.append(ap_text)
all_entities = VGroup(RSG, shared_prc_group, retrieved_label, r_symbol, pr_curve_label, legend, ap_texts)
all_entities.move_to(ORIGIN)
# ---- Start Animation ----
for seq_idx, retrieved_sequence in enumerate(retrieval_sequences):
# 1: Create a duplicate set of rects for the new sequence
new_rects = VGroup()
for rect in rects:
new_rect = rect.copy().set_fill(pale_grey, opacity=0) # Start fully transparent
new_rects.add(new_rect)
# 2: Add the new rects to the scene (invisible at first)
self.add(new_rects)
# 3: Cross-fade animation: Old rects fade out while new rects fade in
self.play(
*[old.animate.set_fill(pale_grey, opacity=0) for old in rects], # Fade out old
*[new.animate.set_fill(pale_grey, opacity=1) for new in new_rects], # Fade in new
run_time=0.5
)
# 4: Replace the old rects with the new ones (clean-up)
rects = new_rects
precisions, recalls = compute_precision_recall(retrieved_sequence, R)
# Create PR Curve & AUC
pr_curve_color = pr_curve_colors[seq_idx]
pr_curve = VMobject().set_color(pr_curve_color).set_stroke(width=4)
pr_curve_points = [axes.c2p(0, 1)] # Start with (0,1) as the first PR point
auc_area = Polygon(
axes.c2p(recalls[0], precisions[0]),
color=pr_curve_color, fill_opacity=auc_opacity
)
all_pr_curves.append(pr_curve)
all_auc_areas.append(auc_area)
# Animation: Show PR Curve Step-by-Step
self.play(FadeIn(pr_curve))
for i in range(len(retrieved_sequence)):
highlight_color = green_correct if retrieved_sequence[i] == 1 else grey_incorrect
pr_curve_points.append(axes.c2p(recalls[i], precisions[i])) # Add new PRC point
new_curve = VMobject().set_color(pr_curve_colors[seq_idx]).set_stroke(width=3)
# NOTE: I'm using the anchor point (0,1) as the first point
# This is a common convention for PR curves (e.g., sklearn does this)
new_curve.set_points_as_corners([axes.c2p(0, 1)] + pr_curve_points)
new_auc = Polygon(
*([axes.c2p(0, 0), axes.c2p(0, 1)] + # Include (0,1)
[axes.c2p(recalls[j], precisions[j]) for j in range(i+1)] +
[axes.c2p(recalls[i], 0)]),
color=pr_curve_colors[seq_idx], fill_opacity=auc_opacity
)
self.remove(pr_curve, auc_area)
pr_curve = new_curve
auc_area = new_auc
self.add(pr_curve, auc_area)
rects[i].set_fill(highlight_color)
pr_dot = Dot(axes.c2p(recalls[i], precisions[i]), color=pr_curve_colors[seq_idx])
self.add(pr_dot)
anchor_dot = Dot(axes.c2p(0, 1), color=PURPLE, radius=0.08)
self.add(anchor_dot)
self.wait(0.2)
ap_text = ap_texts[seq_idx]
self.play(Write(ap_text), run_time=0.3)
self.wait(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment