Skip to content

Instantly share code, notes, and snippets.

@Blaizzy
Last active April 14, 2026 12:23
Show Gist options
  • Select an option

  • Save Blaizzy/008df4f0a2f6df88db6f36569f06ea25 to your computer and use it in GitHub Desktop.

Select an option

Save Blaizzy/008df4f0a2f6df88db6f36569f06ea25 to your computer and use it in GitHub Desktop.
"""
Benchmark TriAttention on MATH 500 — matching the paper's evaluation protocol.
Paper settings: max_tokens=32768, temp=0.6, top_p=0.95, budget=512/1024/2048
We use max_tokens=4096 for practical runtime on Apple Silicon.
USAGE
python bench_triattention_math.py \
--model /tmp/gemma-4-26b-a4b-it-5bit \
--calib /tmp/gemma4_26b_5bit_calib.safetensors \
--num-samples 50
# Full 500 (takes ~3 hours at 85 tok/s)
python bench_triattention_math.py \
--model /tmp/gemma-4-26b-a4b-it-5bit \
--calib /tmp/gemma4_26b_5bit_calib.safetensors \
--num-samples 500
"""
import argparse
import importlib
import re
import time
import mlx.core as mx
from datasets import load_dataset
from mlx_vlm import load
from mlx_vlm.models.cache import make_prompt_cache
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.triattention import maybe_apply_triattention
mod = importlib.import_module("mlx_vlm.generate")
def extract_answer(text: str) -> str:
"""Extract the final boxed answer from model output."""
# Look for \boxed{...}
boxed = re.findall(r"\\boxed\{([^}]*)\}", text)
if boxed:
return boxed[-1].strip()
# Look for "the answer is X" patterns
patterns = [
r"[Tt]he (?:final )?answer is[:\s]*\$?([^$\n.]+)",
r"[Aa]nswer[:\s]*\$?([^$\n.]+)",
r"= \$?\\?boxed\{?([^}$\n]+)",
]
for pat in patterns:
m = re.search(pat, text)
if m:
return m.group(1).strip()
# Last resort: last number in the text
numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
if numbers:
return numbers[-1]
return text.strip()[-20:] if text.strip() else ""
def normalize_answer(ans: str) -> str:
"""Normalize answer for comparison."""
ans = ans.strip().lower()
ans = ans.replace(" ", "").replace(",", "")
# Remove latex formatting
ans = re.sub(r"\\(?:text|mathrm|mathbf)\{([^}]*)\}", r"\1", ans)
ans = ans.replace("\\frac", "").replace("\\", "")
ans = ans.replace("{", "").replace("}", "")
ans = ans.replace("$", "")
return ans
def answers_match(predicted: str, gold: str) -> bool:
"""Check if predicted answer matches gold."""
pred_norm = normalize_answer(predicted)
gold_norm = normalize_answer(gold)
if pred_norm == gold_norm:
return True
# Try numeric comparison
try:
return abs(float(pred_norm) - float(gold_norm)) < 1e-6
except (ValueError, TypeError):
pass
# Check if gold is contained in prediction
return gold_norm in pred_norm
def run_problem(input_ids, model, kv_args, max_tokens):
"""Run a single problem and return generated text + stats."""
prompt_cache = make_prompt_cache(model.language_model)
# Apply TriAttention if specified
calib_path = kv_args.pop("_calib_path", None)
budget = kv_args.pop("_budget", None)
if calib_path and budget:
maybe_apply_triattention(prompt_cache, model, calib_path, budget=budget)
gen = mod.generate_step(
input_ids,
model,
None, # pixel_values
None, # mask
max_tokens=max_tokens,
temperature=0.6,
top_p=0.95,
prompt_cache=prompt_cache,
)
t0 = time.perf_counter()
toks = []
for tok, _ in gen:
toks.append(tok.item() if isinstance(tok, mx.array) else tok)
elapsed = time.perf_counter() - t0
gen_tps = len(toks) / elapsed if elapsed > 0 else 0
return toks, gen_tps
def main():
parser = argparse.ArgumentParser(description="TriAttention MATH 500 benchmark")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--calib", type=str, required=True)
parser.add_argument("--num-samples", type=int, default=50)
parser.add_argument(
"--max-tokens", type=int, default=4096,
help="Max generation tokens per problem (paper uses 32768)"
)
parser.add_argument(
"--budgets", type=int, nargs="+", default=[2048, 1024, 512],
)
args = parser.parse_args()
print(f"Loading model: {args.model}")
model, processor = load(args.model)
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
config = model.config
print(f"Loaded. Decode speed ~85 tok/s\n")
# Load MATH 500
print("Loading MATH 500 dataset...")
ds = load_dataset("HuggingFaceH4/MATH-500", split="test")
indices = list(range(min(args.num_samples, len(ds))))
print(f"Running {len(indices)} problems, max_tokens={args.max_tokens}\n")
modes = [("Baseline", None)]
for b in args.budgets:
modes.append((f"TA-{b}", b))
results = {name: [] for name, _ in modes}
for i, idx in enumerate(indices):
sample = ds[idx]
problem = sample["problem"]
gold = sample["answer"]
level = sample.get("level", "?")
subject = sample.get("subject", "?")
# Format prompt
prompt_text = (
f"Solve the following math problem step by step. "
f"Put your final answer in \\boxed{{}}.\n\n{problem}"
)
prompt = apply_chat_template(processor, config, prompt_text, num_images=0)
input_ids = mx.array(tokenizer.encode(prompt)).reshape(1, -1)
for mode_name, budget in modes:
mx.clear_cache()
kv_args = {}
if budget is not None:
kv_args["_calib_path"] = args.calib
kv_args["_budget"] = budget
toks, gen_tps = run_problem(input_ids, model, kv_args, args.max_tokens)
text = tokenizer.decode(toks)
predicted = extract_answer(text)
correct = answers_match(predicted, gold)
results[mode_name].append({
"idx": idx,
"correct": correct,
"predicted": predicted,
"gold": gold,
"gen_tokens": len(toks),
"gen_tps": gen_tps,
})
mark = "Y" if correct else "N"
print(
f" [{i+1:3d}/{len(indices)}] {mode_name:<10s} | "
f"{mark} | {len(toks):4d} tok @ {gen_tps:5.1f} t/s | "
f"gold={gold[:12]:>12} pred={predicted[:12]:>12} | "
f"{level} {subject[:15]}",
flush=True,
)
print(flush=True)
# Summary
print(f"\n{'=' * 80}")
print(f"MATH 500 Results ({len(indices)} problems, max_tokens={args.max_tokens})")
print(f"Model: {args.model}")
print(f"{'=' * 80}\n")
print(f"{'Mode':<12} | {'Accuracy':>10} | {'Avg tok/s':>10} | {'Avg gen tokens':>15}")
print("-" * 60)
for mode_name, _ in modes:
r = results[mode_name]
n = len(r)
correct = sum(1 for x in r if x["correct"])
avg_tps = sum(x["gen_tps"] for x in r) / n if n else 0
avg_tok = sum(x["gen_tokens"] for x in r) / n if n else 0
print(
f"{mode_name:<12} | {correct:>4}/{n} ({correct/n*100:4.1f}%) | "
f"{avg_tps:>10.1f} | {avg_tok:>15.0f}"
)
# Per-level breakdown for baseline
print(f"\n--- Per-level breakdown (Baseline) ---")
bl = results["Baseline"]
levels = sorted(set(ds[r["idx"]].get("level", "?") for r in bl))
for level in levels:
level_results = [r for r in bl if ds[r["idx"]].get("level", "?") == level]
correct = sum(1 for r in level_results if r["correct"])
n = len(level_results)
if n > 0:
print(f" {level}: {correct}/{n} ({correct/n*100:.0f}%)")
if __name__ == "__main__":
main()
"""Perplexity benchmark: Baseline vs TriAttention at various budgets.
Runs up to 50K context on Gemma4-31B base."""
import math
import time
import mlx.core as mx
import mlx.nn as nn
from datasets import load_dataset
from mlx_vlm import load
from mlx_vlm.models.cache import make_prompt_cache
from mlx_vlm.triattention import maybe_apply_triattention
MODEL = "google/gemma-4-31b"
CALIB = "/tmp/gemma4_31b_base_calib.safetensors"
def compute_ppl(model, tokenizer, text, calib_path=None, budget=None, max_tokens=2048):
"""Compute perplexity on a text sequence using teacher-forced decoding."""
tokens = tokenizer.encode(text)[:max_tokens]
input_ids = mx.array(tokens).reshape(1, -1)
n = len(tokens)
lm = model
if hasattr(model, "language_model"):
lm_prop = model.language_model
if lm_prop is not model:
lm = lm_prop
cache = make_prompt_cache(lm)
if calib_path and budget:
maybe_apply_triattention(cache, model, calib_path, budget=budget)
total_nll = 0.0
total_tokens = 0
chunk_size = 512
for start in range(0, n - 1, chunk_size):
end = min(start + chunk_size, n - 1)
chunk_ids = input_ids[:, start : end + 1]
if hasattr(model, "get_input_embeddings"):
emb_out = model.get_input_embeddings(chunk_ids, None, mask=None)
inputs_embeds = emb_out.inputs_embeds
out = lm(chunk_ids, inputs_embeds=inputs_embeds, cache=cache)
else:
out = lm(chunk_ids, cache=cache)
logits = out.logits[:, :-1, :]
targets = chunk_ids[:, 1:]
log_probs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
target_log_probs = mx.take_along_axis(
log_probs, targets[:, :, None], axis=-1
).squeeze(-1)
nll = -mx.sum(target_log_probs).item()
count = targets.shape[1]
total_nll += nll
total_tokens += count
mx.eval(cache[0].state)
ppl = math.exp(total_nll / total_tokens) if total_tokens > 0 else float("inf")
return ppl, total_tokens
def main():
print(f"Loading model: {MODEL}")
model, processor = load(MODEL)
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
print("Model loaded.\n")
# Load wikitext — concatenate into one long string
print("Loading wikitext-2...")
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
text = "\n\n".join([t for t in ds["text"] if t.strip()])
total_available = len(tokenizer.encode(text))
print(f"Total tokens available: {total_available}\n")
configs = [
("Baseline", None, None),
("TA-2048", CALIB, 2048),
("TA-1024", CALIB, 1024),
("TA-512", CALIB, 512),
("TA-256", CALIB, 256),
]
context_lengths = [1024, 2048, 4096, 8192, 16384, 32768, 50000]
# Header
print(f"{'Context':>8}", end="")
for name, _, _ in configs:
print(f" | {name:>10}", end="")
print()
print("-" * (10 + 13 * len(configs)))
for max_tokens in context_lengths:
if max_tokens > total_available:
print(f"{max_tokens:>8} | (not enough tokens, only {total_available} available)")
break
row = f"{max_tokens:>8}"
baseline_ppl = None
for name, calib_path, budget in configs:
mx.clear_cache()
t0 = time.perf_counter()
ppl, n_tok = compute_ppl(
model, tokenizer, text,
calib_path=calib_path,
budget=budget,
max_tokens=max_tokens,
)
elapsed = time.perf_counter() - t0
if baseline_ppl is None:
baseline_ppl = ppl
row += f" | {ppl:>10.2f}"
else:
delta = ppl - baseline_ppl
sign = "+" if delta >= 0 else ""
row += f" | {ppl:>6.2f}({sign}{delta:.2f})"
mx.clear_cache()
print(row, flush=True)
print("\nDone.")
if __name__ == "__main__":
main()
"""
Benchmark TriAttention vs baseline on MM-NIAH (Multimodal Needle-in-a-Haystack).
INSTALL
pip install -U mlx-vlm
SETUP — Extract images (one-time)
huggingface-cli download OpenGVLab/MM-NIAH mm_niah_val/images.tar.gz --repo-type dataset
mkdir -p /tmp/mm_niah_val_images
tar xzf ~/.cache/huggingface/hub/datasets--OpenGVLab--MM-NIAH/snapshots/*/mm_niah_val/images.tar.gz \\
-C /tmp/mm_niah_val_images
CALIBRATION (one-time per model)
python -m mlx_vlm.triattention_calibrate \\
--model google/gemma-4-26b-a4b-it \\
--output triattention_calib.safetensors
USAGE
# Full benchmark (baseline vs TriAttention budgets)
python bench_triattention.py --model gg-hf-gg/gemma-4-31b-it \\
--calib /tmp/gemma4_triattention_calib.safetensors
# Quick smoke test (2 samples per bucket)
python bench_triattention.py --model gg-hf-gg/gemma-4-31b-it \\
--calib /tmp/gemma4_triattention_calib.safetensors --num-samples 2
# Custom budgets
python bench_triattention.py --model gg-hf-gg/gemma-4-31b-it \\
--calib /tmp/gemma4_triattention_calib.safetensors --budgets 1024 512 256
"""
import argparse
import importlib
import os
import time
import mlx.core as mx
import numpy as np
from datasets import load_dataset
from PIL import Image
from mlx_vlm import load
from mlx_vlm.models.cache import make_prompt_cache
from mlx_vlm.triattention import (
TriAttentionKVCache,
extract_rope_config,
load_calibration,
maybe_apply_triattention,
)
mod = importlib.import_module("mlx_vlm.generate")
DEFAULT_IMAGE_ROOT = "/tmp/mm_niah_val_images/mm_niah_dev/images"
BUCKETS = [
(0, 2000, "~1K"),
(2000, 5000, "~3K"),
(5000, 10000, "~7K"),
(10000, 20000, "~15K"),
(20000, 40000, "~30K"),
(40000, 100000, "~60K"),
]
def select_samples(ds, tokenizer, num_per_bucket):
"""Select samples bucketed by actual token count."""
all_samples = []
for i, s in enumerate(ds):
tok_count = len(tokenizer.encode(s["context"]))
all_samples.append((i, tok_count))
indices = []
for lo, hi, label in BUCKETS:
bucket = [(i, n) for i, n in all_samples if lo <= n < hi]
bucket.sort(key=lambda x: x[1])
if len(bucket) >= num_per_bucket:
step = max(1, len(bucket) // num_per_bucket)
picked = [bucket[j] for j in range(0, len(bucket), step)][
:num_per_bucket
]
else:
picked = bucket
for idx, tok_count in picked:
indices.append((idx, tok_count, label))
return indices
def load_images(sample, image_root):
images = []
for img_path in sample["images_list"]:
full_path = os.path.join(image_root, img_path)
if os.path.exists(full_path):
images.append(Image.open(full_path).convert("RGB"))
return images
def measure_cache_bytes(prompt_cache) -> int:
"""Sum actual .nbytes across all layers of the prompt cache."""
total = 0
for entry in prompt_cache:
total += entry.nbytes
return total
def run_sample(input_ids, model, pv, mask, kv_args, calib_path=None, budget=None):
n = input_ids.shape[1]
prompt_cache = make_prompt_cache(model.language_model)
# Apply TriAttention if requested
if calib_path is not None and budget is not None:
maybe_apply_triattention(
prompt_cache, model, calib_path, budget=budget
)
gen = mod.generate_step(
input_ids,
model,
pv,
mask,
max_tokens=20,
temperature=0.0,
prompt_cache=prompt_cache,
**kv_args,
)
t0 = time.perf_counter()
token, _ = next(gen)
mx.eval(token if isinstance(token, mx.array) else mx.array(token))
t_prefill = time.perf_counter() - t0
prefill_tps = n / t_prefill
t0 = time.perf_counter()
toks = [token.item() if isinstance(token, mx.array) else token]
count = 0
for tok, _ in gen:
mx.eval(tok if isinstance(tok, mx.array) else tok)
toks.append(tok.item() if isinstance(tok, mx.array) else tok)
count += 1
t_decode = time.perf_counter() - t0
decode_tps = count / t_decode if t_decode > 0 else 0
kv_bytes = measure_cache_bytes(prompt_cache)
peak_mem = mx.metal.get_peak_memory() / (1 << 30)
return toks, prefill_tps, decode_tps, kv_bytes, peak_mem
def main():
parser = argparse.ArgumentParser(
description="Benchmark TriAttention on MM-NIAH"
)
parser.add_argument(
"--model", type=str, default="gg-hf-gg/gemma-4-31b-it"
)
parser.add_argument(
"--calib", type=str, required=True, help="Path to calibration .safetensors"
)
parser.add_argument(
"--num-samples", type=int, default=1, help="Samples per bucket"
)
parser.add_argument(
"--budgets",
type=int,
nargs="+",
default=[512, 256, 128],
help="TriAttention KV budgets to test",
)
parser.add_argument(
"--image-root", type=str, default=DEFAULT_IMAGE_ROOT
)
args = parser.parse_args()
if not os.path.exists(args.image_root):
print(f"Image root not found: {args.image_root}")
print("Extract images first (see script docstring for instructions).")
return
ds = load_dataset("OpenGVLab/MM-NIAH", split="val")
model, processor = load(args.model)
indices = select_samples(ds, processor.tokenizer, args.num_samples)
print(f"Model: {args.model}")
print(f"Calib: {args.calib}")
print(f"Samples: {len(indices)} ({args.num_samples} per bucket)")
print(f"Budgets: {args.budgets}\n")
# Build mode list: Baseline + each budget
modes = [("BL", None)]
for b in args.budgets:
modes.append((f"TA-{b}", b))
results = []
for idx, ctx_len, label in indices:
s = ds[idx]
images = load_images(s, args.image_root)
if not images:
print(f" Skipping {label} idx={idx}, no images found", flush=True)
continue
prompt = f"{s['context']}\n\nQuestion: {s['question']}\nAnswer briefly:"
content = [{"type": "image"} for _ in images]
content.append({"type": "text", "text": prompt})
messages = [{"role": "user", "content": content}]
text = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
try:
inputs = processor(text=text, images=images, return_tensors="np")
except Exception as e:
print(
f" Skipping {label} idx={idx}, processor error: {e}",
flush=True,
)
continue
input_ids = mx.array(inputs["input_ids"])
pv = inputs.get("pixel_values", None)
if pv is not None and not isinstance(pv, (list, mx.array)):
pv = mx.array(np.asarray(pv))
mask = (
mx.array(inputs["attention_mask"])
if "attention_mask" in inputs
else None
)
n = input_ids.shape[1]
gold = s["answer"]
for mode_name, budget in modes:
mx.clear_cache()
calib_path = args.calib if budget is not None else None
toks, prefill_tps, decode_tps, kv_bytes, peak_mem = run_sample(
input_ids, model, pv, mask, {}, calib_path, budget
)
kv_gb = kv_bytes / (1 << 30)
answer = (
processor.tokenizer.decode(toks)
.strip()
.replace("\n", " ")[:60]
)
correct = gold.lower() in answer.lower()
results.append(
{
"idx": idx,
"n": n,
"mode": mode_name,
"budget": budget,
"prefill": prefill_tps,
"decode": decode_tps,
"kv_gb": kv_gb,
"peak_gb": peak_mem,
"correct": correct,
"gold": gold,
"answer": answer,
"label": label,
"num_images": len(images),
}
)
mark = "Y" if correct else "N"
print(
f"{label:>5} {n:>6} ({len(images):>2} img) | {mode_name:<6} | "
f"pf {prefill_tps:>7.1f} | dec {decode_tps:>5.1f} | "
f"KV {kv_gb:>6.3f}G | peak {peak_mem:>5.1f}G | {mark} | "
f"gold={gold[:15]:>15} | {answer[:40]}",
flush=True,
)
print(flush=True)
# ── Summary table ──
bl_results = [r for r in results if r["mode"] == "BL"]
if not bl_results:
print("No completed baseline results.")
return
print(f"\n{'=' * 120}")
print("SUMMARY")
print(f"{'=' * 120}")
header = (
f"{'Bucket':>5} {'Tok':>6} {'Img':>3} | {'BL pf':>6} {'BL dec':>6} "
f"{'KV BL':>6} {'Peak':>5}"
)
for b in args.budgets:
header += f" | {'pf':>6} {'dec':>6} {'KV':>6} {'Save':>5} {'Pk':>5} {'Ok':>2}"
header = header # label comes from the column header
print(header)
print("-" * 120)
for bl in bl_results:
row = (
f"{bl['label']:>5} {bl['n']:>6} {bl['num_images']:>3} | "
f"{bl['prefill']:>6.0f} {bl['decode']:>6.1f} "
f"{bl['kv_gb']:>5.3f}G {bl['peak_gb']:>5.1f}G"
)
for b in args.budgets:
ta = next(
(
r
for r in results
if r["idx"] == bl["idx"]
and r["mode"] == f"TA-{b}"
),
None,
)
if ta:
save = (
(1 - ta["kv_gb"] / bl["kv_gb"]) * 100
if bl["kv_gb"] > 0
else 0
)
bm = "Y" if bl["correct"] else "N"
tm = "Y" if ta["correct"] else "N"
row += (
f" | {ta['prefill']:>6.0f} {ta['decode']:>6.1f} "
f"{ta['kv_gb']:>5.3f}G {save:>4.0f}% "
f"{ta['peak_gb']:>5.1f}G {tm:>2}"
)
else:
row += " | - - - - - -"
print(row)
# Aggregate stats
print(f"\n{'─' * 80}")
for mode_name, budget in modes:
mode_results = [r for r in results if r["mode"] == mode_name]
if not mode_results:
continue
n_total = len(mode_results)
n_correct = sum(1 for r in mode_results if r["correct"])
avg_prefill = sum(r["prefill"] for r in mode_results) / n_total
avg_decode = sum(r["decode"] for r in mode_results) / n_total
avg_kv = sum(r["kv_gb"] for r in mode_results) / n_total
avg_peak = sum(r["peak_gb"] for r in mode_results) / n_total
if budget is None:
kv_save = " - "
else:
bl_avg_kv = sum(r["kv_gb"] for r in bl_results) / len(bl_results)
kv_save = f"{(1 - avg_kv / bl_avg_kv) * 100:4.0f}%"
print(
f" {mode_name:<8} | acc={n_correct}/{n_total} ({n_correct/n_total*100:4.0f}%) | "
f"prefill={avg_prefill:7.1f} t/s | decode={avg_decode:5.1f} t/s | "
f"KV={avg_kv:.3f}G | peak={avg_peak:.1f}G | KV saved: {kv_save}"
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment