Skip to content

Instantly share code, notes, and snippets.

@remi-or
Created May 12, 2026 10:47
Show Gist options
  • Select an option

  • Save remi-or/8de44738629c4d3c72451aa01df1a2ab to your computer and use it in GitHub Desktop.

Select an option

Save remi-or/8de44738629c4d3c72451aa01df1a2ab to your computer and use it in GitHub Desktop.
The script used to generate figures where we see CPU and GPU activity
#!/usr/bin/env python3
"""Visualize continuous batching timing data as a timeline."""
import json
import sys
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
def load_timing_data(path: str = "cb_timing.json") -> dict:
"""Loads the timing data from the JSON file. Expects the file to have this format:
{
"generation_start": float,
"generation_end": float,
"generation_duration": float,
"total_gpu_time": float,
"total_cpu_time": float,
"records": [
{
"phase": str, # "gpu" or "cpu"
"start": float,
"end": float,
"duration": float
}, ...
]
}"""
with open(path) as f:
return json.load(f)
def visualize_timeline(data: dict, output_path: str = "cb_phases.png") -> None:
records = data["records"]
if not records:
print("No timing records found.")
return
# Normalize times relative to generation start
gen_start = data["generation_start"]
gen_end = data["generation_end"]
gen_duration = gen_end - gen_start
# Colors
colors = {"gpu": "#e63946", "cpu": "#457b9d"}
# Create figure
fig, ax = plt.subplots(figsize=(14, 4), dpi=150)
# Draw each phase as a horizontal bar (separate y positions to avoid overlap)
bar_height = 0.4
y_pos = {"gpu": 0.25, "cpu": -0.25}
for record in records:
start = record["start"] - gen_start
duration = record["duration"]
phase = record["phase"]
ax.barh(
y=y_pos[phase],
width=duration,
left=start,
height=bar_height,
color=colors[phase],
edgecolor="none",
alpha=1.0,
)
# Styling
ax.set_xlim(0, gen_duration)
ax.set_ylim(-0.6, 0.6)
ax.set_yticks([0.25, -0.25])
ax.set_yticklabels(["GPU", "CPU"], fontsize=10)
ax.set_xlabel("Time (seconds)", fontsize=12)
ax.set_title("Continuous Batching: GPU vs CPU Timeline", fontsize=14, fontweight="bold")
# Legend
gpu_patch = mpatches.Patch(color=colors["gpu"], label=f'GPU ({data["total_gpu_time"]:.3f}s)')
cpu_patch = mpatches.Patch(color=colors["cpu"], label=f'CPU ({data["total_cpu_time"]:.3f}s)')
ax.legend(handles=[gpu_patch, cpu_patch], loc="upper right", fontsize=10)
# Summary text
summary = f'Total: {gen_duration:.3f}s | GPU: {data["total_gpu_time"]:.3f}s ({100*data["total_gpu_time"]/gen_duration:.1f}%) | CPU: {data["total_cpu_time"]:.3f}s ({100*data["total_cpu_time"]/gen_duration:.1f}%)'
ax.text(
0.5, -0.25, summary,
transform=ax.transAxes,
ha="center",
fontsize=10,
color="#333333",
)
# Clean up spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white")
print(f"Saved timeline to {output_path}")
plt.close()
if __name__ == "__main__":
input_path = sys.argv[1] if len(sys.argv) > 1 else "cb_timing.json"
output_path = sys.argv[2] if len(sys.argv) > 2 else "cb_phases.png"
data = load_timing_data(input_path)
visualize_timeline(data, output_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment