Skip to content

Instantly share code, notes, and snippets.

@Helw150
Created April 26, 2026 19:29
Show Gist options
  • Select an option

  • Save Helw150/0aafdae2e5adad9aff01272b44669631 to your computer and use it in GitHub Desktop.

Select an option

Save Helw150/0aafdae2e5adad9aff01272b44669631 to your computer and use it in GitHub Desktop.
Delphi blog: export_figures.py — plotly figure generator from cached W&B data
#!/usr/bin/env python3
"""Export Delphi blog-post plotly figures from the cached W&B data.
This is the plotting-side half of the Delphi pipeline. It reads a JSON
cache written by `fetch_data.py`, fits per-compute-bucket parabolas
plus an asymptotic scaling law per optimizer, and writes the resulting
figures as JSON into `static/assets/images/blog/delphi/`. The blog
template hydrates them via the `{{plotly: <name>.json}}` shortcode and
applies the site-wide Open Athena plotly theme, so this script strips
`template`/`width`/`height`/`font` before writing.
Usage:
uv run scripts/delphi/export_figures.py [--in PATH]
If the cache is missing, run `fetch_data.py` first (it needs
`WANDB_API_KEY`).
Source notebook: https://gist.github.com/Helw150/a1dd8a4fa0437154747c73a2bd35417b
"""
from __future__ import annotations
import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from oa_theme import BRAND_BLACK, BRAND_BLUE, BRAND_TERRACOTTA, OA_PALETTE # noqa: E402
# Warm data color ("gold-ish") for single-series figures.
BRAND_GOLD = OA_PALETTE[1] # Burnt Umber #8F6B38
DEFAULT_CACHE = Path(__file__).resolve().parent / "data" / "runs.json"
DEFAULT_SWEEP_CACHE = (
Path(__file__).resolve().parent / "data" / "hparam_sweep.json"
)
DEFAULT_MMLU_CACHE = Path(__file__).resolve().parent / "data" / "mmlu.json"
DEFAULT_MMLU_EXTERNAL_CACHE = (
Path(__file__).resolve().parent / "data" / "mmlu_external.json"
)
DEFAULT_HUMANEVAL_CACHE = (
Path(__file__).resolve().parent / "data" / "humaneval.json"
)
DEFAULT_HUMANEVAL_EXTERNAL_CACHE = (
Path(__file__).resolve().parent / "data" / "humaneval_external.json"
)
DEFAULT_GSM8K_CACHE = (
Path(__file__).resolve().parent / "data" / "gsm8k.json"
)
DEFAULT_GSM8K_EXTERNAL_CACHE = (
Path(__file__).resolve().parent / "data" / "gsm8k_external.json"
)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
# Map run-name substrings to the display name we want in the figure.
# Order matters: AdamH optimal runs are named
# `adamh-scaling-ladder-nemotron-optimal-*`, which contains "nemotron",
# so the `adamh` check MUST come before the `nemo` check or those rows
# get misclassified as Cautious AdamC.
QUANTIZER_KEYWORDS = {
"adamh_scaling_v6": "Delphi Scaling Suite (AdamH)",
"adamh-scaling-ladder": "Delphi Scaling Suite (AdamH)",
"nemo": "Initial Scaling Suite (Cautious AdamC)",
}
OUT_DIR = (
Path(__file__).resolve().parent.parent.parent
/ "static"
/ "assets"
/ "images"
/ "blog"
/ "delphi"
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def parse_quantizer(name: str) -> str:
lowered = name.lower()
for keyword, display_name in QUANTIZER_KEYWORDS.items():
if keyword in lowered:
return display_name
return "Unknown"
def parse_seed(name: str) -> str:
"""Return a short seed label for a held-out run.
Seeded held-out runs are named ``...optimal-1e+22-v5-seed42-HEXHASH``;
the original (default-seeded) run drops the seed suffix and is named
``...optimal-1e+22-v5-HEXHASH``. Used by the lucky-seeds figure to
label the three AdamH seeds at each budget.
"""
m = re.search(r"-seed(\d+)-", name)
return m.group(1) if m else "default"
def round_sig(x: float, sig: int = 2) -> float:
x = float(x)
if x == 0 or np.isnan(x):
return x
return float(f"{x:.{sig}g}")
def fmt_sci(x: float) -> str:
if x == 0 or np.isnan(x):
return str(x)
exp = int(np.floor(np.log10(abs(x))))
mant = x / (10**exp)
mant_str = f"{mant:.3g}".rstrip("0").rstrip(".")
return f"{mant_str}e{exp:+d}"
def fmt_pct_err(pct: float) -> str:
"""Format a percent residual as ``X.X% worse`` / ``X.X% better``.
Positive = observed value sits above the fit (worse). Negative =
observed value sits below the fit (better). Used for the held-out
callouts on the scaling-law and downstream-eval figures.
"""
if np.isnan(pct):
return ""
direction = "worse" if pct >= 0 else "better"
return f"{abs(pct):.1f}% {direction}"
@dataclass
class Fit:
a: float
b: float
c: float
logT_star: float
T_star: float
loss_star: float
P_star: float
a_under: float
a_over: float
def fit_parabola(sub: pd.DataFrame) -> Fit:
"""loss = a (log10 T)^2 + b (log10 T) + c, optimum clipped to observed range.
Also fits piecewise curvatures ``a_under`` / ``a_over`` around the
vertex — ``loss = loss_star + a_side · (logT - logT_star)²`` fit
separately on each side — so callers that care about over- vs
undertraining asymmetry (e.g. the overtraining-forecast figure) can
skip the symmetric assumption.
"""
T = sub["tokens"].to_numpy(float)
y = sub["loss"].to_numpy(float)
P = sub["params"].to_numpy(float)
logT = np.log10(T)
a, b, c = np.polyfit(logT, y, 2)
logT_star = float(np.clip(-b / (2 * a), logT.min(), logT.max()))
T_star = float(10**logT_star)
loss_star = float(a * logT_star**2 + b * logT_star + c)
m, k = np.polyfit(logT, np.log10(P), 1)
P_star = float(10 ** (m * logT_star + k))
dx = logT - logT_star
dy = y - loss_star
def _side_a(mask: np.ndarray) -> float:
if not mask.any():
return float(a)
r = dx[mask] ** 2
num = float(r @ dy[mask])
den = float(r @ r)
side_a = num / den if den > 0 else float(a)
return side_a if side_a > 0 else float(a)
a_under = _side_a(dx < 0)
a_over = _side_a(dx > 0)
return Fit(
a=float(a),
b=float(b),
c=float(c),
logT_star=logT_star,
T_star=T_star,
loss_star=loss_star,
P_star=P_star,
a_under=a_under,
a_over=a_over,
)
_FAIL = {
"ok": False,
"L_inf": 0.0,
"A": float("nan"),
"alpha": float("nan"),
"sse_log": float("inf"),
}
def _fit_power_at_linf(C: np.ndarray, y: np.ndarray, Linf: float) -> dict | None:
"""Given a fixed L_inf, fit `log(y - L_inf) = log A - alpha * log C`."""
z = y - Linf
if np.any(z <= 0):
return None
logC = np.log(C)
logz = np.log(z)
b1, b0 = np.polyfit(logC, logz, 1)
alpha = -b1
if not np.isfinite(alpha) or alpha <= 0:
return None
pred = b0 + b1 * logC
sse = float(np.mean((logz - pred) ** 2))
return {
"ok": True,
"L_inf": float(Linf),
"A": float(np.exp(b0)),
"alpha": float(alpha),
"sse_log": sse,
}
def fit_asymptotic_powerlaw_per_q(
per_q: dict[str, tuple[np.ndarray, np.ndarray]],
n_grid: int = 400,
) -> dict[str, dict]:
"""Independent per-optimizer fit of ``loss_q(C) = L_inf_q + A_q * C^{-alpha_q}``.
Same inner solver as the shared-L_inf variant, but each optimizer
gets its own L_inf chosen by a 1-D grid search over (0, 0.95·min(y_q)].
Use this when you explicitly do NOT want to tie the asymptote across
optimizers — e.g., when diagnosing whether a shared asymptote is
artificially creating a per-optimizer α gap.
"""
out: dict[str, dict] = {}
for q, (C, y) in per_q.items():
C = np.asarray(C, float)
y = np.asarray(y, float)
order = np.argsort(C)
C = C[order]
y = y[order]
if len(C) < 2:
out[q] = dict(_FAIL)
continue
y_min = float(np.min(y))
Linf_lo, Linf_hi = 0.0, 0.95 * y_min
if Linf_hi <= Linf_lo:
out[q] = dict(_FAIL)
continue
best: dict | None = None
for Linf in np.linspace(Linf_lo, Linf_hi, n_grid):
fit = _fit_power_at_linf(C, y, Linf)
if fit is None:
continue
if best is None or fit["sse_log"] < best["sse_log"]:
best = fit
out[q] = best if best is not None else dict(_FAIL)
return out
def fit_asymptotic_powerlaw_shared_linf(
per_q: dict[str, tuple[np.ndarray, np.ndarray]],
n_grid: int = 400,
) -> dict[str, dict]:
"""Joint fit of ``loss_q(C) = L_inf + A_q * C^{-alpha_q}`` with a
single L_inf shared across all optimizers.
Motivation: L_inf represents the irreducible loss floor of the task
and data, not a property of the optimizer, so it should be tied
across optimizers when comparing scaling curves. Each optimizer
still gets its own (A, alpha) describing how it approaches that
shared floor.
The fit is a 1-D search over L_inf on the interval (0, 0.95·min(y)]
covering every observation. For each trial L_inf we close-form
solve (log A_q, alpha_q) per optimizer via log-space linear
regression, then pick the L_inf that minimizes the sum of per-
optimizer log-space SSEs.
"""
sorted_per_q: dict[str, tuple[np.ndarray, np.ndarray]] = {}
all_y: list[float] = []
for q, (C, y) in per_q.items():
C = np.asarray(C, float)
y = np.asarray(y, float)
order = np.argsort(C)
sorted_per_q[q] = (C[order], y[order])
if len(C) >= 2:
all_y.extend(y.tolist())
if len(all_y) < 2:
return {q: dict(_FAIL) for q in per_q}
y_min = float(np.min(all_y))
Linf_lo, Linf_hi = 0.0, 0.95 * y_min
if Linf_hi <= Linf_lo:
return {q: dict(_FAIL) for q in per_q}
Linf_grid = np.linspace(Linf_lo, Linf_hi, n_grid)
best_total = float("inf")
best_per_q: dict[str, dict] | None = None
for Linf in Linf_grid:
per_q_fit: dict[str, dict] = {}
total = 0.0
ok_all = True
for q, (C, y) in sorted_per_q.items():
if len(C) < 2:
ok_all = False
break
fit = _fit_power_at_linf(C, y, Linf)
if fit is None:
ok_all = False
break
per_q_fit[q] = fit
total += fit["sse_log"]
if not ok_all:
continue
if total < best_total:
best_total = total
best_per_q = per_q_fit
if best_per_q is None:
return {q: dict(_FAIL) for q in per_q}
return best_per_q
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _run_type(name: str) -> str:
"""IsoFLOP sweep runs vs. held-out optimal/validation runs."""
if "optimal" in name.lower():
return "optimal"
return "isoflop"
def load_runs(cache_path: Path) -> pd.DataFrame:
"""Load cached W&B run data, attach display names + run type,
and bucket by GFLOPs.
The cache stores `total_gflops` in raw W&B units (GFLOPs). We keep
it in GFLOPs here and do a single GFLOPs→FLOPs conversion at the
one downstream site that computes `C_flops`.
"""
if not cache_path.exists():
raise SystemExit(
f"cache not found at {cache_path}\n"
"Run `uv run scripts/delphi/fetch_data.py` first "
"(needs WANDB_API_KEY)."
)
rows = json.loads(cache_path.read_text())
df = pd.DataFrame(rows).copy()
df["quantizer"] = df["name"].apply(parse_quantizer)
df["run_type"] = df["name"].apply(_run_type)
# `gflops_q` is a *bucket* of GFLOPs, rounded to 2 sig figs so
# near-identical runs collapse into the same group.
df = df.dropna(subset=["loss", "tokens", "total_gflops", "params"]).copy()
df["gflops_q"] = df["total_gflops"].apply(lambda x: round_sig(x, 2))
return df
def subsample_cadamc_to_adamh_grid(df: pd.DataFrame) -> pd.DataFrame:
"""Drop C-AdamC isoflop points at (params, gflops_q) not present in AdamH.
Both suites share almost all of their (N, D) sweep grid, but C-AdamC
has a few extra small-N points at the larger compute buckets that
AdamH skipped. Dropping them makes the per-bucket parabolas and the
scaling-law fit apples-to-apples between the two ladders. Optimal and
held-out runs are left untouched.
"""
df = df.copy()
df["_pkey"] = (df["params"] / 1e6).round()
adamh_iso = df[(df["run_type"] == "isoflop") & (df["quantizer"] == ADAMH_LABEL)]
adamh_keys = set(zip(adamh_iso["_pkey"], adamh_iso["gflops_q"]))
is_cadamc_iso = (df["run_type"] == "isoflop") & (df["quantizer"] == C_ADAMC_LABEL)
in_grid = df.apply(lambda r: (r["_pkey"], r["gflops_q"]) in adamh_keys, axis=1)
keep = ~is_cadamc_iso | in_grid
return df[keep].drop(columns=["_pkey"]).reset_index(drop=True)
def compute_optima(df: pd.DataFrame) -> tuple[dict, pd.DataFrame, list, list]:
# Parabolic fits are driven by the IsoFLOP sweep only — held-out
# optimal/validation runs are not part of the curve fitting.
df = df[df["run_type"] == "isoflop"]
gflop_groups = sorted(df["gflops_q"].unique())
unique_quantizers = sorted(df["quantizer"].unique().tolist())
fits: dict = {}
optima_rows = []
for q in unique_quantizers:
for g in gflop_groups:
sub = df[(df["quantizer"] == q) & (df["gflops_q"] == g)]
if len(sub) < 3:
continue
f = fit_parabola(sub)
fits[(q, g)] = f
optima_rows.append(
{
"quantizer": q,
"gflops_bucket": g,
"C_flops": g * 1e9,
"T_star": f.T_star,
"P_star": f.P_star,
"loss_star": f.loss_star,
}
)
optima = (
pd.DataFrame(optima_rows)
.sort_values(["quantizer", "gflops_bucket"])
.reset_index(drop=True)
)
return fits, optima, gflop_groups, unique_quantizers
# ---------------------------------------------------------------------------
# Figure builders
# ---------------------------------------------------------------------------
# Palette imported from scripts/oa_theme.py — keep all OA plotting colors
# in one place so future notebooks/scripts produce on-brand charts.
def build_isoflop_figure(
df: pd.DataFrame, fits: dict, gflop_groups: list, unique_quantizers: list
) -> go.Figure:
# Keep the isoflop subplot grid strictly to IsoFLOP sweep runs —
# held-out optimal/validation rows would otherwise land as extra
# fake "compute buckets".
df = df[df["run_type"] == "isoflop"]
g2c = {g: OA_PALETTE[i % len(OA_PALETTE)] for i, g in enumerate(gflop_groups)}
cols = min(len(unique_quantizers), 2)
rows = (len(unique_quantizers) + cols - 1) // cols
fig = make_subplots(
rows=rows,
cols=cols,
shared_yaxes=True,
subplot_titles=unique_quantizers,
horizontal_spacing=0.05,
vertical_spacing=0.12,
)
for i, q in enumerate(unique_quantizers):
row = i // cols + 1
col = i % cols + 1
first = row == 1 and col == 1
for g in gflop_groups:
sub = df[(df["quantizer"] == q) & (df["gflops_q"] == g)]
if len(sub) == 0:
continue
color = g2c[g]
# g is in GFLOPs; the legend title says "FLOPs", so convert.
label = fmt_sci(g * 1e9)
fig.add_trace(
go.Scatter(
x=sub["tokens"],
y=sub["loss"],
mode="markers",
marker=dict(size=7, color=color),
name=label,
legendgroup=label,
showlegend=first,
hovertemplate=(
"tokens=%{x:.3e}<br>loss=%{y:.4f}<br>"
f"bucket={label}<br>{q}<extra></extra>"
),
),
row=row,
col=col,
)
key = (q, g)
if key in fits:
f = fits[key]
logT = np.log10(sub["tokens"].to_numpy(float))
grid = np.linspace(logT.min(), logT.max(), 200)
yhat = f.a * grid**2 + f.b * grid + f.c
fig.add_trace(
go.Scatter(
x=10**grid,
y=yhat,
mode="lines",
line=dict(width=3, color=color),
name=label,
legendgroup=label,
showlegend=False,
hoverinfo="skip",
),
row=row,
col=col,
)
fig.add_trace(
go.Scatter(
x=[f.T_star],
y=[f.loss_star],
mode="markers",
marker=dict(
symbol="x", size=12, color=color, line=dict(width=2)
),
name=label,
legendgroup=label,
showlegend=False,
hovertemplate=(
f"optimum<br>T*={f.T_star:.3e}<br>"
f"loss*={f.loss_star:.4f}<extra></extra>"
),
),
row=row,
col=col,
)
fig.update_xaxes(type="log", title_text="Training tokens")
for r in range(1, rows + 1):
fig.update_yaxes(title_text="Paloma macro loss", row=r, col=1)
# Subplot titles already label each optimizer — the outer frame
# title covers the figure name, so no in-plot title here either.
# Legend goes below, transparent, no border.
# Compute-bucket legend below the subplots. `entrywidth` in
# `fraction` mode ties each entry's minimum width to a fraction of
# the plot area — 0.22 gives 4 entries per row so a 7-bucket set
# lands as 4+3, and the total legend width can never exceed the
# plot width (4 × 0.22 = 0.88 of plot width on the full row).
fig.update_layout(
legend_title_text="Compute (FLOPs)",
legend=dict(
orientation="h",
yanchor="top",
y=-0.2,
xanchor="center",
x=0.5,
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
tracegroupgap=10,
entrywidth=0.22,
entrywidthmode="fraction",
),
margin=dict(t=40, r=40, b=10, l=60),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig
def build_scaling_law_figure(
optima: pd.DataFrame, optimal_runs: pd.DataFrame | None = None
) -> go.Figure:
q_order = sorted(optima["quantizer"].unique().tolist())
# Use distinct, high-contrast colors so the two optimizers are obvious.
q_color = {q: OA_PALETTE[i % len(OA_PALETTE)] for i, q in enumerate(q_order)}
# All IsoFLOP optima are training points for the scaling-law fit.
# The only held-out set is the 1e21+ `optimal` validation runs
# (plotted as stars below); those are never fed to the fit.
train_points: dict[str, pd.DataFrame] = {}
per_q_train: dict[str, tuple[np.ndarray, np.ndarray]] = {}
for q in q_order:
sub_all = optima[optima["quantizer"] == q].sort_values("C_flops").copy()
train_points[q] = sub_all
per_q_train[q] = (
sub_all["C_flops"].to_numpy(float),
sub_all["loss_star"].to_numpy(float),
)
# Independent per-optimizer L_inf. Previously we tied L_inf across
# optimizers on the assumption that the asymptote is a property of
# the task and not the optimizer — but the held-out 1e21/1e22 runs
# showed both curves converging to essentially the same loss, so the
# shared-L_inf fit was inflating AdamC's α just to compensate for its
# worse early-compute starting point. Letting each optimizer pick its
# own L_inf removes that artifact.
asym_fit = fit_asymptotic_powerlaw_per_q(per_q_train)
C_min = float(optima["C_flops"].min())
C_max = max(float(optima["C_flops"].max()), 1e23)
C_grid = np.logspace(np.log10(C_min), np.log10(C_max), 400)
# Pin the x-axis range so the held-out shading can align exactly with
# the frame edges. Use a small log-space pad on each side so the
# outermost markers aren't flush against the axis line.
log_pad = 0.12
x_lo = 10 ** (np.log10(C_min) - log_pad)
x_hi = 10 ** (np.log10(C_max) + log_pad)
# Fit/extrapolation boundary: sits at 5e20, just past the last
# IsoFLOP optimum at 3e20 but before the first held-out optimal run
# at 1e21. Lands cleanly between the two real data regions without
# overlapping any marker.
held_lo = 5e20
fig = go.Figure()
# Thin vertical rule at 10^21 separates the fit region (left) from
# the held-out extrapolation region (right). A subtle dashed rule
# reads as "here be dragons" without the visual weight of a shaded
# rectangle clashing with the page background.
fig.add_shape(
type="line",
x0=held_lo,
x1=held_lo,
xref="x",
y0=0,
y1=1,
yref="paper",
line=dict(color="#1f1e1b", width=1, dash="dot"),
layer="below",
opacity=0.45,
)
# Plotly log-axis quirk: shapes take raw data (x=5e20) but
# annotations take the log10 of the data (x=20.7). Passing 5e20 to an
# annotation puts it at log position 5e20 — miles off the right edge.
# Two annotations anchored to the divider line so the gap stays
# consistent under page resizing — a single centered annotation with
# a fixed offset drifts as the plot width changes.
held_lo_log = float(np.log10(held_lo))
fig.add_annotation(
x=held_lo_log,
y=0.98,
xref="x",
yref="paper",
text="fit ← ",
showarrow=False,
xanchor="right",
yanchor="top",
xshift=-4,
font=dict(size=12, color="#1f1e1b"),
)
fig.add_annotation(
x=held_lo_log,
y=0.98,
xref="x",
yref="paper",
text=" → extrapolation",
showarrow=False,
xanchor="left",
yanchor="top",
xshift=4,
font=dict(size=12, color="#1f1e1b"),
)
for q in q_order:
color = q_color[q]
d = asym_fit[q]
if len(train_points[q]):
fig.add_trace(
go.Scatter(
x=train_points[q]["C_flops"],
y=train_points[q]["loss_star"],
mode="markers",
marker=dict(size=9, color=color, symbol="circle"),
name=q,
legendgroup=q,
hovertemplate=(
f"{q} (IsoFLOP optimum)<br>C=%{{x:.3e}} FLOPs<br>"
"loss*=%{y:.4f}<extra></extra>"
),
)
)
if d["ok"]:
yhat = d["L_inf"] + d["A"] * (C_grid ** (-d["alpha"]))
fig.add_trace(
go.Scatter(
x=C_grid,
y=yhat,
mode="lines",
line=dict(width=2, color=color, dash="dash"),
showlegend=False,
legendgroup=q,
hovertemplate=(
f"{q} fit<br>C=%{{x:.3e}} FLOPs<br>"
"loss=%{y:.4f}<extra></extra>"
),
)
)
# Held-out optimal validation runs (1e21/1e22/1e23). Finished runs are
# averaged across seeds per (quantizer, nominal compute bucket) and
# drawn as stars; crashed runs are drawn as a separate ✕ marker so the
# loss-spike divergence (e.g. C-AdamC at 1e23) is visible on the plot
# rather than silently dropped.
if optimal_runs is not None and len(optimal_runs):
opt_all = optimal_runs.dropna(subset=["loss", "total_gflops"]).copy()
opt_all["C_flops"] = opt_all["total_gflops"] * 1e9
opt_all["C_nom"] = opt_all["C_flops"].apply(
lambda c: 10.0 ** round(float(np.log10(c)))
)
has_state = "state" in opt_all.columns
opt_finished = opt_all[opt_all["state"] != "crashed"] if has_state else opt_all
opt_crashed = (
opt_all[opt_all["state"] == "crashed"] if has_state else opt_all.iloc[0:0]
)
agg = (
opt_finished.groupby(["quantizer", "C_nom"])
.agg(loss_mean=("loss", "mean"), n=("loss", "size"))
.reset_index()
)
for q in q_order:
color = q_color[q]
sub = agg[agg["quantizer"] == q].sort_values("C_nom")
if len(sub) == 0:
continue
fig.add_trace(
go.Scatter(
x=sub["C_nom"],
y=sub["loss_mean"],
mode="markers",
marker=dict(size=11, color=color, symbol="circle"),
showlegend=False,
legendgroup=q,
hovertemplate=(
f"{q} (held-out optimal)<br>C=%{{x:.3e}} FLOPs<br>"
"mean loss=%{y:.4f}<extra></extra>"
),
)
)
for q in q_order:
color = q_color[q]
sub = opt_crashed[opt_crashed["quantizer"] == q]
if len(sub) == 0:
continue
fig.add_trace(
go.Scatter(
x=sub["C_nom"],
y=sub["loss"],
mode="markers+text",
text="Run Diverged",
textfont=dict(color="#1F1E1B"),
textposition="middle left",
marker=dict(
size=14,
color=color,
symbol="x-thin",
line=dict(width=3, color=color),
),
showlegend=False,
legendgroup=q,
hovertemplate=(
f"{q} (diverged mid-run)<br>C=%{{x:.3e}} FLOPs<br>"
"final loss=%{y:.4f}<extra></extra>"
),
)
)
fig.update_xaxes(
type="log",
title_text="Compute (FLOPs)",
range=[np.log10(x_lo), np.log10(x_hi)],
)
fig.update_yaxes(title_text="Paloma macro loss at IsoFLOP optimum")
# No in-plot title — the outer frame (set via the shortcode) already
# names the figure, so we'd just dupe the title.
fig.update_layout(
legend=dict(
orientation="v",
x=1,
xanchor="right",
y=0.925,
yanchor="top",
# Match the inner plotly card background from build.py so the
# legend reads as a subtle inset on the page rather than a
# floating box.
bgcolor="#C4B9AE",
borderwidth=0,
tracegroupgap=4,
),
margin=dict(t=10, r=40, b=10, l=40),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig
C_ADAMC_LABEL = "Initial Scaling Suite (Cautious AdamC)"
ADAMH_LABEL = "Delphi Scaling Suite (AdamH)"
def build_ladder_figure(
df: pd.DataFrame,
fits: dict,
gflop_groups: list,
optima: pd.DataFrame,
optimal_runs: pd.DataFrame | None,
quantizer_label: str = C_ADAMC_LABEL,
) -> go.Figure:
"""Two-panel ladder figure for a single scaling suite.
Left: per-compute-bucket IsoFLOP parabolas with optima marked — how
each compute-optimal point is derived.
Right: scaling-law fit through those optima, dashed extrapolation
past the fit region, held-out 1e21/1e22/1e23 validation circles,
and an ✕ marker on any diverged run.
Used for both section 2 (C-AdamC, where 1e23 diverged) and section 4
(AdamH/Delphi, where all three held-out runs land on forecast).
"""
q = quantizer_label
df_iso = df[(df["quantizer"] == q) & (df["run_type"] == "isoflop")]
optima_q = optima[optima["quantizer"] == q].sort_values("C_flops").copy()
fig = make_subplots(
rows=1,
cols=2,
horizontal_spacing=0.10,
subplot_titles=(
"IsoFLOP parabolas per compute bucket",
"Scaling law + held-out validation",
),
)
# --- Left panel: per-bucket parabolas -------------------------------
g2c = {g: OA_PALETTE[i % len(OA_PALETTE)] for i, g in enumerate(gflop_groups)}
for g in gflop_groups:
sub = df_iso[df_iso["gflops_q"] == g]
if len(sub) == 0:
continue
color = g2c[g]
label = fmt_sci(g * 1e9)
fig.add_trace(
go.Scatter(
x=sub["tokens"],
y=sub["loss"],
mode="markers",
marker=dict(size=7, color=color),
name=label,
legendgroup=label,
showlegend=True,
hovertemplate=(
"tokens=%{x:.3e}<br>loss=%{y:.4f}<br>"
f"bucket={label}<extra></extra>"
),
),
row=1,
col=1,
)
key = (q, g)
if key in fits:
f = fits[key]
logT = np.log10(sub["tokens"].to_numpy(float))
grid = np.linspace(logT.min(), logT.max(), 200)
yhat = f.a * grid**2 + f.b * grid + f.c
fig.add_trace(
go.Scatter(
x=10**grid,
y=yhat,
mode="lines",
line=dict(width=2.5, color=color, dash="dash"),
name=label,
legendgroup=label,
showlegend=False,
hoverinfo="skip",
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=[f.T_star],
y=[f.loss_star],
mode="markers",
marker=dict(symbol="x", size=12, color=color, line=dict(width=2)),
name=label,
legendgroup=label,
showlegend=False,
hovertemplate=(
f"optimum<br>T*={f.T_star:.3e}<br>"
f"loss*={f.loss_star:.4f}<extra></extra>"
),
),
row=1,
col=1,
)
# --- Right panel: scaling law + held-out ---------------------------
per_q_train = {
q: (
optima_q["C_flops"].to_numpy(float),
optima_q["loss_star"].to_numpy(float),
)
}
asym_fit = fit_asymptotic_powerlaw_per_q(per_q_train)
d = asym_fit[q]
# C-AdamC is sorted second in the cross-optimizer figure, so pull
# its OA_PALETTE index from that sorted order to keep colors
# consistent with scaling-law-asymptote.
q_order_all = sorted(optima["quantizer"].unique().tolist())
curve_color = OA_PALETTE[q_order_all.index(q) % len(OA_PALETTE)]
C_min = float(optima_q["C_flops"].min())
C_max = max(float(optima_q["C_flops"].max()), 1e23)
C_grid = np.logspace(np.log10(C_min), np.log10(C_max), 400)
log_pad = 0.12
x_lo = 10 ** (np.log10(C_min) - log_pad)
x_hi = 10 ** (np.log10(C_max) + log_pad)
held_lo = 5e20
fig.add_shape(
type="line",
x0=held_lo,
x1=held_lo,
xref="x2",
y0=0,
y1=1,
yref="y2 domain",
line=dict(color="#1f1e1b", width=1, dash="dot"),
layer="below",
opacity=0.45,
)
held_lo_log = float(np.log10(held_lo))
fig.add_annotation(
x=held_lo_log,
y=0.98,
xref="x2",
yref="y2 domain",
text="fit ← ",
showarrow=False,
xanchor="right",
yanchor="top",
xshift=-4,
font=dict(size=12, color="#1f1e1b"),
)
fig.add_annotation(
x=held_lo_log,
y=0.98,
xref="x2",
yref="y2 domain",
text=" → extrapolation",
showarrow=False,
xanchor="left",
yanchor="top",
xshift=4,
font=dict(size=12, color="#1f1e1b"),
)
fig.add_trace(
go.Scatter(
x=optima_q["C_flops"],
y=optima_q["loss_star"],
mode="markers",
marker=dict(size=9, color=curve_color, symbol="circle"),
name="IsoFLOP optimum",
showlegend=False,
hovertemplate=(
"IsoFLOP optimum<br>C=%{x:.3e} FLOPs<br>loss*=%{y:.4f}<extra></extra>"
),
),
row=1,
col=2,
)
if d["ok"]:
yhat = d["L_inf"] + d["A"] * (C_grid ** (-d["alpha"]))
fig.add_trace(
go.Scatter(
x=C_grid,
y=yhat,
mode="lines",
line=dict(width=2, color=curve_color, dash="dash"),
showlegend=False,
hovertemplate=(
"scaling-law fit<br>C=%{x:.3e} FLOPs<br>"
"loss=%{y:.4f}<extra></extra>"
),
),
row=1,
col=2,
)
# Held-out 1e21/1e22/1e23 optimal validation runs for C-AdamC only.
if optimal_runs is not None and len(optimal_runs):
opt_all = optimal_runs[optimal_runs["quantizer"] == q].copy()
opt_all = opt_all.dropna(subset=["loss", "total_gflops"])
opt_all["C_flops"] = opt_all["total_gflops"] * 1e9
opt_all["C_nom"] = opt_all["C_flops"].apply(
lambda c: 10.0 ** round(float(np.log10(c)))
)
has_state = "state" in opt_all.columns
opt_finished = opt_all[opt_all["state"] != "crashed"] if has_state else opt_all
opt_crashed = (
opt_all[opt_all["state"] == "crashed"] if has_state else opt_all.iloc[0:0]
)
agg = (
opt_finished.groupby("C_nom")
.agg(loss_mean=("loss", "mean"), n=("loss", "size"))
.reset_index()
.sort_values("C_nom")
)
if len(agg) and d["ok"]:
agg["loss_pred"] = d["L_inf"] + d["A"] * (
agg["C_nom"].to_numpy(float) ** (-d["alpha"])
)
agg["residual"] = agg["loss_mean"] - agg["loss_pred"]
agg["pct_err"] = 100.0 * agg["residual"] / agg["loss_pred"]
if len(agg):
fig.add_trace(
go.Scatter(
x=agg["C_nom"],
y=agg["loss_mean"],
mode="markers",
marker=dict(size=11, color=curve_color, symbol="circle"),
showlegend=False,
hovertemplate=(
"held-out run<br>C=%{x:.3e} FLOPs<br>"
"mean loss=%{y:.4f}<extra></extra>"
),
),
row=1,
col=2,
)
if "pct_err" in agg.columns:
for _, row in agg.iterrows():
# Keep the 1e21 label up-right to avoid colliding with
# the fit/extrapolation divider; push 1e22/1e23 labels
# down-left so they stop overlapping the fit line.
if row["C_nom"] <= 1.5e21:
ax, ay = 4, -4
xanchor, yanchor = "left", "bottom"
else:
ax, ay = -4, 4
xanchor, yanchor = "right", "top"
fig.add_annotation(
x=float(np.log10(row["C_nom"])),
y=float(np.log10(row["loss_mean"])),
xref="x2",
yref="y2",
text=fmt_pct_err(float(row["pct_err"])),
showarrow=True,
arrowhead=0,
arrowwidth=1,
arrowcolor="rgba(80,80,80,0.5)",
ax=ax,
ay=ay,
xanchor=xanchor,
yanchor=yanchor,
font=dict(size=11, color="#1f1e1b"),
bgcolor="rgba(0,0,0,0)",
borderpad=0,
)
if len(opt_crashed):
fig.add_trace(
go.Scatter(
x=opt_crashed["C_nom"],
y=opt_crashed["loss"],
mode="markers",
marker=dict(
size=14,
color=curve_color,
symbol="x-thin",
line=dict(width=3, color=curve_color),
),
showlegend=False,
hovertemplate=(
"diverged mid-run<br>C=%{x:.3e} FLOPs<br>"
"final loss=%{y:.4f}<extra></extra>"
),
),
row=1,
col=2,
)
for _, row in opt_crashed.iterrows():
fig.add_annotation(
x=float(np.log10(row["C_nom"])),
y=float(np.log10(row["loss"])),
xref="x2",
yref="y2",
text="Run Diverged",
showarrow=True,
arrowhead=0,
arrowwidth=1,
arrowcolor="rgba(80,80,80,0.5)",
ax=-30,
ay=-22,
xanchor="center",
yanchor="bottom",
font=dict(size=11, color="#1f1e1b"),
bgcolor="rgba(0,0,0,0)",
borderpad=0,
)
fig.update_xaxes(type="log", title_text="Training tokens", row=1, col=1)
fig.update_xaxes(
type="log",
title_text="Compute (FLOPs)",
range=[np.log10(x_lo), np.log10(x_hi)],
row=1,
col=2,
)
fig.update_yaxes(type="log", title_text="Paloma macro loss", row=1, col=1)
fig.update_yaxes(
type="log", title_text="Paloma macro loss at IsoFLOP optimum", row=1, col=2
)
fig.update_layout(
legend_title_text="Compute (FLOPs)",
legend=dict(
orientation="h",
yanchor="top",
y=-0.2,
xanchor="right",
x=1.0,
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
tracegroupgap=10,
entrywidth=0.22,
entrywidthmode="fraction",
),
margin=dict(t=40, r=30, b=10, l=60),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig
# ---------------------------------------------------------------------------
# Overtraining-forecast figure (section 5: "How much does under-overtraining hurt?")
# ---------------------------------------------------------------------------
DEFAULT_SLIDER_SCALES: tuple[float, ...] = (
0.2, 1 / 3, 0.5, 2 / 3, 2.0, 3.0, 4.0, 5.0, 10.0,
)
def _overtraining_labels(scale: float) -> tuple[str, str]:
"""Return (slider_step_label, midpoint_label) for a given scale.
Both use the same multiplicative convention so the slider and the
annotation read consistently across the full range. ``:.3g`` keeps
fractional scales like 1/3 or 2/3 readable (0.333×, 0.667×).
"""
step_label = f"{scale:.3g}×"
if scale > 1:
midpoint_label = f"{scale:.3g}× overtrained"
elif scale < 1:
midpoint_label = f"{1 / scale:.3g}× undertrained"
else:
midpoint_label = "compute-optimal"
return step_label, midpoint_label
def build_overtraining_forecast_figure(
df: pd.DataFrame,
fits: dict,
gflop_groups: list,
optima: pd.DataFrame,
quantizer_label: str = ADAMH_LABEL,
forecast_C: float = 1e23,
default_scale: float = 2.0,
slider_scales: tuple[float, ...] = DEFAULT_SLIDER_SCALES,
) -> go.Figure:
"""Per-bucket parabolas with a slider that shifts the ○ off-optimum
marker and re-projects the sub-optimal forecast.
Static pieces: IsoFLOP data points, per-bucket parabola curves, ×
compute-optimum markers, the solid optimum regression, and the ×
forecast endpoint at ``forecast_C``. Animated pieces: the ○ shifted
marker on each bucket's parabola, the dashed sub-optimal regression
(``loss_opt(C) + Δ(C)`` with a per-scale power-law penalty fit),
the ○ forecast endpoint, the dotted connector at the forecast
budget, and the midpoint label. The slider walks ``scale = T/T*``
across ``slider_scales``; the initial view uses ``default_scale``.
"""
q = quantizer_label
df_iso = df[(df["quantizer"] == q) & (df["run_type"] == "isoflop")]
optima_q = optima[optima["quantizer"] == q].sort_values("C_flops").copy()
parabola_color = BRAND_GOLD
trend_color = BRAND_BLUE
# --- Per-bucket static data and the compute-axis fits (shared across frames).
bucket_records: list[dict] = []
C_list, T_opt_list, loss_opt_list = [], [], []
a_under_list, a_over_list = [], []
for g in gflop_groups:
sub = df_iso[df_iso["gflops_q"] == g]
if len(sub) == 0 or (q, g) not in fits:
continue
f = fits[(q, g)]
C = float(optima_q.loc[optima_q["gflops_bucket"] == g, "C_flops"].iloc[0])
bucket_records.append({"g": g, "f": f, "sub": sub, "C": C})
C_list.append(C)
T_opt_list.append(f.T_star)
loss_opt_list.append(f.loss_star)
a_under_list.append(f.a_under)
a_over_list.append(f.a_over)
if not bucket_records:
return go.Figure()
C_arr = np.array(C_list)
T_opt_arr = np.array(T_opt_list)
loss_opt_arr = np.array(loss_opt_list)
a_under_arr = np.array(a_under_list)
a_over_arr = np.array(a_over_list)
mT, kT = np.polyfit(np.log10(C_arr), np.log10(T_opt_arr), 1)
fit_opt = fit_asymptotic_powerlaw_per_q({q: (C_arr, loss_opt_arr)})[q]
if not fit_opt["ok"]:
return go.Figure()
# The compute-optimal curve is drawn only out to forecast_C.
C_grid = np.logspace(np.log10(C_arr.min()), np.log10(forecast_C), 400)
T_grid = 10 ** (mT * np.log10(C_grid) + kT)
loss_opt_grid = fit_opt["L_inf"] + fit_opt["A"] * (C_grid ** (-fit_opt["alpha"]))
T_opt_end = float(10 ** (mT * np.log10(forecast_C) + kT))
y_opt_end = float(
fit_opt["L_inf"] + fit_opt["A"] * (forecast_C ** (-fit_opt["alpha"]))
)
# Extended axis, used only to project the sub-optimal dashed line
# out to the compute C_equiv where its loss matches y_opt_end.
C_grid_ext = np.logspace(
np.log10(C_arr.min()), np.log10(forecast_C) + 3.0, 800
)
T_grid_ext = 10 ** (mT * np.log10(C_grid_ext) + kT)
loss_opt_grid_ext = fit_opt["L_inf"] + fit_opt["A"] * (
C_grid_ext ** (-fit_opt["alpha"])
)
forecast_idx_ext = int(np.searchsorted(C_grid_ext, forecast_C))
# --- Per-scale computation (everything that the slider animates).
def compute_for_scale(s: float) -> dict:
# Piecewise curvature around each bucket's vertex —
# loss = loss_star + a_side · (log10 s)² — so overtraining
# (s > 1) and undertraining (s < 1) can hurt differently.
a_side_arr = a_over_arr if s > 1 else a_under_arr
penalty_arr = a_side_arr * (np.log10(s) ** 2)
loss_shift_arr = loss_opt_arr + penalty_arr
if np.all(penalty_arr > 1e-9):
p_slope, p_int = np.polyfit(
np.log10(C_arr), np.log10(penalty_arr), 1
)
penalty_grid_ext = 10 ** (p_slope * np.log10(C_grid_ext) + p_int)
penalty_end = float(10 ** (p_slope * np.log10(forecast_C) + p_int))
else:
penalty_grid_ext = np.zeros_like(C_grid_ext)
penalty_end = 0.0
T_shift_grid_full = T_grid_ext * s
loss_shift_grid_full = loss_opt_grid_ext + penalty_grid_ext
T_shift_end = T_opt_end * s
y_shift_end = y_opt_end + penalty_end
# Walk the extended dashed line from forecast_C onward until its
# loss falls to y_opt_end; the gap between forecast_C and that
# crossing is the compute overhead. If the sub-optimal penalty
# grows faster than the compute-optimal curve shrinks, no
# crossing exists — fall back to the end of the extended grid.
beyond = np.arange(len(C_grid_ext)) >= forecast_idx_ext
crossed = beyond & (loss_shift_grid_full <= y_opt_end)
if crossed.any():
cross_idx = int(np.argmax(crossed))
if cross_idx > 0:
lo, hi = cross_idx - 1, cross_idx
y_lo = loss_shift_grid_full[lo]
y_hi = loss_shift_grid_full[hi]
if y_lo != y_hi:
frac = (y_lo - y_opt_end) / (y_lo - y_hi)
else:
frac = 0.0
logC_cross = (
np.log10(C_grid_ext[lo])
+ frac * (np.log10(C_grid_ext[hi]) - np.log10(C_grid_ext[lo]))
)
C_equiv = float(10**logC_cross)
T_cross = float(10 ** (mT * logC_cross + kT) * s)
else:
C_equiv = float(C_grid_ext[cross_idx])
T_cross = float(T_shift_grid_full[cross_idx])
end_idx = cross_idx + 1
else:
C_equiv = float(C_grid_ext[-1])
T_cross = float(T_shift_grid_full[-1])
end_idx = len(C_grid_ext)
T_shift_grid = T_shift_grid_full[:end_idx]
loss_shift_grid = loss_shift_grid_full[:end_idx]
overhead_pct = 100.0 * (C_equiv - forecast_C) / forecast_C
sign = "+" if overhead_pct >= 0 else ""
excess_label = f"{sign}{overhead_pct:.1f}% compute to match"
_, overtrain_label = _overtraining_labels(s)
# Midpoint for the label sits along the horizontal compute-gap
# connector at y_opt_end, between the × at forecast_C and the
# crossing marker at C_equiv.
x_mid = float(np.sqrt(T_opt_end * max(T_cross, T_opt_end)))
y_mid = y_opt_end
return dict(
shifted_xs=(T_opt_arr * s).tolist(),
shifted_ys=loss_shift_arr.tolist(),
T_shift_grid=T_shift_grid,
loss_shift_grid=loss_shift_grid,
T_shift_end=T_shift_end,
y_shift_end=y_shift_end,
T_cross=T_cross,
C_equiv=C_equiv,
overtrain_label=overtrain_label,
excess_label=excess_label,
x_mid=x_mid,
y_mid=y_mid,
)
all_scales = sorted(set(slider_scales) | {default_scale})
frames_data = {s: compute_for_scale(s) for s in all_scales}
base = frames_data[default_scale]
fig = go.Figure()
# --- Static traces ------------------------------------------------
for rec in bucket_records:
f = rec["f"]
sub = rec["sub"]
label = fmt_sci(rec["g"] * 1e9)
fig.add_trace(
go.Scatter(
x=sub["tokens"],
y=sub["loss"],
mode="markers",
marker=dict(size=6, color=parabola_color, opacity=0.55),
showlegend=False,
hovertemplate=(
"tokens=%{x:.3e}<br>loss=%{y:.4f}<br>"
f"bucket={label}<extra></extra>"
),
),
)
# Piecewise parabola curve only within the observed data range.
# Separate curvature on each side of the vertex so the drawn
# curve matches the asymmetric penalty used by the slider. The
# ○ marker can float past the curve at extreme slider values —
# quadratic extrapolation is unreliable out there and would
# swamp the plot's y-range if drawn.
logT = np.log10(sub["tokens"].to_numpy(float))
grid_lo = logT.min() - 0.02
grid_hi = logT.max() + 0.02
g_vals = np.linspace(grid_lo, grid_hi, 200)
dx = g_vals - f.logT_star
y_vals = f.loss_star + np.where(dx < 0, f.a_under, f.a_over) * dx**2
fig.add_trace(
go.Scatter(
x=10**g_vals,
y=y_vals,
mode="lines",
line=dict(width=1.5, color=parabola_color, dash="dash"),
showlegend=False,
hoverinfo="skip",
),
)
fig.add_trace(
go.Scatter(
x=[f.T_star],
y=[f.loss_star],
mode="markers",
marker=dict(
symbol="x", size=12, color=parabola_color, line=dict(width=2)
),
showlegend=False,
hovertemplate=(
f"optimum<br>bucket={label}<br>T*={f.T_star:.3e}<br>"
f"loss*={f.loss_star:.4f}<extra></extra>"
),
),
)
fig.add_trace(
go.Scatter(
x=T_grid,
y=loss_opt_grid,
mode="lines",
line=dict(width=2, color=trend_color),
name="Fit through optima",
customdata=C_grid,
hovertemplate=(
"compute-optimal regression<br>"
"C=%{customdata:.2e} FLOPs<br>"
"T=%{x:.3e}<br>loss=%{y:.4f}<extra></extra>"
),
),
)
fig.add_trace(
go.Scatter(
x=[T_opt_end],
y=[y_opt_end],
mode="markers",
marker=dict(
symbol="x", size=12, color=trend_color, line=dict(width=2)
),
showlegend=False,
hovertemplate=(
f"forecast @ {fmt_sci(forecast_C)} FLOPs<br>"
"T=%{x:.3e}<br>loss=%{y:.4f}<extra></extra>"
),
),
)
# --- Animated traces (recorded in a known order for frame updates) -
animated_indices: list[int] = []
# ○ shifted marker per bucket
for i, rec in enumerate(bucket_records):
animated_indices.append(len(fig.data))
label = fmt_sci(rec["g"] * 1e9)
fig.add_trace(
go.Scatter(
x=[base["shifted_xs"][i]],
y=[base["shifted_ys"][i]],
mode="markers",
marker=dict(
symbol="circle-open",
size=11,
color=parabola_color,
line=dict(width=2.5, color=parabola_color),
),
showlegend=False,
hovertemplate=(
f"shifted<br>bucket={label}<br>"
"T=%{x:.3e}<br>loss=%{y:.4f}<extra></extra>"
),
),
)
# Dashed sub-optimal regression
animated_indices.append(len(fig.data))
fig.add_trace(
go.Scatter(
x=base["T_shift_grid"],
y=base["loss_shift_grid"],
mode="lines",
line=dict(width=2, color=trend_color, dash="dash"),
name="Sub-optimal fit",
customdata=C_grid_ext,
hovertemplate=(
"sub-optimal regression<br>"
"C=%{customdata:.2e} FLOPs<br>"
"T=%{x:.3e}<br>loss=%{y:.4f}<extra></extra>"
),
),
)
# ○ compute-equivalent endpoint — where the extended sub-optimal
# forecast reaches the compute-optimal loss at forecast_C. Horizontal
# position encodes the extra tokens (and, via the power-law T↔C
# mapping, the extra compute) required to match.
animated_indices.append(len(fig.data))
fig.add_trace(
go.Scatter(
x=[base["T_cross"]],
y=[y_opt_end],
mode="markers",
marker=dict(
symbol="circle-open",
size=11,
color=trend_color,
line=dict(width=2.5, color=trend_color),
),
showlegend=False,
customdata=[[base["C_equiv"]]],
hovertemplate=(
"compute to match forecast loss<br>"
"C=%{customdata[0]:.2e} FLOPs<br>"
"T=%{x:.3e}<extra></extra>"
),
),
)
# Horizontal dotted connector at y_opt_end from × (forecast_C) out to
# the compute-equivalent crossing — the visual span is the compute
# overhead.
animated_indices.append(len(fig.data))
fig.add_trace(
go.Scatter(
x=[T_opt_end, base["T_cross"]],
y=[y_opt_end, y_opt_end],
mode="lines",
line=dict(width=1.5, color=trend_color, dash="dot"),
showlegend=False,
hoverinfo="skip",
),
)
# --- Midpoint annotation (rebuilt per frame) ---------------------
def build_annotation(d: dict) -> dict:
return dict(
x=float(np.log10(d["x_mid"])),
y=float(np.log10(d["y_mid"])),
xref="x",
yref="y",
text=f"{d['overtrain_label']}<br>{d['excess_label']}",
showarrow=False,
xanchor="left",
yanchor="top",
xshift=6,
yshift=-6,
font=dict(size=11, color="#1f1e1b"),
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
align="left",
)
# --- Build frames -------------------------------------------------
# The 10× overtrained crossing lands well past the default 10T upper
# bound — widen the x-axis only on that frame so the label and marker
# stay in view.
default_x_range = [8.5, 13.0]
wide_x_range = [8.5, np.log10(5e13)]
def x_range_for(s: float) -> list[float]:
return wide_x_range if s >= 10.0 else default_x_range
frames: list[go.Frame] = []
for s in slider_scales:
d = frames_data[s]
frame_data = []
# ○ shifted markers per bucket
for i in range(len(bucket_records)):
frame_data.append(dict(x=[d["shifted_xs"][i]], y=[d["shifted_ys"][i]]))
# Dashed sub-optimal regression (extended to the crossing)
frame_data.append(dict(x=d["T_shift_grid"], y=d["loss_shift_grid"]))
# ○ compute-equivalent endpoint
frame_data.append(
dict(x=[d["T_cross"]], y=[y_opt_end], customdata=[[d["C_equiv"]]])
)
# Horizontal dotted connector at y_opt_end
frame_data.append(
dict(x=[T_opt_end, d["T_cross"]], y=[y_opt_end, y_opt_end])
)
frames.append(
go.Frame(
name=str(s),
data=frame_data,
traces=animated_indices,
layout=go.Layout(
annotations=[build_annotation(d)],
xaxis=dict(range=x_range_for(s)),
),
)
)
fig.frames = frames
# --- Slider + layout ---------------------------------------------
try:
active_idx = list(slider_scales).index(default_scale)
except ValueError:
active_idx = 0
slider_steps = []
for s in slider_scales:
step_label, _ = _overtraining_labels(s)
slider_steps.append(
dict(
method="animate",
args=[
[str(s)],
dict(
frame=dict(duration=0, redraw=True),
mode="immediate",
transition=dict(duration=0),
),
],
label=step_label,
)
)
fig.update_xaxes(
type="log", title_text="Training tokens", range=x_range_for(default_scale)
)
fig.update_yaxes(type="log", title_text="Paloma macro loss")
fig.update_layout(
sliders=[
dict(
active=active_idx,
currentvalue=dict(prefix="Overtraining: ", font=dict(size=12)),
pad=dict(t=30, l=30, r=30),
steps=slider_steps,
len=0.9,
x=0.05,
y=-0.12,
)
],
annotations=[build_annotation(base)],
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1.0,
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
tracegroupgap=10,
),
margin=dict(t=60, r=30, b=100, l=60),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig
# ---------------------------------------------------------------------------
# Lucky-seeds figure (section 5: "Did we get lucky?")
# ---------------------------------------------------------------------------
def _hex_to_rgba(hex_color: str, alpha: float) -> str:
h = hex_color.lstrip("#")
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
return f"rgba({r},{g},{b},{alpha})"
def _bootstrap_scaling_fit_ci(
C: np.ndarray,
y: np.ndarray,
target_flops: list[float],
n_boot: int = 1000,
rng_seed: int = 0,
ci_levels: tuple[float, float] = (2.5, 97.5),
) -> dict[float, tuple[float, float, float]]:
"""Bootstrap the scaling-law forecast at each target budget.
Resamples the per-bucket AdamH IsoFLOP optima with replacement,
refits the shared (L_inf, A, alpha) form used by the point forecast,
and returns the (low, median, high) percentiles of the forecast at
each target budget in absolute Paloma macro loss. Bucket-level
resampling is the right unit: each bucket contributes one measured
compute-optimal point, while the within-bucket (N, D) sweep is a
deliberate experimental design, not an IID sample. Resampling
inside a bucket would fake degrees of freedom.
"""
rng = np.random.default_rng(rng_seed)
n = len(C)
forecasts_by_target: dict[float, list[float]] = {t: [] for t in target_flops}
for _ in range(n_boot):
idx = rng.integers(0, n, size=n)
Cb, yb = C[idx], y[idx]
if len(np.unique(Cb)) < 3:
continue
fit = fit_asymptotic_powerlaw_per_q({"b": (Cb, yb)}, n_grid=100)["b"]
if not fit["ok"]:
continue
for t in target_flops:
forecasts_by_target[t].append(
fit["L_inf"] + fit["A"] * (t ** (-fit["alpha"]))
)
result: dict[float, tuple[float, float, float]] = {}
for t, forecasts in forecasts_by_target.items():
if not forecasts:
result[t] = (float("nan"), float("nan"), float("nan"))
else:
a = np.array(forecasts)
result[t] = (
float(np.percentile(a, ci_levels[0])),
float(np.percentile(a, 50)),
float(np.percentile(a, ci_levels[1])),
)
return result
def build_lucky_seeds_figure(
df: pd.DataFrame, optima: pd.DataFrame
) -> go.Figure:
"""Bootstrap forecast CI vs observed seed variance.
For each held-out budget, draws the 2.5–97.5% bootstrap CI of the
scaling-law forecast as a shaded band and overlays the individual
seed observations. The x-axis is percent deviation from the point
forecast so all three budgets share one scale. 1e23 is included as
a single-seed observation because the bootstrap-CI story does not
depend on seed replication.
"""
q = ADAMH_LABEL
optima_q = optima[optima["quantizer"] == q].sort_values("C_flops")
C_train = optima_q["C_flops"].to_numpy(float)
y_train = optima_q["loss_star"].to_numpy(float)
point_fit = fit_asymptotic_powerlaw_per_q({q: (C_train, y_train)})[q]
if not point_fit["ok"]:
return go.Figure()
all_budgets = [1e21, 1e22, 1e23]
ci = _bootstrap_scaling_fit_ci(C_train, y_train, all_budgets, n_boot=1000)
optimal = df[(df["quantizer"] == q) & (df["run_type"] == "optimal")].copy()
optimal = optimal.dropna(subset=["loss", "total_gflops"])
optimal["C_flops"] = optimal["total_gflops"] * 1e9
optimal["C_nom"] = optimal["C_flops"].apply(
lambda c: float(f"1e{round(float(np.log10(c)))}")
)
optimal["seed"] = optimal["name"].apply(parse_seed)
q_order_all = sorted(optima["quantizer"].unique().tolist())
seed_color = OA_PALETTE[q_order_all.index(q) % len(OA_PALETTE)]
ci_color = OA_PALETTE[(q_order_all.index(q) + 1) % len(OA_PALETTE)]
rows = []
for C_nom in all_budgets:
point_pred = point_fit["L_inf"] + point_fit["A"] * (
C_nom ** (-point_fit["alpha"])
)
lo, _med, hi = ci[C_nom]
seeds = optimal[optimal["C_nom"] == C_nom].sort_values("seed")
losses = seeds["loss"].to_numpy(float)
rows.append(
{
"C_nom": C_nom,
"point_pred": point_pred,
"ci_low_pct": 100.0 * (lo - point_pred) / point_pred,
"ci_high_pct": 100.0 * (hi - point_pred) / point_pred,
"seed_pcts": (100.0 * (losses - point_pred) / point_pred).tolist(),
"seed_labels": seeds["seed"].tolist(),
}
)
fig = go.Figure()
fig.add_shape(
type="line",
x0=0,
x1=0,
y0=-0.5,
y1=len(rows) - 0.5,
line=dict(color=BRAND_BLACK, dash="dash", width=1.5),
layer="below",
)
fig.add_annotation(
x=0,
y=len(rows) - 0.5,
text="point forecast",
showarrow=False,
xanchor="center",
yanchor="bottom",
yshift=4,
font=dict(size=11, color="#1f1e1b"),
)
for i, r in enumerate(rows):
fig.add_shape(
type="rect",
x0=r["ci_low_pct"],
x1=r["ci_high_pct"],
y0=i - 0.3,
y1=i + 0.3,
fillcolor=_hex_to_rgba(ci_color, 0.35),
line=dict(width=0),
layer="below",
)
if r["seed_pcts"]:
fig.add_trace(
go.Scatter(
x=r["seed_pcts"],
y=[i] * len(r["seed_pcts"]),
mode="markers",
marker=dict(
size=7,
color=seed_color,
symbol="circle-open",
line=dict(color=seed_color, width=1.5),
),
showlegend=False,
text=r["seed_labels"],
hovertemplate=(
f"C={fmt_sci(r['C_nom'])}<br>"
"seed=%{text}<br>"
"deviation=%{x:.2f}%<extra></extra>"
),
)
)
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(
size=14, color=_hex_to_rgba(ci_color, 0.55), symbol="square"
),
name="Forecast 2.5–97.5% CI (bootstrap over bucket optima)",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(
size=7,
color=seed_color,
symbol="circle-open",
line=dict(color=seed_color, width=1.5),
),
name="Independent Training Runs",
showlegend=True,
)
)
tick_labels = [f"C = {fmt_sci(r['C_nom'])} FLOPs" for r in rows]
fig.update_yaxes(
tickmode="array",
tickvals=list(range(len(rows))),
ticktext=tick_labels,
range=[-0.5, len(rows) - 0.5],
showgrid=False,
)
fig.update_xaxes(
title_text="Loss − point forecast (% of point forecast)",
ticksuffix="%",
zeroline=False,
)
fig.update_layout(
margin=dict(t=40, r=30, b=90, l=90),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
legend=dict(
orientation="h",
yanchor="top",
y=-0.35,
xanchor="center",
x=0.5,
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
tracegroupgap=10,
),
)
return fig
# ---------------------------------------------------------------------------
# Hyperparameter sweep (section 3 sanity-check figure)
# ---------------------------------------------------------------------------
# Sweep eval metric matches the IsoFLOP ladder (paloma macro loss) so the
# section-3 figure shares its y-space with the scaling-law figure.
SWEEP_METRIC = "eval/paloma/macro_loss"
# Scaling reference point from experiments/scaling_law_sweeps/completed_adamh.py.
# Keep these in sync with CompletedAdamHHeuristic if the recipe is retuned.
SWEEP_SEQ_LEN = 4096
REF_BATCH = 64
REF_TOKENS = 2.5e9
LR_BASE = 0.00630
ADAM_LR_BASE = 0.000656
EPSILON_BASE = 1.85e-08
BETA2_BASE = 0.9999
BETA2_MIN = 0.9
BETA2_MAX = 0.9999
SWEEP_BATCH_SIZES = (8, 16, 32, 64)
SWEEP_HIDDEN_DIMS = (256, 512)
SWEEP_TOKEN_BUCKETS = ((1e9, "1B"), (2.5e9, "2.5B"), (5e9, "5B"))
# (column key, human label, log-scale y?, pre-log transform)
# The optional transform is applied before log — use it for quantities
# like β₂ where the natural scale is ln(1 − β₂).
SWEEP_HPARAMS = (
("lr", "Hyperball LR", True, None),
("adam_lr", "Adam LR", True, None),
("epsilon", "AdamH ε", True, None),
("beta1", "β₁", True, lambda x: 1.0 - x),
("beta2", "β₂", True, lambda x: 1.0 - x),
)
def bucket_tokens(t: float) -> str:
t = float(t)
if t < 1.75e9:
return "1B"
if t < 3.0e9:
return "2.5B"
return "5B"
def load_sweep_runs(cache_path: Path) -> pd.DataFrame | None:
"""Load the AdamH Vizier-sweep cache. Returns None if absent.
Missing cache is not fatal — the section-3 sanity-check figure is
gated on this file but the IsoFLOP figures should still build.
"""
if not cache_path.exists():
print(
f"sweep cache not found at {cache_path} — skipping hparam-scaling figure. "
"Run `uv run scripts/delphi/fetch_data.py` to populate it."
)
return None
rows = json.loads(cache_path.read_text())
if not rows:
print(f"sweep cache {cache_path} is empty — skipping hparam-scaling figure.")
return None
df = pd.DataFrame(rows)
need = ["state", SWEEP_METRIC, "batch_size", "hidden_dim", "tokens"]
hparam_cols = [c for c, *_ in SWEEP_HPARAMS]
missing = [c for c in need + hparam_cols if c not in df.columns]
if missing:
print(
f"sweep cache missing columns {missing} — skipping hparam-scaling figure."
)
return None
df = df[df["state"] == "finished"].dropna(subset=need + hparam_cols).copy()
if df.empty:
print("sweep cache has no finished runs — skipping hparam-scaling figure.")
return None
df["batch_size"] = df["batch_size"].round().astype(int)
df["hidden_dim"] = df["hidden_dim"].round().astype(int)
df["token_bucket"] = df["tokens"].apply(bucket_tokens)
return df
def compute_sweep_optima(df: pd.DataFrame) -> pd.DataFrame:
"""Pick the min-loss run per (batch_size, hidden_dim, token_bucket) cell."""
group_cols = ["batch_size", "hidden_dim", "token_bucket"]
idx = df.groupby(group_cols)[SWEEP_METRIC].idxmin()
return df.loc[idx].reset_index(drop=True)
def theory_hparam(key: str, batch_size: int, tokens: float) -> float:
"""Return the CompletedAdamHHeuristic-prescribed value of ``key``.
Formulas mirror completed_adamh.py. Kept as a local copy (not an
import) so the export script stays free of marin/levanter deps.
"""
# r/r0 = (B * T0) / (B0 * T); seq_len cancels.
ratio = (batch_size * REF_TOKENS) / (REF_BATCH * tokens)
if key == "lr":
return (
LR_BASE * math.sqrt(batch_size / REF_BATCH) * (REF_TOKENS / tokens) ** 0.3
)
if key == "adam_lr":
return ADAM_LR_BASE * math.sqrt(ratio)
if key == "epsilon":
return EPSILON_BASE * math.sqrt(1.0 / ratio)
if key == "beta1":
return 0.9 # fixed, not scaled
if key == "beta2":
return max(BETA2_MIN, min(BETA2_MAX, BETA2_BASE ** (batch_size / REF_BATCH)))
raise ValueError(f"unknown hparam key: {key}")
def _gp_pdp(
gp: object,
X: np.ndarray,
feature_idx: int,
grid_resolution: int = 80,
include_point: float | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute PDP with confidence band from a fitted GP.
For each grid value of ``feature_idx``, replaces that column in every
row of X, predicts with the GP, and averages across the rows —
following the Monte Carlo PDP definition from
Friedman (2001) / Moosbauer et al. (2021).
If ``include_point`` is provided, expand the grid so it also covers
that value. This is useful when we want to show a theory point even
when it lies outside the observed sweep support.
Returns (x_grid, pdp_mean, pdp_std).
"""
col = X[:, feature_idx]
x_lo = float(col.min())
x_hi = float(col.max())
if include_point is not None:
x_lo = min(x_lo, float(include_point))
x_hi = max(x_hi, float(include_point))
x_grid = np.linspace(x_lo, x_hi, grid_resolution)
means = np.empty((grid_resolution, X.shape[0]))
pdp_var = np.empty(grid_resolution)
for g_idx, g_val in enumerate(x_grid):
X_mod = X.copy()
X_mod[:, feature_idx] = g_val
mu, cov = gp.predict(X_mod, return_cov=True)
means[g_idx] = mu
# Var(mean(f)) = 1^T Cov(f) 1 / n^2, which is the mean entry of Cov(f).
pdp_var[g_idx] = max(float(cov.mean()), 0.0)
pdp_mean = means.mean(axis=1)
pdp_std = np.sqrt(pdp_var)
return x_grid, pdp_mean, pdp_std
def _format_axis_value(x: float) -> str:
x = float(x)
if x == 0:
return "0"
if 0.9 < abs(x) < 1:
return f"{x:.5g}"
if 1e-2 <= abs(x) < 1e3:
return f"{x:.3g}"
return fmt_sci(x)
def _log_axis_ticks(x_lo: float, x_hi: float) -> tuple[list[float], list[str]]:
"""Build human-readable tick labels for a log-transformed axis.
The plotted coordinate is ln(value), but the tick text should show the
underlying positive value itself.
"""
val_lo = float(np.exp(x_lo))
val_hi = float(np.exp(x_hi))
if not np.isfinite(val_lo) or not np.isfinite(val_hi) or val_lo <= 0 or val_hi <= 0:
return [], []
candidates: list[float] = []
exp_lo = int(np.floor(np.log10(val_lo))) - 1
exp_hi = int(np.ceil(np.log10(val_hi))) + 1
for exp in range(exp_lo, exp_hi + 1):
for mant in (1.0, 2.0, 5.0):
val = mant * (10.0**exp)
if val_lo <= val <= val_hi:
candidates.append(val)
if len(candidates) < 3:
candidates = np.geomspace(val_lo, val_hi, 4).tolist()
candidates = sorted(set(float(v) for v in candidates))
max_ticks = 5
if len(candidates) > max_ticks:
idxs = np.linspace(0, len(candidates) - 1, max_ticks)
candidates = [candidates[int(round(i))] for i in idxs]
candidates = sorted(set(float(v) for v in candidates))
return [float(np.log(v)) for v in candidates], [
_format_axis_value(v) for v in candidates
]
def _beta_axis_ticks(x_lo: float, x_hi: float) -> tuple[list[float], list[str]]:
"""Build human-readable tick labels for ln(1 - beta) axes.
The plotted coordinate is ln(1 - beta), but the tick text should show beta.
"""
beta_lo = float(1.0 - np.exp(x_hi))
beta_hi = float(1.0 - np.exp(x_lo))
if not np.isfinite(beta_lo) or not np.isfinite(beta_hi):
return [], []
beta_lo = max(0.0, min(beta_lo, 1.0 - 1e-12))
beta_hi = max(0.0, min(beta_hi, 1.0 - 1e-12))
candidates = [
0.5,
0.8,
0.9,
0.95,
0.98,
0.99,
0.995,
0.998,
0.999,
0.9995,
0.9998,
0.9999,
0.99995,
0.99998,
]
ticks = [b for b in candidates if beta_lo <= b <= beta_hi]
if len(ticks) < 3:
vals = np.geomspace(max(1.0 - beta_hi, 1e-12), max(1.0 - beta_lo, 1e-12), 4)
ticks = [float(1.0 - v) for v in vals[::-1]]
max_ticks = 5
if len(ticks) > max_ticks:
idxs = np.linspace(0, len(ticks) - 1, max_ticks)
ticks = [ticks[int(round(i))] for i in idxs]
ticks = sorted(set(float(v) for v in ticks))
return [float(np.log(max(1.0 - b, 1e-12))) for b in ticks], [
_format_axis_value(b) for b in ticks
]
def build_hparam_sweep_figure(sweep_df: pd.DataFrame) -> go.Figure:
"""Partial-dependence figure: per-hparam sensitivity via GP + PDP.
For each (batch_size, hidden_dim, token_bucket) cell, fits a Gaussian
process on the transformed hparams → loss, then computes PDP with
confidence bands following Moosbauer et al. (NeurIPS 2021). Each
subplot shows the PDP curve ± 1 σ band, the raw Vizier scatter, and
the CompletedAdamH-prescribed value marked when it lies within the
sweep support used to fit the GP.
"""
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel
from sklearn.preprocessing import StandardScaler
n_hparams = len(SWEEP_HPARAMS)
hparam_specs: list[tuple[str, str, bool, object, str]] = []
for key, label, log_scale, transform in SWEEP_HPARAMS:
axis_kind = "beta" if key in {"beta1", "beta2"} else "positive"
hparam_specs.append((key, label, log_scale, transform, axis_kind))
n_cols = 3
n_rows = (n_hparams + n_cols - 1) // n_cols
fig = make_subplots(
rows=n_rows,
cols=n_cols,
horizontal_spacing=0.10,
vertical_spacing=0.16,
)
# 4 traces per hparam per cell: rug, PDP curve, confidence band, theory marker.
traces_per_cell = 4 * n_hparams
trace_meta: list[tuple[int, int, str]] = []
extrapolated_theory_markers: list[tuple[int, int, str, str]] = []
rug_trace_meta: list[tuple[int, int]] = []
subplot_x_bounds = [
{"min": float("inf"), "max": float("-inf")} for _ in range(n_hparams)
]
subplot_y_bounds = [
{"min": float("inf"), "max": float("-inf")} for _ in range(n_hparams)
]
for B in SWEEP_BATCH_SIZES:
for hidden in SWEEP_HIDDEN_DIMS:
for nominal_tokens, bucket_label in SWEEP_TOKEN_BUCKETS:
cell = sweep_df[
(sweep_df["batch_size"] == B)
& (sweep_df["hidden_dim"] == hidden)
& (sweep_df["token_bucket"] == bucket_label)
].copy()
if cell.empty:
for _ in range(traces_per_cell):
trace_meta.append((B, hidden, bucket_label))
fig.add_trace(
go.Scatter(x=[], y=[], mode="markers", visible=False),
row=1,
col=1,
)
continue
# Drop diverged trials so the GP focuses on the well-behaved
# region of the loss surface.
loss_cap = float(cell[SWEEP_METRIC].quantile(0.95))
cell = cell[cell[SWEEP_METRIC] <= loss_cap].copy()
losses = cell[SWEEP_METRIC].to_numpy(float)
# Transform hparams to their natural scales.
transformed: dict[str, np.ndarray] = {}
theory_vals: dict[str, float] = {}
for key, _label, log_scale, transform, _axis_kind in hparam_specs:
vals = cell[key].to_numpy(float)
tv = theory_hparam(key, B, nominal_tokens)
if transform is not None:
vals = np.vectorize(transform)(vals)
tv = transform(tv)
if log_scale:
vals = np.log(np.clip(vals, 1e-20, None))
tv = float(np.log(max(tv, 1e-20)))
transformed[key] = vals
theory_vals[key] = tv
keys = [spec[0] for spec in hparam_specs]
X_raw = np.column_stack([transformed[k] for k in keys])
# Standardize features for stable GP fitting.
scaler = StandardScaler()
X = scaler.fit_transform(X_raw)
# Transform theory values into scaled space.
theory_scaled = scaler.transform(
np.array([[theory_vals[k] for k in keys]])
)[0]
# Fit GP with Matérn-5/2 kernel + noise term.
kernel = Matern(nu=2.5) + WhiteKernel(noise_level=0.01)
gp = GaussianProcessRegressor(
kernel=kernel,
n_restarts_optimizer=3,
alpha=1e-6,
random_state=42,
)
gp.fit(X, losses)
for hp_idx in range(n_hparams):
sp_row = hp_idx // n_cols + 1
sp_col = hp_idx % n_cols + 1
key = keys[hp_idx]
label = hparam_specs[hp_idx][1]
log_scale = hparam_specs[hp_idx][2]
# Theory: only show the rule value when it lies within the
# sweep support for this feature; otherwise the GP value is
# an extrapolation rather than a supported PDP evaluation.
tv_s = theory_scaled[hp_idx]
tv_raw = theory_vals[key]
x_data_raw = transformed[key]
x_data_scaled = X[:, hp_idx]
theory_in_support = (
float(x_data_scaled.min()) <= tv_s <= float(x_data_scaled.max())
)
if not theory_in_support:
extrapolated_theory_markers.append(
(B, hidden, bucket_label, key)
)
# PDP with confidence band. Expand the grid to include the
# theory point so the marker is always visible on-axis.
x_grid_s, pdp_mean, pdp_std = _gp_pdp(
gp, X, hp_idx, include_point=tv_s
)
# Map grid back to original (unscaled) space for plotting.
x_grid_raw = x_grid_s * scaler.scale_[hp_idx] + scaler.mean_[hp_idx]
subplot_x_bounds[hp_idx]["min"] = min(
subplot_x_bounds[hp_idx]["min"],
float(x_grid_raw.min()),
float(x_data_raw.min()),
tv_raw,
)
subplot_x_bounds[hp_idx]["max"] = max(
subplot_x_bounds[hp_idx]["max"],
float(x_grid_raw.max()),
float(x_data_raw.max()),
tv_raw,
)
upper = (pdp_mean + pdp_std).tolist()
lower = (pdp_mean - pdp_std).tolist()
X_theory = X.copy()
X_theory[:, hp_idx] = tv_s
theory_pd = float(gp.predict(X_theory).mean())
subplot_y_bounds[hp_idx]["min"] = min(
subplot_y_bounds[hp_idx]["min"],
float(min(lower)),
theory_pd,
)
subplot_y_bounds[hp_idx]["max"] = max(
subplot_y_bounds[hp_idx]["max"],
float(max(upper)),
theory_pd,
)
# 1) Vizier support rug. Re-anchor all rug traces after we
# have seen every cell so the rug sits on the subplot floor
# for every dropdown state, not at a data-derived y value.
fig.add_trace(
go.Scatter(
x=x_data_raw.tolist(),
y=[0.0] * len(x_data_raw),
mode="markers",
marker=dict(
color=BRAND_BLUE,
size=10,
opacity=0.35,
symbol="line-ns-open",
line=dict(color=BRAND_BLUE, width=1.5),
),
showlegend=False,
visible=False,
hoverinfo="skip",
),
row=sp_row,
col=sp_col,
)
rug_trace_meta.append((len(fig.data) - 1, hp_idx))
trace_meta.append((B, hidden, bucket_label))
# 2) PDP curve.
fig.add_trace(
go.Scatter(
x=x_grid_raw.tolist(),
y=pdp_mean.tolist(),
mode="lines",
line=dict(color=OA_PALETTE[0], width=2.5),
showlegend=False,
visible=False,
hoverinfo="skip",
),
row=sp_row,
col=sp_col,
)
trace_meta.append((B, hidden, bucket_label))
# 3) ±1σ confidence band (filled area).
fig.add_trace(
go.Scatter(
x=x_grid_raw.tolist() + x_grid_raw[::-1].tolist(),
y=upper + lower[::-1],
fill="toself",
fillcolor="rgba(56,92,143,0.15)",
line=dict(width=0),
showlegend=False,
visible=False,
hoverinfo="skip",
),
row=sp_row,
col=sp_col,
)
trace_meta.append((B, hidden, bucket_label))
# 4) Theory marker on the PDP curve.
fig.add_trace(
go.Scatter(
x=[tv_raw],
y=[theory_pd],
mode="markers",
marker=dict(
symbol="x",
size=11,
color=BRAND_BLACK,
line=dict(width=3, color=BRAND_BLACK),
),
showlegend=False,
visible=False,
customdata=[
[
float(1.0 - np.exp(tv_raw))
if key in {"beta1", "beta2"} and log_scale
else float(np.exp(tv_raw))
if log_scale
else tv_raw
]
],
hovertemplate=(
f"CompletedAdamH"
f"{' (extrapolated)' if not theory_in_support else ''}"
f"<br>{label}=%{{customdata[0]:.3g}}<br>"
"PDP loss=%{y:.4f}<extra></extra>"
),
),
row=sp_row,
col=sp_col,
)
trace_meta.append((B, hidden, bucket_label))
if extrapolated_theory_markers:
extrap_summary = ", ".join(
f"B={B} H={hidden} T={bucket} {key}"
for B, hidden, bucket, key in extrapolated_theory_markers[:6]
)
if len(extrapolated_theory_markers) > 6:
extrap_summary += ", ..."
print(
"showed "
f"{len(extrapolated_theory_markers)} CompletedAdamH markers outside "
f"sweep support; those points are GP extrapolations: {extrap_summary}"
)
for trace_idx, hp_idx in rug_trace_meta:
y_floor = subplot_y_bounds[hp_idx]["min"]
fig.data[trace_idx].y = [y_floor] * len(fig.data[trace_idx].x)
# Axis labels.
for hp_idx, (key, label, log_scale, _tr, axis_kind) in enumerate(hparam_specs):
sp_row = hp_idx // n_cols + 1
sp_col = hp_idx % n_cols + 1
fig.update_xaxes(title_text=label, row=sp_row, col=sp_col)
x_lo = subplot_x_bounds[hp_idx]["min"]
x_hi = subplot_x_bounds[hp_idx]["max"]
if log_scale:
if axis_kind == "beta":
tickvals, ticktext = _beta_axis_ticks(x_lo, x_hi)
else:
tickvals, ticktext = _log_axis_ticks(x_lo, x_hi)
fig.update_xaxes(
range=[x_lo, x_hi],
tickmode="array",
tickvals=tickvals,
ticktext=ticktext,
row=sp_row,
col=sp_col,
)
else:
fig.update_xaxes(range=[x_lo, x_hi], row=sp_row, col=sp_col)
y_floor = subplot_y_bounds[hp_idx]["min"]
y_ceil = subplot_y_bounds[hp_idx]["max"]
y_span = max(y_ceil - y_floor, 1e-6)
fig.update_yaxes(
range=[y_floor, y_ceil + 0.04 * y_span],
row=sp_row,
col=sp_col,
)
for r in range(1, n_rows + 1):
fig.update_yaxes(title_text="Paloma loss", row=r, col=1)
# Hide unused subplot cells (e.g. 5 hparams in a 2×3 grid leaves one empty).
for idx in range(n_hparams, n_rows * n_cols):
r = idx // n_cols + 1
c = idx % n_cols + 1
fig.update_xaxes(visible=False, row=r, col=c)
fig.update_yaxes(visible=False, row=r, col=c)
# Always-visible legend entries.
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(
color=BRAND_BLUE,
size=10,
opacity=0.35,
symbol="line-ns-open",
line=dict(color=BRAND_BLUE, width=1.5),
),
name="Sampled runs",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="lines",
line=dict(color=OA_PALETTE[0], width=2.5),
name="Sensitivity curve",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(
symbol="x",
size=11,
color=BRAND_BLACK,
line=dict(width=3, color=BRAND_BLACK),
),
name="Our Heuristic",
showlegend=True,
)
)
n_legend = 3
default_batch = SWEEP_BATCH_SIZES[-1]
default_hidden = SWEEP_HIDDEN_DIMS[-1]
default_bucket = SWEEP_TOKEN_BUCKETS[1][1]
default_batch_idx = SWEEP_BATCH_SIZES.index(default_batch)
default_hidden_idx = SWEEP_HIDDEN_DIMS.index(default_hidden)
default_bucket_idx = [b for _, b in SWEEP_TOKEN_BUCKETS].index(default_bucket)
def visibility_for(batch: int, hid: int, b: str) -> list[bool]:
vis = [
(meta[0] == batch) and (meta[1] == hid) and (meta[2] == b)
for meta in trace_meta
]
vis.extend([True] * n_legend)
return vis
for i, v in enumerate(
visibility_for(default_batch, default_hidden, default_bucket)
):
fig.data[i].visible = v
batch_buttons = [
dict(
label=f"B = {bs}",
method="update",
args=[{"visible": visibility_for(bs, default_hidden, default_bucket)}],
)
for bs in SWEEP_BATCH_SIZES
]
hidden_buttons = [
dict(
label=f"H = {h}",
method="update",
args=[{"visible": visibility_for(default_batch, h, default_bucket)}],
)
for h in SWEEP_HIDDEN_DIMS
]
bucket_buttons = [
dict(
label=f"T = {b}",
method="update",
args=[{"visible": visibility_for(default_batch, default_hidden, b)}],
)
for _, b in SWEEP_TOKEN_BUCKETS
]
fig.update_layout(
updatemenus=[
dict(
buttons=batch_buttons,
direction="down",
showactive=True,
active=default_batch_idx,
bgcolor="rgba(196,185,174,0.9)",
borderwidth=0,
x=0.00,
xanchor="left",
y=1.11,
yanchor="top",
),
dict(
buttons=hidden_buttons,
direction="down",
showactive=True,
active=default_hidden_idx,
bgcolor="rgba(196,185,174,0.9)",
borderwidth=0,
x=0.20,
xanchor="left",
y=1.11,
yanchor="top",
),
dict(
buttons=bucket_buttons,
direction="down",
showactive=True,
active=default_bucket_idx,
bgcolor="rgba(196,185,174,0.9)",
borderwidth=0,
x=0.42,
xanchor="left",
y=1.11,
yanchor="top",
),
],
annotations=[
dict(
text="Batch size",
x=0.00,
xref="paper",
xanchor="left",
y=1.1,
yref="paper",
yanchor="bottom",
showarrow=False,
font=dict(size=12),
),
dict(
text="Hidden dim",
x=0.20,
xref="paper",
xanchor="left",
y=1.1,
yref="paper",
yanchor="bottom",
showarrow=False,
font=dict(size=12),
),
dict(
text="Token horizon",
x=0.42,
xref="paper",
xanchor="left",
y=1.1,
yref="paper",
yanchor="bottom",
showarrow=False,
font=dict(size=12),
),
],
# Center the legend in panel 6 (the empty cell). In the desktop
# 2×3 grid that cell spans xaxis6 [0.733, 1.0] × yaxis6 [0, 0.42],
# so its center is at paper (0.867, 0.21). The mobile 3×2 layout
# shifts the empty cell, so `to_mobile_3x2` overrides this
# position.
legend=dict(
x=0.867,
xanchor="center",
y=0.21,
yanchor="middle",
bgcolor="rgba(196,185,174,0.8)",
borderwidth=0,
),
margin=dict(t=80, r=30, b=10, l=60),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig
# ---------------------------------------------------------------------------
# MMLU downstream eval (section 4: "Does this work for downstream tasks?")
# ---------------------------------------------------------------------------
# MMLU is a 4-way multiple choice benchmark, so random chance sits at
# 0.25 accuracy and the random-chance floor on mean log probability of
# the correct choice is log(1/4) ≈ -1.386.
MMLU_RANDOM_ACC = 0.25
MMLU_RANDOM_LOGPROB = math.log(1.0 / 4.0)
# Dropdown views for the downstream-evals figure. MMLU 0-shot / 5-shot
# are paired runs per compute budget (same accuracy + choice-logprob
# schema). HumanEval 10-shot is a logprob-only view at shot=10 — its
# right panel is empty until the pending generation-pass@1 runs land.
VIEW_COLORS = {
0: BRAND_TERRACOTTA,
5: BRAND_TERRACOTTA,
10: BRAND_TERRACOTTA,
15: BRAND_TERRACOTTA,
}
VIEW_LABELS = {
0: "MMLU 0-Shot",
5: "MMLU 5-Shot",
10: "HumanEval 10-Shot",
15: "GSM8K 5-Shot",
}
# Dropdown sentinels marking downstream eval rows. Not real shot counts —
# the values are just integers distinct from the real MMLU shot counts
# (0, 5). Each downstream view has a different data schema (macro_bpb
# vs accuracy/choice_logprob) so the figure branches on these keys when
# choosing metric, axis title, and panel layout.
HUMANEVAL_VIEW = 10
GSM8K_VIEW = 15
# Per-rung model identity for the Delphi ladder. IsoFLOP rungs (3e18 →
# 3e20) carry the AdamH-optimal architecture chosen for that budget;
# the validation rungs (1e21 → 1e23) are single optimal models named
# `adamh-scaling-ladder-nemotron-optimal-{budget}`. Used to enrich the
# Delphi-marker hover labels on the downstream-evals figure so a
# reader can see which checkpoint each point came from.
_DELPHI_MODEL_LABELS = {
3e18: "isoflop d1024 / L11 / B8",
9e18: "isoflop d1152 / L12 / B16",
2e19: "isoflop d1408 / L15 / B16",
3e19: "isoflop d1536 / L16 / B32",
9e19: "isoflop d1792 / L18 / B64",
2e20: "isoflop d2048 / L21 / B64",
3e20: "isoflop d2304 / L23 / B128",
1e21: "Delphi optimal 1e21",
1e22: "Delphi optimal 1e22",
1e23: "Delphi optimal 1e23",
}
def _delphi_model_label(budget: float) -> str:
"""Architecture / family label for a Delphi rung given its compute."""
for k, v in _DELPHI_MODEL_LABELS.items():
if abs(budget - k) / k < 0.05:
return v
return f"C={budget:.0e} FLOPs"
def load_mmlu_runs(cache_path: Path) -> pd.DataFrame | None:
"""Load the exp1337 MMLU eval cache. Returns None if absent.
Missing cache is not fatal — the downstream-evals figure is gated
on this file, but the rest of the figures should still build.
"""
if not cache_path.exists():
print(
f"mmlu cache not found at {cache_path} — skipping downstream-evals figure. "
"Run `uv run scripts/delphi/fetch_data.py` to populate it."
)
return None
rows = json.loads(cache_path.read_text())
if not rows:
print(f"mmlu cache {cache_path} is empty — skipping downstream-evals figure.")
return None
df = pd.DataFrame(rows)
need = ["budget", "shot", "accuracy", "choice_logprob"]
missing = [c for c in need if c not in df.columns]
if missing:
print(f"mmlu cache missing columns {missing} — skipping downstream-evals figure.")
return None
df = df.dropna(subset=need).copy()
df["shot"] = df["shot"].astype(int)
return df.sort_values(["shot", "budget"]).reset_index(drop=True)
def load_humaneval_runs(cache_path: Path) -> pd.DataFrame | None:
"""Load the exp1337 HumanEval logprob eval cache. Returns None if absent.
Missing cache is not fatal — the downstream-evals figure still
renders the MMLU views; only the HumanEval 10-shot dropdown option
is dropped.
"""
if not cache_path.exists():
print(
f"humaneval cache not found at {cache_path} — HumanEval dropdown "
"will be dropped. Run `uv run scripts/delphi/fetch_data.py` to populate."
)
return None
rows = json.loads(cache_path.read_text())
if not rows:
print(f"humaneval cache {cache_path} is empty — skipping HumanEval dropdown.")
return None
df = pd.DataFrame(rows)
need = ["budget", "macro_bpb"]
missing = [c for c in need if c not in df.columns]
if missing:
print(f"humaneval cache missing columns {missing} — skipping HumanEval dropdown.")
return None
df = df.dropna(subset=need).copy()
# pass_at_1 is optional — runs fetched before the generation eval
# landed leave it as NaN. Keep the column if present so the panel-2
# scatter can plot wherever we have it; otherwise synthesise an
# all-NaN column so downstream code doesn't branch on presence.
if "pass_at_1" not in df.columns:
df["pass_at_1"] = float("nan")
return df.sort_values("budget").reset_index(drop=True)
def load_humaneval_external(cache_path: Path) -> pd.DataFrame | None:
"""Load the supplemental pool of non-Marin HumanEval evaluations.
Parallel to `load_mmlu_external`: one row per externally released
base model with `macro_bpb` (logprob eval) and `pass_at_1`
(generation eval) plus a `display_name` suitable for hover labels.
Missing cache is not fatal — the HumanEval view falls back to the
pinned-L∞ asymptote fit and omits the grey scatter on panel 2.
"""
if not cache_path.exists():
print(
f"humaneval-external cache not found at {cache_path} — "
"HumanEval panel will use the pinned-L∞ asymptote fit and "
"skip the grey external scatter. Run "
"`uv run scripts/delphi/fetch_data.py` to populate."
)
return None
rows = json.loads(cache_path.read_text())
if not rows:
print(f"humaneval-external cache {cache_path} is empty.")
return None
df = pd.DataFrame(rows)
need = ["display_name", "macro_bpb", "pass_at_1"]
missing = [c for c in need if c not in df.columns]
if missing:
print(f"humaneval-external cache missing columns {missing} — skipping.")
return None
return df.dropna(subset=["macro_bpb"]).copy().reset_index(drop=True)
def load_gsm8k_external(cache_path: Path) -> pd.DataFrame | None:
"""Load the supplemental pool of non-Marin GSM8K evaluations.
Parallel to `load_humaneval_external`: one row per externally
released base model with `macro_bpb` (logprob) and
`exact_match_flex` (generation). Missing cache is not fatal —
the GSM8K view falls back to the pinned-L∞ asymptote fit and
omits the grey scatter on panel 2.
"""
if not cache_path.exists():
print(
f"gsm8k-external cache not found at {cache_path} — "
"GSM8K panel will use the pinned-L∞ asymptote fit and "
"skip the grey external scatter. Run "
"`uv run scripts/delphi/fetch_data.py` to populate."
)
return None
rows = json.loads(cache_path.read_text())
if not rows:
print(f"gsm8k-external cache {cache_path} is empty.")
return None
df = pd.DataFrame(rows)
need = ["display_name", "macro_bpb", "exact_match_flex"]
missing = [c for c in need if c not in df.columns]
if missing:
print(f"gsm8k-external cache missing columns {missing} — skipping.")
return None
return df.dropna(subset=["macro_bpb"]).copy().reset_index(drop=True)
def load_gsm8k_runs(cache_path: Path) -> pd.DataFrame | None:
"""Load the exp1337 GSM8K logprob + generation eval cache. None if absent.
Structurally identical to the HumanEval cache: one row per budget
with `macro_bpb` (logprob eval) and `exact_match_flex` / `_strict`
(generation eval). Missing cache is not fatal — the figure renders
the other views regardless.
"""
if not cache_path.exists():
print(
f"gsm8k cache not found at {cache_path} — GSM8K dropdown "
"will be dropped. Run `uv run scripts/delphi/fetch_data.py` to populate."
)
return None
rows = json.loads(cache_path.read_text())
if not rows:
print(f"gsm8k cache {cache_path} is empty — skipping GSM8K dropdown.")
return None
df = pd.DataFrame(rows)
need = ["budget", "macro_bpb"]
missing = [c for c in need if c not in df.columns]
if missing:
print(f"gsm8k cache missing columns {missing} — skipping GSM8K dropdown.")
return None
df = df.dropna(subset=need).copy()
if "exact_match_flex" not in df.columns:
df["exact_match_flex"] = float("nan")
return df.sort_values("budget").reset_index(drop=True)
def _prettify_external_name(name: str) -> str:
"""Turn a W&B eval-harness run name into a reader-friendly label.
Drops the eval runner's region / branch / hash metadata and returns
the model identity as something a reader would recognise from a
model card ("Qwen 3 Base 14B", "Llama 3.2 1B", "OLMo 2 13B"). Run
names from the comma-mix IsoFLOP scan are labelled by budget since
no user-facing name exists for those checkpoints.
Dashes in HF names double as both version separators and
decimal-point stand-ins (``Llama-3-2-1B`` means Llama 3.2 at 1B,
``Qwen3-0-6B`` means Qwen3 at 0.6B); we only restore the decimal
dot for the families where the pattern applies.
"""
m = re.match(r"^isoflop-(\d+e[+\-]\d+)-", name)
if m:
budget = m.group(1).replace("+", "")
return f"comma-mix @ {budget} FLOPs"
s = re.sub(r"^marin-us-\w+-", "", name)
s = re.sub(r"_lmeval.*$", "", s)
parts = s.split("--", 2)
if len(parts) < 2:
return name
model = parts[1]
# Family-specific rewrites — chosen by leading token of the model id.
if model.startswith("Qwen3-"):
rest = model[len("Qwen3-"):]
rest = re.sub(r"^(\d)-(\d+)B", r"\1.\2B", rest)
rest = re.sub(r"-Base$", "", rest)
# Promote "Qwen3" (no space in HF id) to "Qwen 3" and append
# "Base" so the label mirrors how the model is marketed.
size = rest.split("-")[0]
return f"Qwen 3 Base {size}"
if model.startswith("Llama-"):
# "Llama-3-2-1B" → ("3", "2", "1B"); "Llama-2-7b-hf" → ("2", None, "7b")
mver = re.match(
r"^Llama-(\d+)(?:-(\d+))?-(\d+[A-Za-z]+)", model
)
if mver:
major, minor, size = mver.groups()
size = size.upper() # "7b" → "7B"
version = f"{major}.{minor}" if minor else major
return f"Llama {version} {size}"
if model.startswith("OLMo-"):
# "OLMo-2-1124-13B-…" — the 4-digit chunk is an internal date
# code (Nov 2024) that readers don't need. Keep version + size.
mver = re.match(r"^OLMo-(\d+)-\d{4}-(\d+[A-Za-z]+)", model)
if mver:
return f"OLMo {mver.group(1)} {mver.group(2).upper()}"
if model.startswith("marin-"):
# "marin-8b-base" → "Marin 8B Base"
tokens = [t.capitalize() if t != "base" else "Base" for t in model.split("-")]
tokens = ["Marin" if t.lower() == "marin" else t for t in tokens]
tokens = [t.upper() if re.fullmatch(r"\d+b", t, re.I) else t for t in tokens]
return " ".join(tokens)
# Fallback: strip suffixes and return as-is.
return re.sub(r"-(?:Base|hf)$", "", model)
def load_mmlu_external(cache_path: Path) -> pd.DataFrame | None:
"""Load the supplemental non-Marin MMLU cache. Returns None if absent.
Missing cache is tolerated — the figure just drops the grey points.
The cache also includes ~85 comma-mix IsoFLOP runs that we filter
out here: those are Marin-internal scaling runs whose budgets
overlap our own ladder, so they don't add real "external pool"
information to the asymptote constraint or the sigmoid fit.
"""
if not cache_path.exists():
return None
rows = json.loads(cache_path.read_text())
if not rows:
return None
df = pd.DataFrame(rows)
need = ["accuracy", "choice_logprob"]
missing = [c for c in need if c not in df.columns]
if missing:
return None
df = df.dropna(subset=need).copy()
if "run_name" in df.columns:
df = df[~df["run_name"].str.contains("comma", case=False, na=False)]
return df.reset_index(drop=True)
def _fit_richards(
x: np.ndarray,
y: np.ndarray,
*,
floor: float,
ceiling: float | None = None,
) -> tuple[float, float, float, float] | None:
"""Fit ``y = floor + (C − floor) · σ(k·(x − x0))^v``. Returns (k, x0, v, C).
The ``v`` parameter is the Richards shape term — a plain symmetric
logistic (v=1) systematically over-shoots the long flat region near
the floor and under-shoots the steep rise that follows.
``floor`` is pinned by the caller (random-chance probability for
the metric — 0.25 for MMLU's 4-way MC, 0 for pass@1 / exact-match).
``ceiling`` is fit freely in [floor + 0.05, 1.0] by default so each
task picks up its own empirical saturation point. Callers can pin
it by passing ``ceiling=<float>``.
Returns None if scipy is unavailable or the optimiser fails (the
old MMLU-specific grid-search fallback is dropped — scipy is a
required dep for every other figure in this file).
"""
x = np.asarray(x, float)
y = np.asarray(y, float)
x_lo, x_hi = float(x.min()), float(x.max())
try:
from scipy.optimize import curve_fit # noqa: PLC0415
except Exception:
return None
if ceiling is None:
def model(xv, k, x0, v, C):
return floor + (C - floor) / (1.0 + np.exp(-k * (xv - x0))) ** v
p0 = [5.0, float(np.median(x)), 1.0, 0.95]
bounds = (
[0.1, x_lo - 2.0, 0.1, floor + 0.05],
[50.0, x_hi + 2.0, 20.0, 1.0],
)
try:
popt, _ = curve_fit(model, x, y, p0=p0, bounds=bounds, maxfev=20000)
return (
float(popt[0]),
float(popt[1]),
float(popt[2]),
float(popt[3]),
)
except Exception:
return None
def model_pinned(xv, k, x0, v):
return floor + (ceiling - floor) / (1.0 + np.exp(-k * (xv - x0))) ** v
p0 = [5.0, float(np.median(x)), 1.0]
bounds = ([0.1, x_lo - 2.0, 0.1], [50.0, x_hi + 2.0, 20.0])
try:
popt, _ = curve_fit(
model_pinned, x, y, p0=p0, bounds=bounds, maxfev=20000
)
return (
float(popt[0]),
float(popt[1]),
float(popt[2]),
float(ceiling),
)
except Exception:
return None
def _invert_richards(
k: float,
x0: float,
v: float,
*,
target: float,
floor: float,
ceiling: float,
) -> float:
"""Solve for ``x`` such that the Richards fit equals ``target``.
Inverse of `_fit_richards`'s forward model. Used to estimate where
the fitted curve crosses a near-ceiling threshold (typically
``ceiling − 1/n_problems``, the per-problem resolution of the
metric) so we can derive an asymptote estimate in x-space.
"""
s = ((target - floor) / (ceiling - floor)) ** (1.0 / v)
return float(x0 - np.log(1.0 / s - 1.0) / k)
def build_mmlu_emergence_figure(
df: pd.DataFrame,
external_df: pd.DataFrame | None = None,
humaneval_df: pd.DataFrame | None = None,
gsm8k_df: pd.DataFrame | None = None,
humaneval_external_df: pd.DataFrame | None = None,
gsm8k_external_df: pd.DataFrame | None = None,
) -> go.Figure:
"""Two panels: MMLU log-prob vs compute, MMLU accuracy vs log-prob.
Argument of the figure: the per-choice log probability improves
smoothly with compute at every scale, but accuracy (the thresholded
version of that signal) only leaves random-chance noise past a few
1e20 FLOPs — the classic "emergence" shape is a sigmoidal readout
of a smooth underlying signal.
Train/test split matches the rest of the blog: fit on the seven
IsoFLOP bucket evals at <1e21 FLOPs, and the 1e21/1e22/1e23 evals
are held-out extrapolation targets. Both panels carry a
"fit ← | → extrapolation" divider so the fit region is visible.
``external_df`` is the supplemental pool of non-Marin model scores
(Qwen / Llama / OLMo / Comma at 5-shot). They're plotted as grey
circles on panel 2, and enter the panel-1 asymptote pipeline
through the shared Richards fit (see `_fit_richards` +
`_invert_richards` — the same pipeline used by HumanEval/GSM8K).
The shot-count dropdown toggles between the 0-shot and 5-shot views;
external points only appear on the 5-shot view because that's the
only shot they logged.
``humaneval_df`` adds a third dropdown view (HumanEval 10-shot) that
replots panel 1 as ``macro_bpb`` vs compute. Panel 2 is empty on this
view until the pending pass@1 evals land; the button flips panel
titles, axis labels, and the MMLU-specific shape/annotation
visibility to keep the figure self-describing across views.
"""
# Fit on IsoFLOP-scale points, extrapolate to the held-out budgets.
# 5e20 sits between the last train point (3e20) and the first held
# point (1e21) — same divider as the scaling-law-asymptote figure.
held_lo = 5e20
train = df[df["budget"] < held_lo].copy()
test = df[df["budget"] >= held_lo].copy()
fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=False,
subplot_titles=(
"Log-prob improves smoothly with compute",
"Accuracy emerges at larger scales",
),
horizontal_spacing=0.14,
)
shots = sorted(df["shot"].unique().tolist())
# Panel 1 asymptote fit — same pipeline as the downstream-eval
# views: fit a Richards sigmoid on the (external pool + Delphi
# train) accuracy-vs-logprob pairs, invert at `ceiling − 1/n` to
# get the logprob value where the curve saturates to within the
# metric's per-problem resolution, then plug the resulting L∞ into
# `_fit_power_at_linf` on each shot's (C, −logprob) train points.
# The sigmoid is a single shared fit (logprob↔accuracy is a
# property of the eval, not the shot count) so L∞ is shared across
# shots too — no more bespoke `_fit_asymptote_pool_linf` grid.
# MMLU test set is 14042 questions across 57 subjects.
MMLU_N_PROBLEMS = 14042
train_per_shot: dict[int, tuple[np.ndarray, np.ndarray]] = {}
for shot in shots:
sub = train[train["shot"] == shot]
train_per_shot[int(shot)] = (
sub["budget"].to_numpy(float),
-sub["choice_logprob"].to_numpy(float),
)
if external_df is not None and len(external_df):
mmlu_richards_x = external_df["choice_logprob"].to_numpy(float)
mmlu_richards_y = external_df["accuracy"].to_numpy(float)
else:
mmlu_richards_x = train["choice_logprob"].to_numpy(float)
mmlu_richards_y = train["accuracy"].to_numpy(float)
mmlu_richards_fit = _fit_richards(
mmlu_richards_x, mmlu_richards_y,
floor=MMLU_RANDOM_ACC, ceiling=None,
)
pool_neg_lp = [-float(v) for v in df["choice_logprob"].tolist()]
if external_df is not None and len(external_df):
pool_neg_lp.extend(-float(v) for v in external_df["choice_logprob"].tolist())
pool_y_min = float(min(pool_neg_lp)) if pool_neg_lp else 0.0
if mmlu_richards_fit is not None:
k_m, x0_m, v_m, C_m = mmlu_richards_fit
threshold_m = C_m - 1.0 / MMLU_N_PROBLEMS
x_at_m = _invert_richards(
k_m, x0_m, v_m,
target=threshold_m, floor=MMLU_RANDOM_ACC, ceiling=C_m,
)
# Sigmoid fit is in logprob space (negative values); panel 1
# plots `-logprob` (positive). The asymptote in panel-1 space
# is therefore `-x_at`. Clamp into (0, pool_min] — the curve
# can't sit below zero (perfect log p = 0) and shouldn't sit
# above any observed model's -logprob.
linf_neg_lp_candidate = -x_at_m
if 0.0 < linf_neg_lp_candidate < pool_y_min:
linf_neg_lp = float(linf_neg_lp_candidate)
else:
linf_neg_lp = pool_y_min
else:
linf_neg_lp = pool_y_min
asym_fit: dict[int, dict] = {}
for shot in shots:
C_tr, y_tr = train_per_shot[int(shot)]
fit = _fit_power_at_linf(C_tr, y_tr, linf_neg_lp)
asym_fit[int(shot)] = fit if fit is not None else dict(_FAIL)
# Extend the compute axis out past 1e25 so the asymptote fit can
# flatten visibly past the last Delphi held-out run at 1e23 AND the
# 1e25-forecast marker (plus its text callout) isn't cropped at the
# right edge.
budgets_all = df["budget"].to_numpy(float)
c_min = float(budgets_all.min()) / 10**0.1
c_max = 3e25
log_c_grid = np.linspace(np.log10(c_min), np.log10(c_max), 400)
c_grid = 10**log_c_grid
# Divider + fit/extrapolation labels on panel 1 — mirrors the
# scaling-law-asymptote figure. Shapes take raw x; annotations on a
# log axis take log10(x).
fig.add_shape(
type="line",
x0=held_lo,
x1=held_lo,
xref="x",
y0=0,
y1=1,
yref="y domain",
line=dict(color="#1f1e1b", width=1, dash="dot"),
layer="below",
opacity=0.45,
)
held_lo_log = float(np.log10(held_lo))
fig.add_annotation(
x=held_lo_log,
y=0.98,
xref="x",
yref="y domain",
text="fit ← ",
showarrow=False,
xanchor="right",
yanchor="top",
xshift=-4,
font=dict(size=12, color="#1f1e1b"),
)
fig.add_annotation(
x=held_lo_log,
y=0.98,
xref="x",
yref="y domain",
text=" → extrapolation",
showarrow=False,
xanchor="left",
yanchor="top",
xshift=4,
font=dict(size=12, color="#1f1e1b"),
)
# Panel 2 — accuracy vs choice logprob. Divider lands at the midpoint
# between the max train logprob and the min test logprob, so the
# panel's fit/extrapolation labels line up with panel 1's compute
# divider. External grey points live to the right of the divider
# (they're bigger models than anything in train), so their addition
# only extends the extrapolation region.
lp_vals = df["choice_logprob"].to_numpy(float).tolist()
if external_df is not None and len(external_df):
lp_vals.extend(external_df["choice_logprob"].tolist())
x_lo_all = float(min(lp_vals))
x_hi_all = float(max(lp_vals))
pad = 0.08 * (x_hi_all - x_lo_all) if x_hi_all > x_lo_all else 0.2
x_curve = np.linspace(x_lo_all - pad, x_hi_all + pad, 400)
# Panel 2 has no fit/extrapolation divider: the sigmoid is fit on
# the external pool alone (where there is no clean train/test
# boundary), and Marin points are shown as observations rather than
# fit targets.
# Build traces per view, tagged so the dropdown can flip visibility.
# ``view_traces`` records the trace index for each dropdown view.
# Downstream-eval traces (HumanEval, GSM8K) are added later in the
# function but registered here so the dict has a stable key set.
view_traces: dict[int, list[int]] = {int(s): [] for s in shots}
if humaneval_df is not None and len(humaneval_df):
view_traces[HUMANEVAL_VIEW] = []
if gsm8k_df is not None and len(gsm8k_df):
view_traces[GSM8K_VIEW] = []
def _add(trace: go.Scatter, view_id: int, *, row: int, col: int) -> None:
fig.add_trace(trace, row=row, col=col)
view_traces[int(view_id)].append(len(fig.data) - 1)
default_shot = 5 if 5 in [int(s) for s in shots] else int(shots[0])
for i, shot in enumerate(shots):
color = VIEW_COLORS[int(shot)]
label = VIEW_LABELS[int(shot)]
is_default = int(shot) == default_shot
d = asym_fit.get(int(shot), dict(_FAIL))
if d["ok"]:
# Asymptote in -logprob space (positive, decreasing toward
# L_inf ≥ 0). Plot the positive value directly on the log
# axis; ticks negate for display.
y_neg = d["L_inf"] + d["A"] * (c_grid ** (-d["alpha"]))
_add(
go.Scatter(
x=c_grid,
y=y_neg,
mode="lines",
line=dict(color=color, width=1.5, dash="dash"),
name=f"{label} asymptote (α={d['alpha']:.3f}, L∞={d['L_inf']:.3f})",
legendgroup=label,
visible=is_default,
showlegend=False,
customdata=(-y_neg).reshape(-1, 1),
hovertemplate=(
f"{label} asymptote<br>C=%{{x:.3e}} FLOPs<br>"
"log p(correct)=%{customdata[0]:.3f}<extra></extra>"
),
),
int(shot),
row=1,
col=1,
)
train_sub = train[train["shot"] == shot].sort_values("budget")
train_neg_lp = -train_sub["choice_logprob"].to_numpy(float)
train_models = [
_delphi_model_label(b) for b in train_sub["budget"].to_numpy(float)
]
_add(
go.Scatter(
x=train_sub["budget"],
y=train_neg_lp,
mode="markers",
marker=dict(size=9, color=color, symbol="circle"),
name=f"{label} (Delphi fit)",
legendgroup=label,
visible=is_default,
customdata=np.column_stack([-train_neg_lp, np.array(train_models)]),
hovertemplate=(
f"{label} (fit)<br>"
"%{customdata[1]}<br>"
"C=%{x:.3e} FLOPs<br>"
"log p(correct)=%{customdata[0]:.3f}<extra></extra>"
),
),
int(shot),
row=1,
col=1,
)
test_sub = test[test["shot"] == shot].sort_values("budget")
if len(test_sub):
test_neg_lp = -test_sub["choice_logprob"].to_numpy(float)
test_C = test_sub["budget"].to_numpy(float)
test_models = [_delphi_model_label(b) for b in test_C]
# Forecast residual at each held-out budget, expressed as
# signed % deviation of the observed -logprob from what
# the per-shot asymptote fit predicts. Positive = observed
# log-prob is below (more negative than) the fit predicts;
# negative = the model beat the fit.
if d["ok"]:
pred = d["L_inf"] + d["A"] * (test_C ** (-d["alpha"]))
pct_err = (test_neg_lp - pred) / pred * 100.0
err_labels = [fmt_pct_err(float(e)) for e in pct_err]
else:
err_labels = [""] * len(test_sub)
_add(
go.Scatter(
x=test_C,
y=test_neg_lp,
mode="markers",
marker=dict(size=11, color=color, symbol="circle"),
name=f"{label} (Delphi held-out)",
legendgroup=label,
visible=is_default,
showlegend=False,
customdata=np.column_stack([-test_neg_lp, np.array(test_models)]),
hovertemplate=(
f"{label} (held-out)<br>"
"%{customdata[1]}<br>"
"C=%{x:.3e} FLOPs<br>"
"log p(correct)=%{customdata[0]:.3f}<extra></extra>"
),
),
int(shot),
row=1,
col=1,
)
# Leader lines + err labels, one segment per held-out
# point. Below-and-right offset on the log axes (multiply
# C by ~3.5, multiply -logprob by ~1.18). `None` separators
# break the polyline into independent segments so a single
# trace can carry all three leaders.
x_off, y_off = 1.7, 1.08
seg_x: list = []
seg_y: list = []
seg_text: list = []
for C_i, y_i, lbl in zip(test_C, test_neg_lp, err_labels):
seg_x += [float(C_i), float(C_i) * x_off, None]
seg_y += [float(y_i), float(y_i) * y_off, None]
seg_text += ["", lbl, ""]
_add(
go.Scatter(
x=seg_x,
y=seg_y,
mode="lines+text",
line=dict(color="rgba(80,80,80,0.55)", width=1),
text=seg_text,
textposition="middle right",
textfont=dict(size=11, color=BRAND_BLACK),
showlegend=False,
legendgroup=label,
visible=is_default,
hoverinfo="skip",
),
int(shot),
row=1,
col=1,
)
# Panel 2 Richards curve — reuses the shared sigmoid fit
# computed above. The logprob↔accuracy mapping is a property
# of the eval, not the shot count, so both shots plot the same
# curve (only the Delphi marker positions differ between shots).
if mmlu_richards_fit is None:
# No viable fit (typically means scipy failed or the pool
# was empty). Skip the curve — panel-2 markers alone still
# tell the emergence story.
k, x0, v, C = 0.0, 0.0, 1.0, 1.0
y_sig = np.full_like(x_curve, np.nan)
else:
k, x0, v, C = mmlu_richards_fit
y_sig = MMLU_RANDOM_ACC + (C - MMLU_RANDOM_ACC) / (
1.0 + np.exp(-k * (x_curve - x0))
) ** v
# Plot in -logprob (positive) space so the log-scale x-axis
# works. ``x_curve`` is in logprob space (negative) for the
# math; negate when handing to plotly.
_add(
go.Scatter(
x=-x_curve,
y=y_sig,
mode="lines",
line=dict(color=color, width=1.5, dash="dash"),
name=(
f"{label} Richards fit "
f"(k={k:.2f}, v={v:.2f}, C={C:.2f}, L∞={d['L_inf']:.3f})"
),
legendgroup=label,
visible=is_default,
showlegend=False,
customdata=x_curve.reshape(-1, 1),
hovertemplate=(
f"{label} Richards fit<br>"
"log p(correct)=%{customdata[0]:.3f}<br>"
"predicted acc=%{y:.3f}<extra></extra>"
),
),
int(shot),
row=1,
col=2,
)
if external_df is not None and len(external_df):
pretty_labels = [
_prettify_external_name(str(n))
for n in external_df["run_name"].tolist()
]
ext_lp = external_df["choice_logprob"].to_numpy(float)
_add(
go.Scatter(
x=-ext_lp,
y=external_df["accuracy"],
mode="markers",
marker=dict(
size=7,
color="#7a736b",
symbol="circle",
opacity=0.55,
line=dict(width=0),
),
name="External Models Used for Regression",
legendgroup="external",
visible=is_default,
customdata=np.column_stack([
np.array(pretty_labels), ext_lp,
]),
hovertemplate=(
"%{customdata[0]}<br>"
"log p(correct)=%{customdata[1]:.3f}<br>"
"accuracy=%{y:.3f}<extra></extra>"
),
),
int(shot),
row=1,
col=2,
)
train_sub_lp = train_sub.sort_values("choice_logprob")
train_lp_vals = train_sub_lp["choice_logprob"].to_numpy(float)
train_lp_models = [
_delphi_model_label(b) for b in train_sub_lp["budget"].to_numpy(float)
]
_add(
go.Scatter(
x=-train_lp_vals,
y=train_sub_lp["accuracy"],
mode="markers",
marker=dict(size=9, color=color, symbol="circle"),
name=f"{label} (Delphi fit)",
legendgroup=label,
visible=is_default,
showlegend=False,
customdata=np.column_stack([train_lp_vals, np.array(train_lp_models)]),
hovertemplate=(
f"{label} (fit)<br>"
"%{customdata[1]}<br>"
"log p(correct)=%{customdata[0]:.3f}<br>"
"accuracy=%{y:.3f}<extra></extra>"
),
),
int(shot),
row=1,
col=2,
)
if len(test_sub):
test_sub_lp = test_sub.sort_values("choice_logprob")
test_lp_vals = test_sub_lp["choice_logprob"].to_numpy(float)
test_lp_models = [
_delphi_model_label(b) for b in test_sub_lp["budget"].to_numpy(float)
]
_add(
go.Scatter(
x=-test_lp_vals,
y=test_sub_lp["accuracy"],
mode="markers",
marker=dict(size=11, color=color, symbol="circle"),
name=f"{label} (Delphi held-out)",
legendgroup=label,
visible=is_default,
showlegend=False,
customdata=np.column_stack([test_lp_vals, np.array(test_lp_models)]),
hovertemplate=(
f"{label} (held-out)<br>"
"%{customdata[1]}<br>"
"log p(correct)=%{customdata[0]:.3f}<br>"
"accuracy=%{y:.3f}<extra></extra>"
),
),
int(shot),
row=1,
col=2,
)
# 1e25 two-step forecast. Step 1: panel-1 power law extrapolates
# -logprob to C=1e25. Step 2: panel-2 Richards sigmoid maps the
# predicted logprob to accuracy. Shown as an open diamond on each
# panel so the reader can tell it apart from the held-out runs.
if d["ok"] and mmlu_richards_fit is not None:
C_forecast = 1e25
pred_neg_lp = float(
d["L_inf"] + d["A"] * (C_forecast ** (-d["alpha"]))
)
pred_logprob = -pred_neg_lp
k_m_f, x0_m_f, v_m_f, C_m_f = mmlu_richards_fit
pred_acc = float(
MMLU_RANDOM_ACC + (C_m_f - MMLU_RANDOM_ACC) / (
1.0 + np.exp(-k_m_f * (pred_logprob - x0_m_f))
) ** v_m_f
)
forecast_hover = (
f"{label} 1e25 forecast<br>C=1e25 FLOPs<br>"
f"log p(correct)={pred_logprob:.3f}<br>"
f"predicted acc={pred_acc:.3f}<extra></extra>"
)
_add(
go.Scatter(
x=[C_forecast],
y=[pred_neg_lp],
mode="markers",
marker=dict(
size=14, color=color, symbol="x-open",
line=dict(width=2, color=color),
),
name=f"{label} (1e25 forecast)",
legendgroup=label,
visible=is_default,
showlegend=False,
hovertemplate=forecast_hover,
),
int(shot),
row=1,
col=1,
)
_add(
go.Scatter(
x=[pred_neg_lp],
y=[pred_acc],
mode="markers+text",
marker=dict(
size=14, color=color, symbol="x-open",
line=dict(width=2, color=color),
),
text=[f"1e25 forecast: {pred_acc:.1%}"],
textposition="top left",
textfont=dict(size=11, color=BRAND_BLACK),
name=f"{label} (1e25 forecast)",
legendgroup=label,
visible=is_default,
showlegend=False,
hovertemplate=forecast_hover,
),
int(shot),
row=1,
col=2,
)
# Downstream-eval views (HumanEval, GSM8K) share the same two-panel
# shape: panel 1 is bits/byte vs compute with an asymptote fit pinned
# at a hand-picked L∞ (no external pool yet at these task scales),
# panel 2 is the generation accuracy (pass@1 / exact-match) plotted
# against bits/byte as a raw scatter (no sigmoid fit — the emergence
# shape is the thing the reader should see). The MMLU view still
# has its own bespoke logic above because its panel 2 sigmoid is fit
# on a large external pool, which is qualitatively different.
def _add_downstream_view(spec: dict) -> dict:
ds_df = spec["df"]
view_id = spec["view_id"]
color = VIEW_COLORS[view_id]
label = VIEW_LABELS[view_id]
bpb_col = spec["bpb_col"]
acc_col = spec["acc_col"]
acc_hover = spec["acc_hover"]
bpb_title = spec["bpb_title"]
acc_title = spec["acc_title"]
linf_pinned = spec["linf_pinned"]
ds_external_df = spec.get("external_df")
tr = ds_df[ds_df["budget"] < held_lo].sort_values("budget")
te = ds_df[ds_df["budget"] >= held_lo].sort_values("budget")
# Panel 1 — asymptote fit on train rows, scatter for train + test,
# leader lines + per-point percent-error callouts on the held-out
# rows. L∞ picking, in order of preference:
# 1. If we have both an external pool AND the hard-metric
# problem count (n_problems), fit a sigmoid on (bpb, acc)
# across Delphi train + pool with the accuracy ceiling
# pinned at 1, and invert at acc = 1 − 1/n_problems to get
# the bpb at which the fitted emergence curve saturates to
# within per-problem resolution. That's our asymptote
# estimate — principled via the hard-metric ceiling.
# 2. If we only have a pool, pin L∞ to the best observed bpb
# (empirical ceiling — asymptote can't exceed an already-
# achieved model).
# 3. Otherwise fall back to the hand-picked linf_pinned.
# Same pipeline as MMLU's panel 1 (`_fit_richards` +
# `_invert_richards` + `_fit_power_at_linf`) — each view just
# supplies its own floor, x-axis transform, and n_problems.
C_tr = tr["budget"].to_numpy(float)
y_tr = tr[bpb_col].to_numpy(float)
n_problems = spec.get("n_problems")
linf_sigmoid_fit: tuple[float, float, float] | None = None
linf_sigmoid_estimate: float | None = None
if ds_external_df is not None and len(ds_external_df):
sig_bpb = list(y_tr)
sig_acc = list(tr[acc_col].to_numpy(float))
sig_bpb.extend(ds_external_df[bpb_col].to_numpy(float))
sig_acc.extend(ds_external_df[acc_col].to_numpy(float))
sig_bpb_arr = np.asarray(sig_bpb, float)
sig_acc_arr = np.asarray(sig_acc, float)
mask = np.isfinite(sig_bpb_arr) & np.isfinite(sig_acc_arr)
if n_problems and mask.sum() >= 4:
linf_sigmoid_fit = _fit_richards(
-sig_bpb_arr[mask], sig_acc_arr[mask], floor=0.0, ceiling=None
)
pool_min_bpb = float(np.min(sig_bpb_arr[mask])) if mask.any() else float("nan")
if linf_sigmoid_fit is not None and np.isfinite(pool_min_bpb):
k_s, x0_s, v_s, C_s = linf_sigmoid_fit
# Per-problem resolution of the learned ceiling — i.e.
# "within 1/n_problems of where the fitted curve
# saturates." Using fitted C rather than 1.0 matters
# because the pool may cap below 100% (HumanEval's
# Qwen3-14B is at 52% pass@1, so C < 1 is typical).
threshold = C_s - 1.0 / float(n_problems)
x_at = _invert_richards(
k_s, x0_s, v_s, target=threshold, floor=0.0, ceiling=C_s
)
# x = -bpb in the sigmoid fit, so bpb_saturation = -x.
# Valid estimate only if it lands in (0, pool_min): the
# asymptote is physically bounded below by 0 and above
# by any model we've actually observed achieving that
# bpb. If the sigmoid extrapolates to an unphysical
# point (typical when the pool's accuracy max is far
# below the near-ceiling threshold), fall back to
# pool_min — the empirical ceiling — rather than
# collapsing to 0.
bpb_sat = -x_at
if 0.0 < bpb_sat < pool_min_bpb:
linf_sigmoid_estimate = float(bpb_sat)
else:
linf_sigmoid_estimate = pool_min_bpb
linf_use = linf_sigmoid_estimate
else:
linf_use = (
pool_min_bpb if np.isfinite(pool_min_bpb) else linf_pinned
)
else:
linf_use = linf_pinned
fit = _fit_power_at_linf(C_tr, y_tr, linf_use) or dict(_FAIL)
if fit["ok"]:
fit_y = fit["L_inf"] + fit["A"] * (c_grid ** (-fit["alpha"]))
_add(
go.Scatter(
x=c_grid,
y=fit_y,
mode="lines",
line=dict(color=color, width=1.5, dash="dash"),
name=(
f"{label} asymptote "
f"(α={fit['alpha']:.3f}, L∞={fit['L_inf']:.3f})"
),
legendgroup=label,
visible=False,
showlegend=False,
hovertemplate=(
f"{label} fit<br>C=%{{x:.3e}} FLOPs<br>"
"bits/byte=%{y:.3f}<extra></extra>"
),
),
view_id,
row=1,
col=1,
)
tr_models = [_delphi_model_label(b) for b in tr["budget"].to_numpy(float)]
_add(
go.Scatter(
x=tr["budget"],
y=y_tr,
mode="markers",
marker=dict(size=9, color=color, symbol="circle"),
name=f"{label} (Delphi fit)",
legendgroup=label,
visible=False,
customdata=np.array(tr_models).reshape(-1, 1),
hovertemplate=(
f"{label} (fit)<br>"
"%{customdata[0]}<br>"
"C=%{x:.3e} FLOPs<br>"
"bits/byte=%{y:.3f}<extra></extra>"
),
),
view_id,
row=1,
col=1,
)
if len(te):
C_te = te["budget"].to_numpy(float)
y_te = te[bpb_col].to_numpy(float)
te_models = [_delphi_model_label(b) for b in C_te]
if fit["ok"]:
pred = fit["L_inf"] + fit["A"] * (C_te ** (-fit["alpha"]))
pct_err = (y_te - pred) / pred * 100.0
err_labels = [fmt_pct_err(float(e)) for e in pct_err]
else:
err_labels = [""] * len(te)
_add(
go.Scatter(
x=C_te,
y=y_te,
mode="markers",
marker=dict(size=11, color=color, symbol="circle"),
name=f"{label} (Delphi held-out)",
legendgroup=label,
visible=False,
showlegend=False,
customdata=np.array(te_models).reshape(-1, 1),
hovertemplate=(
f"{label} (held-out)<br>"
"%{customdata[0]}<br>"
"C=%{x:.3e} FLOPs<br>"
"bits/byte=%{y:.3f}<extra></extra>"
),
),
view_id,
row=1,
col=1,
)
# Leader lines + err labels, same offsets as MMLU panel 1's
# callouts (both axes are log) so the arrows read consistently
# between dropdown views.
x_off, y_off = 1.7, 1.08
seg_x: list = []
seg_y: list = []
seg_text: list = []
for C_i, y_i, lbl in zip(C_te, y_te, err_labels):
seg_x += [float(C_i), float(C_i) * x_off, None]
seg_y += [float(y_i), float(y_i) * y_off, None]
seg_text += ["", lbl, ""]
_add(
go.Scatter(
x=seg_x,
y=seg_y,
mode="lines+text",
line=dict(color="rgba(80,80,80,0.55)", width=1),
text=seg_text,
textposition="middle right",
textfont=dict(size=11, color=BRAND_BLACK),
showlegend=False,
legendgroup=label,
visible=False,
hoverinfo="skip",
),
view_id,
row=1,
col=1,
)
# Panel 1 y-axis envelope (log scale, reversed so better/lower is
# toward the bottom — same orientation as the original HumanEval
# code). Panel 1's range only covers Delphi points; panel 2's
# range extends to cover external pool points too. Shared
# tickval candidate list; each axis filters to what lands inside
# its own range. The top of the axis is extended to the fitted
# L∞ so the asymptote tail at C→∞ stays visible — mirrors how
# the MMLU panel pins its top to 0.45 past the best observed
# -logprob. Log-space padding (0.06) matches MMLU's panel for
# visual consistency between the views.
bpb_all = ds_df[bpb_col].to_numpy(float)
bpb_min = float(bpb_all.min())
bpb_max = float(bpb_all.max())
bpb_low = min(bpb_min, float(linf_use)) if linf_use is not None else bpb_min
log_pad = 0.06
axis_low = bpb_low / (10 ** log_pad)
axis_high = bpb_max * (10 ** log_pad)
bpb_candidates = [0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 8.0]
panel1_tickvals = [
v for v in bpb_candidates if v >= axis_low and v <= axis_high
] or bpb_candidates
panel1_yaxis = dict(
title=dict(text=bpb_title),
range=[
float(np.log10(bpb_max)) + log_pad,
float(np.log10(bpb_low)) - log_pad,
],
tickvals=panel1_tickvals,
ticktext=[f"{v:g}" for v in panel1_tickvals],
)
# Panel 2's bpb envelope needs to cover the 1e25 forecast too,
# otherwise the forecast marker lands outside the x-axis range.
# The forecast sits between linf_use (C→∞) and the best observed
# Delphi/pool bpb, so using linf_use as a lower floor is safe.
panel2_bpb_min = min(bpb_min, float(linf_use)) if linf_use is not None else bpb_min
panel2_bpb_max = bpb_max
if ds_external_df is not None and len(ds_external_df):
ext_bpb_vals = ds_external_df[bpb_col].to_numpy(float)
panel2_bpb_min = min(panel2_bpb_min, float(ext_bpb_vals.min()))
panel2_bpb_max = max(panel2_bpb_max, float(ext_bpb_vals.max()))
panel2_tickvals = [
v for v in bpb_candidates
if v >= panel2_bpb_min * 0.85 and v <= panel2_bpb_max * 1.15
] or bpb_candidates
# Panel 2 — accuracy vs bits/byte. When a sigmoid fit on the
# (train + external) pool is available (from the L∞ estimation
# step above), draw the Richards curve as a dashed line so the
# reader can see the shape that produced the asymptote; then
# overlay train, held-out, and grey external points.
panel2_xaxis = None
panel2_yaxis = None
has_acc = False
acc_all = ds_df[acc_col].to_numpy(float)
acc_valid = list(acc_all[np.isfinite(acc_all)])
if ds_external_df is not None and len(ds_external_df):
ext_acc_vals = ds_external_df[acc_col].to_numpy(float)
acc_valid.extend(list(ext_acc_vals[np.isfinite(ext_acc_vals)]))
if len(acc_valid) > 0:
has_acc = True
tr_p = tr[tr[acc_col].notna()].sort_values(bpb_col)
te_p = te[te[acc_col].notna()].sort_values(bpb_col)
if linf_sigmoid_fit is not None:
k_s, x0_s, v_s, C_s = linf_sigmoid_fit
# Span a bit past the panel-2 x-axis envelope so the
# curve clearly approaches both asymptotes (0 on the
# left tail, learned ceiling C_s on the right).
curve_lo = min(panel2_bpb_min * 0.6, float(linf_use) * 0.5)
curve_hi = panel2_bpb_max * 1.30
curve_bpb = np.logspace(
np.log10(max(curve_lo, 1e-4)),
np.log10(curve_hi),
400,
)
curve_acc = C_s / (
1.0 + np.exp(-k_s * (-curve_bpb - x0_s))
) ** v_s
_add(
go.Scatter(
x=curve_bpb,
y=curve_acc,
mode="lines",
line=dict(color=color, width=1.5, dash="dash"),
name=(
f"{label} Richards fit "
f"(k={k_s:.2f}, v={v_s:.2f}, C={C_s:.2f}, L∞={linf_use:.3f})"
),
legendgroup=label,
visible=False,
showlegend=False,
customdata=curve_bpb.reshape(-1, 1),
hovertemplate=(
f"{label} Richards fit<br>"
"bits/byte=%{customdata[0]:.3f}<br>"
f"predicted {acc_hover}=%{{y:.3f}}<extra></extra>"
),
),
view_id,
row=1,
col=2,
)
if ds_external_df is not None and len(ds_external_df):
ext_p = ds_external_df[
ds_external_df[bpb_col].notna() & ds_external_df[acc_col].notna()
].sort_values(bpb_col)
if len(ext_p):
ext_bpb = ext_p[bpb_col].to_numpy(float)
ext_acc = ext_p[acc_col].to_numpy(float)
ext_names = ext_p["display_name"].astype(str).tolist()
_add(
go.Scatter(
x=ext_bpb,
y=ext_acc,
mode="markers",
marker=dict(
size=7,
color="#7a736b",
symbol="circle",
opacity=0.55,
line=dict(width=0),
),
name="External Models Used for Regression",
legendgroup="external",
visible=False,
customdata=np.column_stack([ext_bpb, np.array(ext_names)]),
hovertemplate=(
"%{customdata[1]}<br>"
"bits/byte=%{customdata[0]:.3f}<br>"
f"{acc_hover}=%{{y:.3f}}<extra></extra>"
),
),
view_id,
row=1,
col=2,
)
if len(tr_p):
tr_bpb = tr_p[bpb_col].to_numpy(float)
tr_acc = tr_p[acc_col].to_numpy(float)
tr_models_p2 = [_delphi_model_label(b) for b in tr_p["budget"].to_numpy(float)]
_add(
go.Scatter(
x=tr_bpb,
y=tr_acc,
mode="markers",
marker=dict(size=9, color=color, symbol="circle"),
name=f"{label} (Delphi fit)",
legendgroup=label,
visible=False,
showlegend=False,
customdata=np.column_stack([tr_bpb, np.array(tr_models_p2)]),
hovertemplate=(
f"{label} (fit)<br>"
"%{customdata[1]}<br>"
"bits/byte=%{customdata[0]:.3f}<br>"
f"{acc_hover}=%{{y:.3f}}<extra></extra>"
),
),
view_id,
row=1,
col=2,
)
if len(te_p):
te_bpb = te_p[bpb_col].to_numpy(float)
te_acc = te_p[acc_col].to_numpy(float)
te_models_p2 = [_delphi_model_label(b) for b in te_p["budget"].to_numpy(float)]
_add(
go.Scatter(
x=te_bpb,
y=te_acc,
mode="markers",
marker=dict(size=11, color=color, symbol="circle"),
name=f"{label} (Delphi held-out)",
legendgroup=label,
visible=False,
showlegend=False,
customdata=np.column_stack([te_bpb, np.array(te_models_p2)]),
hovertemplate=(
f"{label} (held-out)<br>"
"%{customdata[1]}<br>"
"bits/byte=%{customdata[0]:.3f}<br>"
f"{acc_hover}=%{{y:.3f}}<extra></extra>"
),
),
view_id,
row=1,
col=2,
)
# 1e25 two-step forecast — same style as the MMLU views.
# Step 1: panel-1 power law extrapolates bits/byte to C=1e25.
# Step 2: panel-2 Richards sigmoid maps the predicted bpb to
# the hard metric (pass@1 / exact-match).
if fit["ok"] and linf_sigmoid_fit is not None:
C_forecast = 1e25
pred_bpb = float(
fit["L_inf"] + fit["A"] * (C_forecast ** (-fit["alpha"]))
)
k_s_f, x0_s_f, v_s_f, C_s_f = linf_sigmoid_fit
pred_acc = float(
C_s_f / (
1.0 + np.exp(-k_s_f * (-pred_bpb - x0_s_f))
) ** v_s_f
)
forecast_hover = (
f"{label} 1e25 forecast<br>C=1e25 FLOPs<br>"
f"bits/byte={pred_bpb:.3f}<br>"
f"predicted {acc_hover}={pred_acc:.3f}<extra></extra>"
)
_add(
go.Scatter(
x=[C_forecast],
y=[pred_bpb],
mode="markers",
marker=dict(
size=14, color=color, symbol="x-open",
line=dict(width=2, color=color),
),
name=f"{label} (1e25 forecast)",
legendgroup=label,
visible=False,
showlegend=False,
hovertemplate=forecast_hover,
),
view_id,
row=1,
col=1,
)
_add(
go.Scatter(
x=[pred_bpb],
y=[pred_acc],
mode="markers+text",
marker=dict(
size=14, color=color, symbol="x-open",
line=dict(width=2, color=color),
),
text=[f"1e25 forecast: {pred_acc:.1%}"],
textposition="top left",
textfont=dict(size=11, color=BRAND_BLACK),
name=f"{label} (1e25 forecast)",
legendgroup=label,
visible=False,
showlegend=False,
hovertemplate=forecast_hover,
),
view_id,
row=1,
col=2,
)
# Panel 2 mirrors panel 1's bits/byte axis orientation on x,
# but its envelope extends to include external pool points
# (which typically sit past Delphi's best on the x axis and
# push accuracy higher on y).
panel2_xaxis = dict(
visible=True,
type="log",
title=dict(text=bpb_title),
range=[
float(np.log10(panel2_bpb_max * 1.15)),
float(np.log10(panel2_bpb_min * 0.85)),
],
tickvals=panel2_tickvals,
ticktext=[f"{v:g}" for v in panel2_tickvals],
)
# Y-axis floors at the hard-metric ceiling (1.0 for pass@1
# / exact-match) so the full emergence arc is visible —
# otherwise the axis would clip at ~75% and hide the
# headroom the sigmoid predicts. Matches MMLU's y-axis
# convention (floors at 1.0 even when observed max is ~55%).
acc_ceiling = float(spec.get("acc_ceiling", 1.0))
acc_top = float(max(max(acc_valid) * 1.08, acc_ceiling))
panel2_yaxis = dict(
visible=True,
type="linear",
title=dict(text=acc_title),
range=[-0.01, acc_top],
tickformat=".0%",
)
return {
"view_id": view_id,
"panel1_yaxis": panel1_yaxis,
"panel2_xaxis": panel2_xaxis,
"panel2_yaxis": panel2_yaxis,
"has_acc": has_acc,
"panel1_annot": spec["panel1_annot"],
"panel2_annot": (
spec["panel2_annot_with_acc"]
if has_acc
else spec["panel2_annot_without_acc"]
),
}
downstream_view_states: list[dict] = []
if humaneval_df is not None and len(humaneval_df):
downstream_view_states.append(_add_downstream_view({
"df": humaneval_df,
"external_df": humaneval_external_df,
"view_id": HUMANEVAL_VIEW,
"bpb_col": "macro_bpb",
"acc_col": "pass_at_1",
"acc_hover": "pass@1",
"bpb_title": "HumanEval bits/byte",
"acc_title": "HumanEval pass@1",
# HumanEval test set size — used to compute the "per-problem
# resolution" threshold (1 − 1/n_problems) for inverting the
# sigmoid fit on (bpb, pass@1) when estimating L∞ from the
# hard-metric ceiling (pass@1 ≤ 1).
"n_problems": 164,
# Fallback when no external pool is available. With the
# pool + n_problems present, L∞ is estimated from the
# sigmoid-ceiling inversion instead.
"linf_pinned": 0.2,
"panel1_annot": "Bits/byte improves smoothly with compute",
"panel2_annot_with_acc": "Pass@1 emerges at larger scales",
"panel2_annot_without_acc": "Pass@1 emergence (pending)",
}))
if gsm8k_df is not None and len(gsm8k_df):
downstream_view_states.append(_add_downstream_view({
"df": gsm8k_df,
"external_df": gsm8k_external_df,
"view_id": GSM8K_VIEW,
"bpb_col": "macro_bpb",
"acc_col": "exact_match_flex",
"acc_hover": "exact-match",
"bpb_title": "GSM8K bits/byte",
"acc_title": "GSM8K exact-match",
# GSM8K test set size (1319 problems) — controls the
# sigmoid-ceiling inversion threshold used to estimate L∞
# when the external pool is present.
"n_problems": 1319,
# Fallback L∞ for when the external pool cache is absent.
"linf_pinned": 0.2,
"panel1_annot": "Bits/byte improves smoothly with compute",
"panel2_annot_with_acc": "Exact-match emerges at larger scales",
"panel2_annot_without_acc": "Exact-match emergence (pending)",
}))
# Dropdown that shows one view at a time (MMLU 0-shot / 5-shot /
# HumanEval 10-shot). Trace visibility is the core flip; the
# HumanEval view additionally retitles panel 1, swaps its y-axis to
# bits/byte, hides panel 2's axes, and hides the MMLU-specific
# random-chance callout (annotations[4], shapes[1] — indices match
# the creation order in this function).
total_traces = len(fig.data)
# Panel-1 y-axis settings per view. MMLU range and ticks match the
# defaults applied via update_yaxes below; the button re-states them
# so flipping back from HumanEval restores the full configuration.
# Top extends to the fitted L∞ so the asymptote tail at C→∞ stays
# visible — same rule as the downstream (HumanEval / GSM8K) views.
mmlu_neg_lp_max = -float(df["choice_logprob"].min())
mmlu_neg_lp_min = min(-float(df["choice_logprob"].max()), float(linf_neg_lp))
mmlu_logprob_tickvals = [3.0, 2.0, 1.5, 1.0, 0.7, 0.5, 0.3]
mmlu_panel1_yaxis = dict(
title=dict(text="MMLU choice log-prob"),
range=[
float(np.log10(mmlu_neg_lp_max) + 0.06),
float(np.log10(mmlu_neg_lp_min) - 0.06),
],
tickvals=mmlu_logprob_tickvals,
ticktext=[f"−{v:g}" for v in mmlu_logprob_tickvals],
)
# Downstream-eval axis state is computed inside _add_downstream_view
# and keyed on view_id so the dropdown button loop can look up the
# right panel-1 y-axis / panel-2 x+y-axes per view.
downstream_state_by_view = {s["view_id"]: s for s in downstream_view_states}
# Panel-2 axis state on the MMLU views. Any downstream view mutates
# panel 2 from log(-logprob) + linear(accuracy) to log(bits/byte) +
# linear(accuracy), so the MMLU buttons have to restore both axes
# when the user flips back. Values here are recomputed to match the
# fig.update_xaxes / update_yaxes calls applied below.
# Extend the low (better) end to the fitted L∞ so the 1e25 forecast
# marker (which sits between the pool's best and the asymptote) stays
# within the axis range.
panel2_neg_lp_min_btn = min(-max(lp_vals), float(linf_neg_lp))
panel2_neg_lp_max_btn = -min(lp_vals)
mmlu_panel2_xaxis = dict(
visible=True,
type="log",
title=dict(text="MMLU choice log-prob"),
range=[
float(np.log10(panel2_neg_lp_max_btn) + 0.06),
float(np.log10(panel2_neg_lp_min_btn) - 0.06),
],
tickvals=mmlu_logprob_tickvals,
ticktext=[f"−{v:g}" for v in mmlu_logprob_tickvals],
)
acc_vals = df["accuracy"].to_numpy(float)
mmlu_panel2_yaxis = dict(
visible=True,
type="linear",
title=dict(text="MMLU accuracy"),
range=[0.2, float(max(acc_vals.max() * 1.08, 1.0))],
tickformat=".0%",
)
dropdown_views = list(shots) + [s["view_id"] for s in downstream_view_states]
buttons = []
for view in dropdown_views:
vis = [False] * total_traces
for idx in view_traces[int(view)]:
vis[idx] = True
ds_state = downstream_state_by_view.get(int(view))
if ds_state is not None:
layout_update = {
"annotations[0].text": ds_state["panel1_annot"],
"annotations[1].text": ds_state["panel2_annot"],
"annotations[4].visible": False,
"shapes[1].visible": False,
"yaxis.title.text": ds_state["panel1_yaxis"]["title"]["text"],
"yaxis.range": ds_state["panel1_yaxis"]["range"],
"yaxis.tickvals": ds_state["panel1_yaxis"]["tickvals"],
"yaxis.ticktext": ds_state["panel1_yaxis"]["ticktext"],
}
if ds_state["has_acc"]:
p2x = ds_state["panel2_xaxis"]
p2y = ds_state["panel2_yaxis"]
layout_update.update({
"xaxis2.visible": True,
"xaxis2.type": p2x["type"],
"xaxis2.title.text": p2x["title"]["text"],
"xaxis2.range": p2x["range"],
"xaxis2.tickvals": p2x["tickvals"],
"xaxis2.ticktext": p2x["ticktext"],
"yaxis2.visible": True,
"yaxis2.type": p2y["type"],
"yaxis2.title.text": p2y["title"]["text"],
"yaxis2.range": p2y["range"],
"yaxis2.tickformat": p2y["tickformat"],
})
else:
layout_update["xaxis2.visible"] = False
layout_update["yaxis2.visible"] = False
else:
layout_update = {
"annotations[0].text": "Log-prob improves smoothly with compute",
"annotations[1].text": "Accuracy emerges at larger scales",
"annotations[4].visible": True,
"shapes[1].visible": True,
"yaxis.title.text": mmlu_panel1_yaxis["title"]["text"],
"yaxis.range": mmlu_panel1_yaxis["range"],
"yaxis.tickvals": mmlu_panel1_yaxis["tickvals"],
"yaxis.ticktext": mmlu_panel1_yaxis["ticktext"],
"xaxis2.visible": True,
"xaxis2.type": mmlu_panel2_xaxis["type"],
"xaxis2.title.text": mmlu_panel2_xaxis["title"]["text"],
"xaxis2.range": mmlu_panel2_xaxis["range"],
"xaxis2.tickvals": mmlu_panel2_xaxis["tickvals"],
"xaxis2.ticktext": mmlu_panel2_xaxis["ticktext"],
"yaxis2.visible": True,
"yaxis2.type": mmlu_panel2_yaxis["type"],
"yaxis2.title.text": mmlu_panel2_yaxis["title"]["text"],
"yaxis2.range": mmlu_panel2_yaxis["range"],
"yaxis2.tickformat": mmlu_panel2_yaxis["tickformat"],
}
buttons.append(
dict(
label=VIEW_LABELS[int(view)],
method="update",
args=[{"visible": vis}, layout_update],
)
)
default_button_idx = next(
(i for i, s in enumerate(dropdown_views) if int(s) == default_shot), 0
)
fig.update_layout(
updatemenus=[
dict(
buttons=buttons,
# Expand upward so the options don't run off the bottom
# of the SVG and trigger plotly's auto-scroll fallback.
direction="up",
showactive=True,
active=default_button_idx,
x=-0.16,
xanchor="left",
y=-0.16,
yanchor="top",
bgcolor="rgba(196,185,174,0.9)",
borderwidth=0,
pad=dict(l=6, r=6, t=4, b=4),
)
]
)
# Random-chance accuracy floor on panel 2.
fig.add_hline(
y=MMLU_RANDOM_ACC,
line=dict(color="#1F1E1B", width=1, dash="dot"),
opacity=0.45,
row=1,
col=2,
)
fig.add_annotation(
xref="x2 domain",
yref="y2",
x=0.98,
y=MMLU_RANDOM_ACC,
text="random chance",
showarrow=False,
xanchor="right",
yanchor="top",
font=dict(size=11, color="#1F1E1B"),
)
fig.update_xaxes(
type="log",
title_text="Compute (FLOPs)",
range=[np.log10(c_min), np.log10(c_max)],
row=1,
col=1,
)
# Both log-prob axes plot -logprob (positive) on a log scale, with
# tick labels negated back so the axis still reads as the original
# log p(correct) values. Reversed orientation keeps "better"
# (closer to 0) at the high end of the axis.
logprob_tickvals = [3.0, 2.0, 1.5, 1.0, 0.7, 0.5, 0.3]
logprob_ticktext = [f"−{v:g}" for v in logprob_tickvals]
# Explicit log10-space range for both log-prob axes — plotly's
# autorange interaction with reversed log axes can drift, so we
# clamp to the data envelope. The high end of the original log-prob
# value (closest to 0, ie. small -logprob) sits at the high end of
# the displayed axis; the low end (very negative log-prob, ie.
# large -logprob) sits at the bottom.
panel1_neg_lp_min = -float(df["choice_logprob"].max()) # best Delphi
panel1_neg_lp_max = -float(df["choice_logprob"].min()) # worst Delphi
# Extend the top of the axis to the fitted L∞ so the asymptote tail
# at C→∞ stays visible — same rule as the downstream views.
panel1_neg_lp_min = min(panel1_neg_lp_min, float(linf_neg_lp))
# Mild log-space padding so markers don't sit flush against the
# axis edges and the leader-line+label callouts have headroom.
log_pad = 0.06
fig.update_yaxes(
type="log",
title_text="MMLU choice log-prob",
range=[
np.log10(panel1_neg_lp_max) + log_pad,
np.log10(panel1_neg_lp_min) - log_pad,
],
tickvals=logprob_tickvals,
ticktext=logprob_ticktext,
row=1,
col=1,
)
# Extend the low (better) end to the fitted L∞ so the 1e25 forecast
# marker stays within the axis range.
panel2_neg_lp_min = min(-max(lp_vals), float(linf_neg_lp)) # best in pool / L∞
panel2_neg_lp_max = -min(lp_vals) # worst in pool
fig.update_xaxes(
type="log",
title_text="MMLU choice log-prob",
range=[
np.log10(panel2_neg_lp_max) + log_pad,
np.log10(panel2_neg_lp_min) - log_pad,
],
tickvals=logprob_tickvals,
ticktext=logprob_ticktext,
row=1,
col=2,
)
fig.update_yaxes(
title_text="MMLU accuracy",
tickformat=".0%",
range=[0.2, float(max(acc_vals.max() * 1.08, 1.0))],
row=1,
col=2,
)
fig.update_layout(
legend=dict(
orientation="h",
x=0.62,
xanchor="center",
y=-0.18,
yanchor="top",
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
tracegroupgap=12,
),
margin=dict(t=40, r=30, b=60, l=60),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig
def to_mobile_mmlu_stack(fig: go.Figure) -> go.Figure:
"""Stack the 1×2 MMLU panels vertically for narrow viewports."""
mfig = go.Figure(fig)
mfig.update_layout(
xaxis=dict(domain=[0.0, 1.0]),
xaxis2=dict(domain=[0.0, 1.0]),
yaxis=dict(domain=[0.60, 0.98]),
yaxis2=dict(domain=[0.0, 0.38]),
height=760,
# Extra bottom margin holds the legend + dropdown row stack;
# narrow viewports don't have horizontal room to sit them beside
# each other the way the desktop layout does.
margin=dict(t=32, r=24, b=110, l=56),
legend=dict(
orientation="h",
x=0.5,
xanchor="center",
y=-0.08,
yanchor="top",
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
),
)
# Recenter the two subplot titles on the stacked layout.
anns = list(mfig.layout.annotations or [])
paper_idx = 0
for a in anns:
if a.xref == "paper" and a.yref == "paper" and paper_idx < 2:
a.update(
x=0.5,
y=0.99 if paper_idx == 0 else 0.43,
xanchor="center",
yanchor="middle",
)
paper_idx += 1
mfig.layout.annotations = tuple(anns)
# Drop the dropdown fully beneath the legend. The desktop layout
# puts it on the left beside the centered legend, but mobile doesn't
# have the horizontal room — the wrapped legend rows collide with
# the dropdown button at that width.
menus = list(mfig.layout.updatemenus or [])
for menu in menus:
menu.update(x=0.5, xanchor="center", y=-0.22, yanchor="top")
mfig.layout.updatemenus = tuple(menus)
return mfig
# ---------------------------------------------------------------------------
# Export
# ---------------------------------------------------------------------------
def strip_for_blog(fig: go.Figure, *, keep_height: bool = False) -> dict:
"""Drop layout keys that the blog's shortcode supplies itself.
The blog applies its own `layout.template`, responsive sizing, and margins
via `templates/blog_post.html`, so we strip those here to avoid fighting
with the page theme. Mobile variants pass `keep_height=True` so the tall
stacked-layout height survives the strip.
"""
d = json.loads(pio.to_json(fig))
layout = d.get("layout", {}) or {}
# We keep plot_bgcolor/paper_bgcolor so figures can override the page
# theme — both builders set them to transparent so plots sit directly on
# the tan page without a cream frame.
drop = {"template", "width", "font"}
if not keep_height:
drop.add("height")
for key in drop:
layout.pop(key, None)
d["layout"] = layout
return d
def write_figure(fig: go.Figure, name: str, *, keep_height: bool = False) -> Path:
OUT_DIR.mkdir(parents=True, exist_ok=True)
path = OUT_DIR / f"{name}.json"
path.write_text(json.dumps(strip_for_blog(fig, keep_height=keep_height), indent=2))
print(f"wrote {path.relative_to(Path(__file__).resolve().parent.parent.parent)}")
return path
def to_mobile_stack(fig: go.Figure) -> go.Figure:
"""Restack a 1×2 ladder figure into a 2×1 mobile layout.
Flips both x-axes to full width, splits the y-axes into top/bottom
halves with a gap between them, recenters the first two paper-ref
panel-title annotations, and sets an explicit height so the two
panels each have room. Callers should write the result with
`write_figure(..., keep_height=True)` so the height survives the
blog's layout strip.
"""
mfig = go.Figure(fig)
mfig.update_layout(
xaxis=dict(domain=[0.0, 1.0]),
xaxis2=dict(domain=[0.0, 1.0]),
# Top panel [0.60, 0.98]; bottom panel [0, 0.38]; gap [0.38, 0.60]
# holds panel 1's x-axis ticks+title (upper half) and panel 2's
# title (lower half).
yaxis=dict(domain=[0.60, 0.98]),
yaxis2=dict(domain=[0.0, 0.38]),
height=660,
# Bottom margin holds the 3-column / 3-row legend below panel 2's
# x-axis label.
margin=dict(t=24, r=24, b=32, l=56),
legend=dict(
# Drop the legend title on mobile — panel 2's x-axis is
# already labeled "Compute (FLOPs)", so repeating it here
# just collides with that label.
title=dict(text=""),
orientation="h",
yanchor="top",
y=-0.12,
xanchor="right",
x=1.0,
# 3 entries per row. Plotly reserves internal padding per
# entry, so 0.34 fraction (which should tile 3×) rounds down
# to 2 per row in practice; 0.30 gives 3 reliably.
entrywidth=0.30,
entrywidthmode="fraction",
tracegroupgap=4,
bgcolor="rgba(0,0,0,0)",
borderwidth=0,
),
)
anns = list(mfig.layout.annotations or [])
paper_idx = 0
for a in anns:
if a.xref == "paper" and a.yref == "paper" and paper_idx < 2:
# Panel 1 title in the 0.02 strip above the top panel.
# Panel 2 title in the lower half of the gap, just above the
# top of panel 2, clear of panel 1's x-axis title.
a.update(
x=0.5,
y=0.99 if paper_idx == 0 else 0.43,
xanchor="center",
yanchor="middle",
)
paper_idx += 1
mfig.layout.annotations = tuple(anns)
return mfig
def to_mobile_3x2(fig: go.Figure) -> go.Figure:
"""Reshape a 2×3 hparam grid into a 3×2 mobile layout.
Original grid is two rows of three panels. Mobile viewports squash
each panel to ~90px wide, so we reflow to three rows of two. Panel
order (reading row-major) is preserved: (x1,y1) stays top-left,
(x2,y2) top-right, (x3,y3) middle-left, …, (x6,y6) bottom-right.
Beyond the domain reshuffle we also:
- Move "Paloma loss" onto the left-column y-axes (yaxis, yaxis3,
yaxis5) and clear it from the right column (yaxis4 inherited it
from the old 2×3 layout where it was on the left).
- Stack the three dropdowns vertically, one per row, so their
labels stop colliding.
- Move the legend below the plot so it doesn't land on top of the
stacked dropdowns.
"""
mfig = go.Figure(fig)
# Horizontal: two columns with a 0.10 paper gap between them.
col_left = [0.0, 0.45]
col_right = [0.55, 1.0]
# Vertical: three rows with 0.08 paper gap between rows for axis
# ticks + axis titles.
row_top = [0.72, 1.0]
row_mid = [0.36, 0.64]
row_bot = [0.0, 0.28]
# Force a uniform diagonal tick angle on every x-axis. At mobile
# panel widths Plotly auto-picks per-axis (some 45°, some 90°) and
# the mismatch looks scruffy; pinning to 45° keeps them consistent.
tick_ax = dict(tickangle=45)
mfig.update_layout(
xaxis=dict(domain=col_left, **tick_ax),
xaxis2=dict(domain=col_right, **tick_ax),
xaxis3=dict(domain=col_left, **tick_ax),
xaxis4=dict(domain=col_right, **tick_ax),
xaxis5=dict(domain=col_left, **tick_ax),
xaxis6=dict(domain=col_right, **tick_ax),
yaxis=dict(domain=row_top, title=dict(text="Paloma loss")),
yaxis2=dict(domain=row_top, title=dict(text=None)),
yaxis3=dict(domain=row_mid, title=dict(text="Paloma loss")),
yaxis4=dict(domain=row_mid, title=dict(text=None)),
yaxis5=dict(domain=row_bot, title=dict(text="Paloma loss")),
yaxis6=dict(domain=row_bot, title=dict(text=None)),
height=900,
# Small top margin holds the dropdown row just above the plot.
margin=dict(t=36, r=0, b=0, l=0),
# Mobile's empty cell (panel 6 in the 3×2 reflow) is in a
# different place than desktop's, and the user asked to push the
# legend further into the right margin on mobile specifically.
legend=dict(
x=1.02,
xanchor="right",
y=0.10,
yanchor="bottom",
bgcolor="rgba(196,185,174,0.8)",
borderwidth=0,
),
)
# Row the three dropdowns across the top of the figure. Their
# selected-option text is already self-describing (B = 64, H = 512,
# T = 2.5B), so the separate paper-ref label annotations that the
# desktop layout puts over each dropdown are dropped on mobile —
# they only crowd the top margin at this width.
dropdown_xs = (0.0, 0.36, 0.70)
menus = list(mfig.layout.updatemenus or [])
for menu, x in zip(menus, dropdown_xs):
# yanchor="bottom" with y=1.0 puts the dropdown's bottom edge
# flush against the plot-area top, so the body sits entirely in
# the top margin touching the chart.
menu.update(x=x, xanchor="left", y=1.0, yanchor="bottom")
mfig.layout.updatemenus = tuple(menus)
anns = [
a for a in (mfig.layout.annotations or [])
if not (a.xref == "paper" and a.yref == "paper")
]
mfig.layout.annotations = tuple(anns)
return mfig
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
ap.add_argument(
"--in",
dest="cache",
type=Path,
default=DEFAULT_CACHE,
help=f"IsoFLOP cache path (default: {DEFAULT_CACHE})",
)
ap.add_argument(
"--sweep-in",
dest="sweep_cache",
type=Path,
default=DEFAULT_SWEEP_CACHE,
help=f"hparam-sweep cache path (default: {DEFAULT_SWEEP_CACHE})",
)
args = ap.parse_args()
df = load_runs(args.cache)
print(f"loaded {len(df)} runs across {df['quantizer'].nunique()} optimizers")
df = subsample_cadamc_to_adamh_grid(df)
print(f"after grid-matching: {len(df)} runs")
fits, optima, gflop_groups, unique_quantizers = compute_optima(df)
print(f"fit {len(fits)} parabolas; {len(optima)} optima rows")
optimal_runs = df[df["run_type"] == "optimal"].copy()
write_figure(
build_isoflop_figure(df, fits, gflop_groups, unique_quantizers),
"isoflop-parabolas",
)
write_figure(
build_scaling_law_figure(optima, optimal_runs), "scaling-law-asymptote"
)
first_ladder = build_ladder_figure(
df, fits, gflop_groups, optima, optimal_runs, C_ADAMC_LABEL
)
write_figure(first_ladder, "first-ladder")
write_figure(to_mobile_stack(first_ladder), "first-ladder-mobile", keep_height=True)
delphi_ladder = build_ladder_figure(
df, fits, gflop_groups, optima, optimal_runs, ADAMH_LABEL
)
write_figure(delphi_ladder, "delphi-ladder")
write_figure(to_mobile_stack(delphi_ladder), "delphi-ladder-mobile", keep_height=True)
write_figure(build_lucky_seeds_figure(df, optima), "lucky-seeds")
write_figure(
build_overtraining_forecast_figure(df, fits, gflop_groups, optima),
"overtraining-forecast",
)
# Downstream MMLU figure (section 4). Gated on the MMLU cache;
# the non-Marin pool and downstream-eval caches are optional.
mmlu_df = load_mmlu_runs(DEFAULT_MMLU_CACHE)
external_df = load_mmlu_external(DEFAULT_MMLU_EXTERNAL_CACHE)
humaneval_df = load_humaneval_runs(DEFAULT_HUMANEVAL_CACHE)
humaneval_external_df = load_humaneval_external(DEFAULT_HUMANEVAL_EXTERNAL_CACHE)
gsm8k_df = load_gsm8k_runs(DEFAULT_GSM8K_CACHE)
gsm8k_external_df = load_gsm8k_external(DEFAULT_GSM8K_EXTERNAL_CACHE)
if mmlu_df is not None:
print(
f"loaded {len(mmlu_df)} MMLU evals "
f"({mmlu_df['shot'].nunique()} shot conditions × "
f"{mmlu_df['budget'].nunique()} budgets)"
)
if external_df is not None:
print(f"loaded {len(external_df)} external MMLU runs for panel 2")
if humaneval_df is not None:
print(f"loaded {len(humaneval_df)} HumanEval 10-shot logprob evals")
if humaneval_external_df is not None:
print(
f"loaded {len(humaneval_external_df)} external HumanEval runs for panel 2"
)
if gsm8k_df is not None:
print(f"loaded {len(gsm8k_df)} GSM8K 5-shot logprob+gen evals")
if gsm8k_external_df is not None:
print(
f"loaded {len(gsm8k_external_df)} external GSM8K runs for panel 2"
)
mmlu_fig = build_mmlu_emergence_figure(
mmlu_df, external_df, humaneval_df, gsm8k_df,
humaneval_external_df=humaneval_external_df,
gsm8k_external_df=gsm8k_external_df,
)
write_figure(mmlu_fig, "mmlu-emergence")
write_figure(
to_mobile_mmlu_stack(mmlu_fig), "mmlu-emergence-mobile", keep_height=True
)
# Hyperparameter-scaling sanity-check figure (section 3). Last in the
# export order because it's the slowest — the other figures land on
# disk first so the watcher can pick them up while this one is still
# building. Gated on the sweep cache — if it's missing we still
# build the IsoFLOP figures above, and `load_sweep_runs` prints a
# hint to the user.
sweep_df = load_sweep_runs(args.sweep_cache)
if sweep_df is not None:
n_slices = sweep_df.groupby(["hidden_dim", "token_bucket"]).ngroups
print(f"loaded {len(sweep_df)} sweep runs across {n_slices} slices")
hparam = build_hparam_sweep_figure(sweep_df)
write_figure(hparam, "hparam-scaling")
write_figure(to_mobile_3x2(hparam), "hparam-scaling-mobile", keep_height=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment