Skip to content

Instantly share code, notes, and snippets.

@Blaizzy
Last active April 5, 2026 19:33
Show Gist options
  • Select an option

  • Save Blaizzy/e703abfaa41b878b9bd297b9953e9330 to your computer and use it in GitHub Desktop.

Select an option

Save Blaizzy/e703abfaa41b878b9bd297b9953e9330 to your computer and use it in GitHub Desktop.
"""Benchmark TurboQuant vs baseline on LongBench-v2.
Usage:
python scripts/bench_longbench_v2.py --model google/gemma-4-e4b-it --num-samples 10 --max-tokens-ctx 260000
python scripts/bench_longbench_v2.py --model google/gemma-4-26b-a4b-it --num-samples 5 --max-tokens-ctx 128000 --kv-bits 4
"""
import argparse
import importlib
import time
import mlx.core as mx
from datasets import load_dataset
from mlx_lm.models.cache import make_prompt_cache
from mlx_vlm import load
mod = importlib.import_module("mlx_vlm.generate")
def build_prompt(sample):
ctx = sample["context"]
q = sample["question"]
choices = []
for key in ["choice_A", "choice_B", "choice_C", "choice_D"]:
if sample.get(key):
choices.append(f"{key[-1]}. {sample[key]}")
choices_str = "\n".join(choices)
return (
"Read the following context and answer the multiple choice question. "
"Reply with ONLY the letter (A, B, C, or D).\n\n"
f"Context:\n{ctx}\n\n"
f"Question: {q}\n\n"
f"{choices_str}\n\n"
"Answer:"
)
def extract_letter(text):
for c in text.upper():
if c in "ABCD":
return c
return ""
def select_samples(ds, processor, num_samples, max_tokens_ctx):
candidates = []
for i, s in enumerate(ds):
ctx = s.get("context", "") or ""
tok_count = len(processor.tokenizer.encode(ctx))
if tok_count <= max_tokens_ctx:
candidates.append((i, tok_count))
candidates.sort(key=lambda x: x[1])
if len(candidates) <= num_samples:
return [c[0] for c in candidates]
step = max(1, len(candidates) // num_samples)
selected = [candidates[i] for i in range(0, len(candidates), step)][:num_samples]
selected.sort(key=lambda x: x[1])
return [s[0] for s in selected]
def measure_cache_bytes(prompt_cache) -> int:
"""Sum actual .nbytes across all layers of the prompt cache."""
total = 0
for entry in prompt_cache:
total += entry.nbytes
return total
def run_sample(input_ids, model, kv_args):
n = input_ids.shape[1]
prompt_cache = make_prompt_cache(model.language_model)
gen = mod.generate_step(
input_ids, model, None, None, max_tokens=10, temperature=0.0,
prompt_cache=prompt_cache, **kv_args
)
t0 = time.perf_counter()
token, _ = next(gen)
mx.eval(token if isinstance(token, mx.array) else mx.array(token))
t_prefill = time.perf_counter() - t0
prefill_tps = n / t_prefill
t0 = time.perf_counter()
toks = [token.item() if isinstance(token, mx.array) else token]
count = 0
for tok, _ in gen:
mx.eval(tok if isinstance(tok, mx.array) else tok)
toks.append(tok.item() if isinstance(tok, mx.array) else tok)
count += 1
t_decode = time.perf_counter() - t0
decode_tps = count / t_decode if t_decode > 0 else 0
# Measure actual KV cache bytes after generation
kv_bytes = measure_cache_bytes(prompt_cache)
return toks, prefill_tps, decode_tps, kv_bytes
def main():
parser = argparse.ArgumentParser(description="Benchmark TurboQuant on LongBench-v2")
parser.add_argument("--model", type=str, default="google/gemma-4-e4b-it")
parser.add_argument("--num-samples", type=int, default=10)
parser.add_argument("--max-tokens-ctx", type=int, default=260000)
parser.add_argument("--kv-bits", type=float, default=3.5)
args = parser.parse_args()
ds = load_dataset("zai-org/LongBench-v2", split="train")
model, processor = load(args.model)
indices = select_samples(ds, processor, args.num_samples, args.max_tokens_ctx)
print(f"Model: {args.model}")
print(f"Samples: {len(indices)}, max context: {args.max_tokens_ctx} tokens")
print(f"TBQ: {args.kv_bits}-bit\n")
# Warm up Metal shaders to avoid penalizing the first mode
dummy_ids = mx.array([[1, 2, 3]])
for _, kv_args in [("BL", {}), ("TBQ", {"kv_bits": args.kv_bits, "kv_quant_scheme": "turboquant"})]:
gen = mod.generate_step(dummy_ids, model, None, None, max_tokens=1, temperature=0.0, **kv_args)
tok, _ = next(gen)
mx.eval(tok if isinstance(tok, mx.array) else mx.array(tok))
del gen
mx.clear_cache()
modes = [
("BL", {}),
("TBQ", {"kv_bits": args.kv_bits, "kv_quant_scheme": "turboquant"}),
]
results = []
for sample_idx in indices:
s = ds[sample_idx]
prompt = build_prompt(s)
gold = s["answer"].strip().upper()
messages = [{"role": "user", "content": prompt}]
text = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
input_ids = mx.array([processor.tokenizer.encode(text)])
n = input_ids.shape[1]
for mode_name, kv_args in modes:
mx.clear_cache()
toks, prefill_tps, decode_tps, kv_bytes = run_sample(
input_ids, model, kv_args
)
kv_gb = kv_bytes / (1 << 30)
answer_text = processor.tokenizer.decode(toks).strip()[:20]
pred = extract_letter(answer_text)
correct = pred == gold
results.append(
{
"idx": sample_idx,
"n": n,
"mode": mode_name,
"prefill": prefill_tps,
"decode": decode_tps,
"kv": kv_gb,
"pred": pred,
"gold": gold,
"correct": correct,
}
)
mark = "Y" if correct else "N"
print(
f"{n:>7} | {mode_name:<3} | pf {prefill_tps:>7.1f} | "
f"dec {decode_tps:>5.1f} | KV {kv_gb:>5.2f}G | "
f"{pred}/{gold} {mark} | {answer_text}",
flush=True,
)
print(flush=True)
# Summary
bl_results = [r for r in results if r["mode"] == "BL"]
tbq_results = [r for r in results if r["mode"] == "TBQ"]
bl_correct = sum(1 for r in bl_results if r["correct"])
tbq_correct = sum(1 for r in tbq_results if r["correct"])
total = len(indices)
agree = sum(
1
for b, t in zip(bl_results, tbq_results)
if b["pred"] == t["pred"]
)
tbq_wins = sum(
1 for b, t in zip(bl_results, tbq_results) if t["correct"] and not b["correct"]
)
bl_wins = sum(
1 for b, t in zip(bl_results, tbq_results) if b["correct"] and not t["correct"]
)
bl_kv_avg = sum(r["kv"] for r in bl_results) / total if total else 0
tbq_kv_avg = sum(r["kv"] for r in tbq_results) / total if total else 0
kv_save = (1 - tbq_kv_avg / bl_kv_avg) * 100 if bl_kv_avg > 0 else 0
print(f"\nAccuracy: BL={bl_correct}/{total} ({bl_correct/total*100:.0f}%) "
f"TBQ={tbq_correct}/{total} ({tbq_correct/total*100:.0f}%)")
print(f"Agreement: {agree}/{total} ({agree/total*100:.0f}%)")
print(f"TBQ wins: {tbq_wins} | BL wins: {bl_wins} | Net: TBQ +{tbq_wins - bl_wins}")
print(f"KV cache: BL={bl_kv_avg:.2f}G TBQ={tbq_kv_avg:.2f}G ({kv_save:.0f}% savings)")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment