Last active
April 5, 2026 19:33
-
-
Save Blaizzy/e703abfaa41b878b9bd297b9953e9330 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Benchmark TurboQuant vs baseline on LongBench-v2. | |
| Usage: | |
| python scripts/bench_longbench_v2.py --model google/gemma-4-e4b-it --num-samples 10 --max-tokens-ctx 260000 | |
| python scripts/bench_longbench_v2.py --model google/gemma-4-26b-a4b-it --num-samples 5 --max-tokens-ctx 128000 --kv-bits 4 | |
| """ | |
| import argparse | |
| import importlib | |
| import time | |
| import mlx.core as mx | |
| from datasets import load_dataset | |
| from mlx_lm.models.cache import make_prompt_cache | |
| from mlx_vlm import load | |
| mod = importlib.import_module("mlx_vlm.generate") | |
| def build_prompt(sample): | |
| ctx = sample["context"] | |
| q = sample["question"] | |
| choices = [] | |
| for key in ["choice_A", "choice_B", "choice_C", "choice_D"]: | |
| if sample.get(key): | |
| choices.append(f"{key[-1]}. {sample[key]}") | |
| choices_str = "\n".join(choices) | |
| return ( | |
| "Read the following context and answer the multiple choice question. " | |
| "Reply with ONLY the letter (A, B, C, or D).\n\n" | |
| f"Context:\n{ctx}\n\n" | |
| f"Question: {q}\n\n" | |
| f"{choices_str}\n\n" | |
| "Answer:" | |
| ) | |
| def extract_letter(text): | |
| for c in text.upper(): | |
| if c in "ABCD": | |
| return c | |
| return "" | |
| def select_samples(ds, processor, num_samples, max_tokens_ctx): | |
| candidates = [] | |
| for i, s in enumerate(ds): | |
| ctx = s.get("context", "") or "" | |
| tok_count = len(processor.tokenizer.encode(ctx)) | |
| if tok_count <= max_tokens_ctx: | |
| candidates.append((i, tok_count)) | |
| candidates.sort(key=lambda x: x[1]) | |
| if len(candidates) <= num_samples: | |
| return [c[0] for c in candidates] | |
| step = max(1, len(candidates) // num_samples) | |
| selected = [candidates[i] for i in range(0, len(candidates), step)][:num_samples] | |
| selected.sort(key=lambda x: x[1]) | |
| return [s[0] for s in selected] | |
| def measure_cache_bytes(prompt_cache) -> int: | |
| """Sum actual .nbytes across all layers of the prompt cache.""" | |
| total = 0 | |
| for entry in prompt_cache: | |
| total += entry.nbytes | |
| return total | |
| def run_sample(input_ids, model, kv_args): | |
| n = input_ids.shape[1] | |
| prompt_cache = make_prompt_cache(model.language_model) | |
| gen = mod.generate_step( | |
| input_ids, model, None, None, max_tokens=10, temperature=0.0, | |
| prompt_cache=prompt_cache, **kv_args | |
| ) | |
| t0 = time.perf_counter() | |
| token, _ = next(gen) | |
| mx.eval(token if isinstance(token, mx.array) else mx.array(token)) | |
| t_prefill = time.perf_counter() - t0 | |
| prefill_tps = n / t_prefill | |
| t0 = time.perf_counter() | |
| toks = [token.item() if isinstance(token, mx.array) else token] | |
| count = 0 | |
| for tok, _ in gen: | |
| mx.eval(tok if isinstance(tok, mx.array) else tok) | |
| toks.append(tok.item() if isinstance(tok, mx.array) else tok) | |
| count += 1 | |
| t_decode = time.perf_counter() - t0 | |
| decode_tps = count / t_decode if t_decode > 0 else 0 | |
| # Measure actual KV cache bytes after generation | |
| kv_bytes = measure_cache_bytes(prompt_cache) | |
| return toks, prefill_tps, decode_tps, kv_bytes | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Benchmark TurboQuant on LongBench-v2") | |
| parser.add_argument("--model", type=str, default="google/gemma-4-e4b-it") | |
| parser.add_argument("--num-samples", type=int, default=10) | |
| parser.add_argument("--max-tokens-ctx", type=int, default=260000) | |
| parser.add_argument("--kv-bits", type=float, default=3.5) | |
| args = parser.parse_args() | |
| ds = load_dataset("zai-org/LongBench-v2", split="train") | |
| model, processor = load(args.model) | |
| indices = select_samples(ds, processor, args.num_samples, args.max_tokens_ctx) | |
| print(f"Model: {args.model}") | |
| print(f"Samples: {len(indices)}, max context: {args.max_tokens_ctx} tokens") | |
| print(f"TBQ: {args.kv_bits}-bit\n") | |
| # Warm up Metal shaders to avoid penalizing the first mode | |
| dummy_ids = mx.array([[1, 2, 3]]) | |
| for _, kv_args in [("BL", {}), ("TBQ", {"kv_bits": args.kv_bits, "kv_quant_scheme": "turboquant"})]: | |
| gen = mod.generate_step(dummy_ids, model, None, None, max_tokens=1, temperature=0.0, **kv_args) | |
| tok, _ = next(gen) | |
| mx.eval(tok if isinstance(tok, mx.array) else mx.array(tok)) | |
| del gen | |
| mx.clear_cache() | |
| modes = [ | |
| ("BL", {}), | |
| ("TBQ", {"kv_bits": args.kv_bits, "kv_quant_scheme": "turboquant"}), | |
| ] | |
| results = [] | |
| for sample_idx in indices: | |
| s = ds[sample_idx] | |
| prompt = build_prompt(s) | |
| gold = s["answer"].strip().upper() | |
| messages = [{"role": "user", "content": prompt}] | |
| text = processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=False | |
| ) | |
| input_ids = mx.array([processor.tokenizer.encode(text)]) | |
| n = input_ids.shape[1] | |
| for mode_name, kv_args in modes: | |
| mx.clear_cache() | |
| toks, prefill_tps, decode_tps, kv_bytes = run_sample( | |
| input_ids, model, kv_args | |
| ) | |
| kv_gb = kv_bytes / (1 << 30) | |
| answer_text = processor.tokenizer.decode(toks).strip()[:20] | |
| pred = extract_letter(answer_text) | |
| correct = pred == gold | |
| results.append( | |
| { | |
| "idx": sample_idx, | |
| "n": n, | |
| "mode": mode_name, | |
| "prefill": prefill_tps, | |
| "decode": decode_tps, | |
| "kv": kv_gb, | |
| "pred": pred, | |
| "gold": gold, | |
| "correct": correct, | |
| } | |
| ) | |
| mark = "Y" if correct else "N" | |
| print( | |
| f"{n:>7} | {mode_name:<3} | pf {prefill_tps:>7.1f} | " | |
| f"dec {decode_tps:>5.1f} | KV {kv_gb:>5.2f}G | " | |
| f"{pred}/{gold} {mark} | {answer_text}", | |
| flush=True, | |
| ) | |
| print(flush=True) | |
| # Summary | |
| bl_results = [r for r in results if r["mode"] == "BL"] | |
| tbq_results = [r for r in results if r["mode"] == "TBQ"] | |
| bl_correct = sum(1 for r in bl_results if r["correct"]) | |
| tbq_correct = sum(1 for r in tbq_results if r["correct"]) | |
| total = len(indices) | |
| agree = sum( | |
| 1 | |
| for b, t in zip(bl_results, tbq_results) | |
| if b["pred"] == t["pred"] | |
| ) | |
| tbq_wins = sum( | |
| 1 for b, t in zip(bl_results, tbq_results) if t["correct"] and not b["correct"] | |
| ) | |
| bl_wins = sum( | |
| 1 for b, t in zip(bl_results, tbq_results) if b["correct"] and not t["correct"] | |
| ) | |
| bl_kv_avg = sum(r["kv"] for r in bl_results) / total if total else 0 | |
| tbq_kv_avg = sum(r["kv"] for r in tbq_results) / total if total else 0 | |
| kv_save = (1 - tbq_kv_avg / bl_kv_avg) * 100 if bl_kv_avg > 0 else 0 | |
| print(f"\nAccuracy: BL={bl_correct}/{total} ({bl_correct/total*100:.0f}%) " | |
| f"TBQ={tbq_correct}/{total} ({tbq_correct/total*100:.0f}%)") | |
| print(f"Agreement: {agree}/{total} ({agree/total*100:.0f}%)") | |
| print(f"TBQ wins: {tbq_wins} | BL wins: {bl_wins} | Net: TBQ +{tbq_wins - bl_wins}") | |
| print(f"KV cache: BL={bl_kv_avg:.2f}G TBQ={tbq_kv_avg:.2f}G ({kv_save:.0f}% savings)") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment