Last active
February 10, 2025 17:37
-
-
Save ggand0/8d1e21b4606c8f79c54f5348e2c04950 to your computer and use it in GitHub Desktop.
precision-recall curve visualization for information retrieval
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from manim import * | |
| import numpy as np | |
| config.pixel_width = 1280 | |
| config.pixel_height = 720 | |
| # Define retrieval sequences | |
| retrieval_sequences = [ | |
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], # High Precision, High Recall | |
| [1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1], # Balanced Retrieval | |
| [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0] # Low Precision, Low Recall | |
| ] | |
| labels = ["High Precision & Recall", "Balanced Retrieval", "Low Precision & Recall"] | |
| pr_curve_colors = [BLUE, GREEN, RED] # Different colors for each PR curve | |
| auc_opacity = 0.2 # Lower opacity for AUC highlights | |
| R = 12 # I specified this value | |
| # Colors | |
| green_correct = rgb_to_color((0.2, 0.8, 0.5)) # Vibrant Green for relevant items | |
| grey_incorrect = rgb_to_color((0.4, 0.4, 0.4)) # Muted grey for not relevant items | |
| pale_grey = rgb_to_color((0.7, 0.7, 0.7)) # Faded irrelevant items | |
| pale_auc = rgb_to_color((0.3, 0.85, 0.55)) # AUC - Softer green with high brightness | |
| # Compute Precision-Recall values | |
| def compute_precision_recall(sequence, R): | |
| precisions, recalls = [], [] | |
| retrieved_relevant = 0 | |
| for i, rel in enumerate(sequence): | |
| if rel == 1: | |
| retrieved_relevant += 1 | |
| precision = retrieved_relevant / (i + 1) | |
| recall = retrieved_relevant / R | |
| precisions.append(precision) | |
| recalls.append(recall) | |
| return np.array(precisions), np.array(recalls) | |
| class IR_AP_Visualization(Scene): | |
| def construct(self): | |
| font_name = "Inter" | |
| # ---- Shared Elements (Axes for PR Curve) ---- | |
| axes = Axes( | |
| x_range=[0, 1.1, 0.2], y_range=[0, 1.1, 0.2], | |
| x_length=6.0, # Slightly shorter | |
| y_length=6.0, | |
| axis_config={"color": WHITE, "include_tip": False} | |
| ).add_coordinates() | |
| axes.scale(0.75).shift(DOWN * 0.2) | |
| labels_axes = axes.get_axis_labels("Recall", "Precision") | |
| labels_axes.scale(0.5) # Adjust the scale (0.7 makes it smaller) | |
| labels_axes[0].next_to(axes.x_axis, RIGHT, buff=0.1) # "Recall" closer to x-axis | |
| labels_axes[1].next_to(axes.y_axis, UP, buff=0.2) # "Precision" closer to y-axis | |
| # Store precomputed PR Curves & AUC Areas | |
| all_pr_curves = [] | |
| all_auc_areas = [] | |
| ap_texts = [] | |
| # Initialize a single RSG (will be reset per sequence) | |
| num_items = len(retrieval_sequences[0]) | |
| rankings = list(range(1, num_items + 1)) | |
| cell_width, cell_height = 2.8, 0.18 | |
| vertical_spacing = 0.05 | |
| total_height = num_items * (cell_height + vertical_spacing) - vertical_spacing | |
| start_y = (total_height / 2) + 0.3 | |
| rects, rank_labels = VGroup(), VGroup() | |
| for i, rank in enumerate(rankings): | |
| y = start_y - i * (cell_height + vertical_spacing) | |
| rect = Rectangle( | |
| width=cell_width, height=cell_height, | |
| color=pale_grey, fill_opacity=0.8, stroke_width=0 | |
| ) | |
| rect.move_to([0, y, 0]) | |
| rects.add(rect) | |
| rank_text = Text(str(rank), font_size=16, font=font_name, color=WHITE) | |
| rank_text.next_to(rect, LEFT, buff=0.2) | |
| rank_labels.add(rank_text) | |
| RSG = VGroup(rects, rank_labels) | |
| RSG.shift(LEFT * 3.5) # Keep fixed position | |
| # Group PR Curves and AUC in a single PRC group (to avoid overlap) | |
| shared_prc_group = VGroup(axes, labels_axes) | |
| shared_prc_group.next_to(RSG, RIGHT, buff=1.0) | |
| self.add(RSG, shared_prc_group) | |
| # Legend: (Purple Dot = Anchor Point by Convention) | |
| legend_anchor_dot = Dot(radius=0.06, color=PURPLE) # Smaller than main dot | |
| legend_text = Text("Anchor Point by Convention", font_size=18, color=WHITE) | |
| # Position the legend near the bottom-right of the PRC plot | |
| legend = VGroup(legend_anchor_dot, legend_text).arrange(RIGHT, buff=0.3) | |
| legend.next_to(axes, DOWN, buff=0.5) | |
| retrieved_label = Text("Retrieved Sequence", font_size=24) | |
| retrieved_label.next_to(RSG, UP, buff=0.5) | |
| pr_curve_label = Text("Precision-Recall Curve", font_size=24) | |
| pr_curve_label.next_to(retrieved_label, RIGHT, buff=2.0) | |
| r_symbol = MathTex(r"\#\text{Relevant items} = " + str(R), font_size=30) | |
| r_symbol.next_to(RSG, DOWN, buff=0.5) | |
| self.add(retrieved_label, pr_curve_label, r_symbol) | |
| # AP Text (Move each sequence's AP to avoid overlap) | |
| aucs = [1.0, 0.708, 0.269] | |
| for seq_idx, retrieved_sequence in enumerate(retrieval_sequences): | |
| print(f"Sequence {seq_idx}: {retrieved_sequence}") | |
| pr_curve_color = pr_curve_colors[seq_idx] | |
| if seq_idx == 0: | |
| ap_text = MathTex(r"AUC = {:.3f}".format(aucs[seq_idx]), color=pr_curve_color, font_size=36) | |
| else: | |
| ap_text = MathTex(r"AUC \approx {:.3f}".format(aucs[seq_idx]), color=pr_curve_color, font_size=36) | |
| ap_text.next_to(axes, RIGHT, buff=0.3).shift(UP * 1.0).shift(DOWN * seq_idx * 0.8) | |
| ap_texts.append(ap_text) | |
| all_entities = VGroup(RSG, shared_prc_group, retrieved_label, r_symbol, pr_curve_label, legend, ap_texts) | |
| all_entities.move_to(ORIGIN) | |
| # ---- Start Animation ---- | |
| for seq_idx, retrieved_sequence in enumerate(retrieval_sequences): | |
| # 1: Create a duplicate set of rects for the new sequence | |
| new_rects = VGroup() | |
| for rect in rects: | |
| new_rect = rect.copy().set_fill(pale_grey, opacity=0) # Start fully transparent | |
| new_rects.add(new_rect) | |
| # 2: Add the new rects to the scene (invisible at first) | |
| self.add(new_rects) | |
| # 3: Cross-fade animation: Old rects fade out while new rects fade in | |
| self.play( | |
| *[old.animate.set_fill(pale_grey, opacity=0) for old in rects], # Fade out old | |
| *[new.animate.set_fill(pale_grey, opacity=1) for new in new_rects], # Fade in new | |
| run_time=0.5 | |
| ) | |
| # 4: Replace the old rects with the new ones (clean-up) | |
| rects = new_rects | |
| precisions, recalls = compute_precision_recall(retrieved_sequence, R) | |
| # Create PR Curve & AUC | |
| pr_curve_color = pr_curve_colors[seq_idx] | |
| pr_curve = VMobject().set_color(pr_curve_color).set_stroke(width=4) | |
| pr_curve_points = [axes.c2p(0, 1)] # Start with (0,1) as the first PR point | |
| auc_area = Polygon( | |
| axes.c2p(recalls[0], precisions[0]), | |
| color=pr_curve_color, fill_opacity=auc_opacity | |
| ) | |
| all_pr_curves.append(pr_curve) | |
| all_auc_areas.append(auc_area) | |
| # Animation: Show PR Curve Step-by-Step | |
| self.play(FadeIn(pr_curve)) | |
| for i in range(len(retrieved_sequence)): | |
| highlight_color = green_correct if retrieved_sequence[i] == 1 else grey_incorrect | |
| pr_curve_points.append(axes.c2p(recalls[i], precisions[i])) # Add new PRC point | |
| new_curve = VMobject().set_color(pr_curve_colors[seq_idx]).set_stroke(width=3) | |
| # NOTE: I'm using the anchor point (0,1) as the first point | |
| # This is a common convention for PR curves (e.g., sklearn does this) | |
| new_curve.set_points_as_corners([axes.c2p(0, 1)] + pr_curve_points) | |
| new_auc = Polygon( | |
| *([axes.c2p(0, 0), axes.c2p(0, 1)] + # Include (0,1) | |
| [axes.c2p(recalls[j], precisions[j]) for j in range(i+1)] + | |
| [axes.c2p(recalls[i], 0)]), | |
| color=pr_curve_colors[seq_idx], fill_opacity=auc_opacity | |
| ) | |
| self.remove(pr_curve, auc_area) | |
| pr_curve = new_curve | |
| auc_area = new_auc | |
| self.add(pr_curve, auc_area) | |
| rects[i].set_fill(highlight_color) | |
| pr_dot = Dot(axes.c2p(recalls[i], precisions[i]), color=pr_curve_colors[seq_idx]) | |
| self.add(pr_dot) | |
| anchor_dot = Dot(axes.c2p(0, 1), color=PURPLE, radius=0.08) | |
| self.add(anchor_dot) | |
| self.wait(0.2) | |
| ap_text = ap_texts[seq_idx] | |
| self.play(Write(ap_text), run_time=0.3) | |
| self.wait(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment