Last active
April 14, 2026 12:23
-
-
Save Blaizzy/008df4f0a2f6df88db6f36569f06ea25 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Benchmark TriAttention on MATH 500 — matching the paper's evaluation protocol. | |
| Paper settings: max_tokens=32768, temp=0.6, top_p=0.95, budget=512/1024/2048 | |
| We use max_tokens=4096 for practical runtime on Apple Silicon. | |
| USAGE | |
| python bench_triattention_math.py \ | |
| --model /tmp/gemma-4-26b-a4b-it-5bit \ | |
| --calib /tmp/gemma4_26b_5bit_calib.safetensors \ | |
| --num-samples 50 | |
| # Full 500 (takes ~3 hours at 85 tok/s) | |
| python bench_triattention_math.py \ | |
| --model /tmp/gemma-4-26b-a4b-it-5bit \ | |
| --calib /tmp/gemma4_26b_5bit_calib.safetensors \ | |
| --num-samples 500 | |
| """ | |
| import argparse | |
| import importlib | |
| import re | |
| import time | |
| import mlx.core as mx | |
| from datasets import load_dataset | |
| from mlx_vlm import load | |
| from mlx_vlm.models.cache import make_prompt_cache | |
| from mlx_vlm.prompt_utils import apply_chat_template | |
| from mlx_vlm.triattention import maybe_apply_triattention | |
| mod = importlib.import_module("mlx_vlm.generate") | |
| def extract_answer(text: str) -> str: | |
| """Extract the final boxed answer from model output.""" | |
| # Look for \boxed{...} | |
| boxed = re.findall(r"\\boxed\{([^}]*)\}", text) | |
| if boxed: | |
| return boxed[-1].strip() | |
| # Look for "the answer is X" patterns | |
| patterns = [ | |
| r"[Tt]he (?:final )?answer is[:\s]*\$?([^$\n.]+)", | |
| r"[Aa]nswer[:\s]*\$?([^$\n.]+)", | |
| r"= \$?\\?boxed\{?([^}$\n]+)", | |
| ] | |
| for pat in patterns: | |
| m = re.search(pat, text) | |
| if m: | |
| return m.group(1).strip() | |
| # Last resort: last number in the text | |
| numbers = re.findall(r"-?\d+(?:\.\d+)?", text) | |
| if numbers: | |
| return numbers[-1] | |
| return text.strip()[-20:] if text.strip() else "" | |
| def normalize_answer(ans: str) -> str: | |
| """Normalize answer for comparison.""" | |
| ans = ans.strip().lower() | |
| ans = ans.replace(" ", "").replace(",", "") | |
| # Remove latex formatting | |
| ans = re.sub(r"\\(?:text|mathrm|mathbf)\{([^}]*)\}", r"\1", ans) | |
| ans = ans.replace("\\frac", "").replace("\\", "") | |
| ans = ans.replace("{", "").replace("}", "") | |
| ans = ans.replace("$", "") | |
| return ans | |
| def answers_match(predicted: str, gold: str) -> bool: | |
| """Check if predicted answer matches gold.""" | |
| pred_norm = normalize_answer(predicted) | |
| gold_norm = normalize_answer(gold) | |
| if pred_norm == gold_norm: | |
| return True | |
| # Try numeric comparison | |
| try: | |
| return abs(float(pred_norm) - float(gold_norm)) < 1e-6 | |
| except (ValueError, TypeError): | |
| pass | |
| # Check if gold is contained in prediction | |
| return gold_norm in pred_norm | |
| def run_problem(input_ids, model, kv_args, max_tokens): | |
| """Run a single problem and return generated text + stats.""" | |
| prompt_cache = make_prompt_cache(model.language_model) | |
| # Apply TriAttention if specified | |
| calib_path = kv_args.pop("_calib_path", None) | |
| budget = kv_args.pop("_budget", None) | |
| if calib_path and budget: | |
| maybe_apply_triattention(prompt_cache, model, calib_path, budget=budget) | |
| gen = mod.generate_step( | |
| input_ids, | |
| model, | |
| None, # pixel_values | |
| None, # mask | |
| max_tokens=max_tokens, | |
| temperature=0.6, | |
| top_p=0.95, | |
| prompt_cache=prompt_cache, | |
| ) | |
| t0 = time.perf_counter() | |
| toks = [] | |
| for tok, _ in gen: | |
| toks.append(tok.item() if isinstance(tok, mx.array) else tok) | |
| elapsed = time.perf_counter() - t0 | |
| gen_tps = len(toks) / elapsed if elapsed > 0 else 0 | |
| return toks, gen_tps | |
| def main(): | |
| parser = argparse.ArgumentParser(description="TriAttention MATH 500 benchmark") | |
| parser.add_argument("--model", type=str, required=True) | |
| parser.add_argument("--calib", type=str, required=True) | |
| parser.add_argument("--num-samples", type=int, default=50) | |
| parser.add_argument( | |
| "--max-tokens", type=int, default=4096, | |
| help="Max generation tokens per problem (paper uses 32768)" | |
| ) | |
| parser.add_argument( | |
| "--budgets", type=int, nargs="+", default=[2048, 1024, 512], | |
| ) | |
| args = parser.parse_args() | |
| print(f"Loading model: {args.model}") | |
| model, processor = load(args.model) | |
| tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor | |
| config = model.config | |
| print(f"Loaded. Decode speed ~85 tok/s\n") | |
| # Load MATH 500 | |
| print("Loading MATH 500 dataset...") | |
| ds = load_dataset("HuggingFaceH4/MATH-500", split="test") | |
| indices = list(range(min(args.num_samples, len(ds)))) | |
| print(f"Running {len(indices)} problems, max_tokens={args.max_tokens}\n") | |
| modes = [("Baseline", None)] | |
| for b in args.budgets: | |
| modes.append((f"TA-{b}", b)) | |
| results = {name: [] for name, _ in modes} | |
| for i, idx in enumerate(indices): | |
| sample = ds[idx] | |
| problem = sample["problem"] | |
| gold = sample["answer"] | |
| level = sample.get("level", "?") | |
| subject = sample.get("subject", "?") | |
| # Format prompt | |
| prompt_text = ( | |
| f"Solve the following math problem step by step. " | |
| f"Put your final answer in \\boxed{{}}.\n\n{problem}" | |
| ) | |
| prompt = apply_chat_template(processor, config, prompt_text, num_images=0) | |
| input_ids = mx.array(tokenizer.encode(prompt)).reshape(1, -1) | |
| for mode_name, budget in modes: | |
| mx.clear_cache() | |
| kv_args = {} | |
| if budget is not None: | |
| kv_args["_calib_path"] = args.calib | |
| kv_args["_budget"] = budget | |
| toks, gen_tps = run_problem(input_ids, model, kv_args, args.max_tokens) | |
| text = tokenizer.decode(toks) | |
| predicted = extract_answer(text) | |
| correct = answers_match(predicted, gold) | |
| results[mode_name].append({ | |
| "idx": idx, | |
| "correct": correct, | |
| "predicted": predicted, | |
| "gold": gold, | |
| "gen_tokens": len(toks), | |
| "gen_tps": gen_tps, | |
| }) | |
| mark = "Y" if correct else "N" | |
| print( | |
| f" [{i+1:3d}/{len(indices)}] {mode_name:<10s} | " | |
| f"{mark} | {len(toks):4d} tok @ {gen_tps:5.1f} t/s | " | |
| f"gold={gold[:12]:>12} pred={predicted[:12]:>12} | " | |
| f"{level} {subject[:15]}", | |
| flush=True, | |
| ) | |
| print(flush=True) | |
| # Summary | |
| print(f"\n{'=' * 80}") | |
| print(f"MATH 500 Results ({len(indices)} problems, max_tokens={args.max_tokens})") | |
| print(f"Model: {args.model}") | |
| print(f"{'=' * 80}\n") | |
| print(f"{'Mode':<12} | {'Accuracy':>10} | {'Avg tok/s':>10} | {'Avg gen tokens':>15}") | |
| print("-" * 60) | |
| for mode_name, _ in modes: | |
| r = results[mode_name] | |
| n = len(r) | |
| correct = sum(1 for x in r if x["correct"]) | |
| avg_tps = sum(x["gen_tps"] for x in r) / n if n else 0 | |
| avg_tok = sum(x["gen_tokens"] for x in r) / n if n else 0 | |
| print( | |
| f"{mode_name:<12} | {correct:>4}/{n} ({correct/n*100:4.1f}%) | " | |
| f"{avg_tps:>10.1f} | {avg_tok:>15.0f}" | |
| ) | |
| # Per-level breakdown for baseline | |
| print(f"\n--- Per-level breakdown (Baseline) ---") | |
| bl = results["Baseline"] | |
| levels = sorted(set(ds[r["idx"]].get("level", "?") for r in bl)) | |
| for level in levels: | |
| level_results = [r for r in bl if ds[r["idx"]].get("level", "?") == level] | |
| correct = sum(1 for r in level_results if r["correct"]) | |
| n = len(level_results) | |
| if n > 0: | |
| print(f" {level}: {correct}/{n} ({correct/n*100:.0f}%)") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Perplexity benchmark: Baseline vs TriAttention at various budgets. | |
| Runs up to 50K context on Gemma4-31B base.""" | |
| import math | |
| import time | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| from datasets import load_dataset | |
| from mlx_vlm import load | |
| from mlx_vlm.models.cache import make_prompt_cache | |
| from mlx_vlm.triattention import maybe_apply_triattention | |
| MODEL = "google/gemma-4-31b" | |
| CALIB = "/tmp/gemma4_31b_base_calib.safetensors" | |
| def compute_ppl(model, tokenizer, text, calib_path=None, budget=None, max_tokens=2048): | |
| """Compute perplexity on a text sequence using teacher-forced decoding.""" | |
| tokens = tokenizer.encode(text)[:max_tokens] | |
| input_ids = mx.array(tokens).reshape(1, -1) | |
| n = len(tokens) | |
| lm = model | |
| if hasattr(model, "language_model"): | |
| lm_prop = model.language_model | |
| if lm_prop is not model: | |
| lm = lm_prop | |
| cache = make_prompt_cache(lm) | |
| if calib_path and budget: | |
| maybe_apply_triattention(cache, model, calib_path, budget=budget) | |
| total_nll = 0.0 | |
| total_tokens = 0 | |
| chunk_size = 512 | |
| for start in range(0, n - 1, chunk_size): | |
| end = min(start + chunk_size, n - 1) | |
| chunk_ids = input_ids[:, start : end + 1] | |
| if hasattr(model, "get_input_embeddings"): | |
| emb_out = model.get_input_embeddings(chunk_ids, None, mask=None) | |
| inputs_embeds = emb_out.inputs_embeds | |
| out = lm(chunk_ids, inputs_embeds=inputs_embeds, cache=cache) | |
| else: | |
| out = lm(chunk_ids, cache=cache) | |
| logits = out.logits[:, :-1, :] | |
| targets = chunk_ids[:, 1:] | |
| log_probs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) | |
| target_log_probs = mx.take_along_axis( | |
| log_probs, targets[:, :, None], axis=-1 | |
| ).squeeze(-1) | |
| nll = -mx.sum(target_log_probs).item() | |
| count = targets.shape[1] | |
| total_nll += nll | |
| total_tokens += count | |
| mx.eval(cache[0].state) | |
| ppl = math.exp(total_nll / total_tokens) if total_tokens > 0 else float("inf") | |
| return ppl, total_tokens | |
| def main(): | |
| print(f"Loading model: {MODEL}") | |
| model, processor = load(MODEL) | |
| tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor | |
| print("Model loaded.\n") | |
| # Load wikitext — concatenate into one long string | |
| print("Loading wikitext-2...") | |
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") | |
| text = "\n\n".join([t for t in ds["text"] if t.strip()]) | |
| total_available = len(tokenizer.encode(text)) | |
| print(f"Total tokens available: {total_available}\n") | |
| configs = [ | |
| ("Baseline", None, None), | |
| ("TA-2048", CALIB, 2048), | |
| ("TA-1024", CALIB, 1024), | |
| ("TA-512", CALIB, 512), | |
| ("TA-256", CALIB, 256), | |
| ] | |
| context_lengths = [1024, 2048, 4096, 8192, 16384, 32768, 50000] | |
| # Header | |
| print(f"{'Context':>8}", end="") | |
| for name, _, _ in configs: | |
| print(f" | {name:>10}", end="") | |
| print() | |
| print("-" * (10 + 13 * len(configs))) | |
| for max_tokens in context_lengths: | |
| if max_tokens > total_available: | |
| print(f"{max_tokens:>8} | (not enough tokens, only {total_available} available)") | |
| break | |
| row = f"{max_tokens:>8}" | |
| baseline_ppl = None | |
| for name, calib_path, budget in configs: | |
| mx.clear_cache() | |
| t0 = time.perf_counter() | |
| ppl, n_tok = compute_ppl( | |
| model, tokenizer, text, | |
| calib_path=calib_path, | |
| budget=budget, | |
| max_tokens=max_tokens, | |
| ) | |
| elapsed = time.perf_counter() - t0 | |
| if baseline_ppl is None: | |
| baseline_ppl = ppl | |
| row += f" | {ppl:>10.2f}" | |
| else: | |
| delta = ppl - baseline_ppl | |
| sign = "+" if delta >= 0 else "" | |
| row += f" | {ppl:>6.2f}({sign}{delta:.2f})" | |
| mx.clear_cache() | |
| print(row, flush=True) | |
| print("\nDone.") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Benchmark TriAttention vs baseline on MM-NIAH (Multimodal Needle-in-a-Haystack). | |
| INSTALL | |
| pip install -U mlx-vlm | |
| SETUP — Extract images (one-time) | |
| huggingface-cli download OpenGVLab/MM-NIAH mm_niah_val/images.tar.gz --repo-type dataset | |
| mkdir -p /tmp/mm_niah_val_images | |
| tar xzf ~/.cache/huggingface/hub/datasets--OpenGVLab--MM-NIAH/snapshots/*/mm_niah_val/images.tar.gz \\ | |
| -C /tmp/mm_niah_val_images | |
| CALIBRATION (one-time per model) | |
| python -m mlx_vlm.triattention_calibrate \\ | |
| --model google/gemma-4-26b-a4b-it \\ | |
| --output triattention_calib.safetensors | |
| USAGE | |
| # Full benchmark (baseline vs TriAttention budgets) | |
| python bench_triattention.py --model gg-hf-gg/gemma-4-31b-it \\ | |
| --calib /tmp/gemma4_triattention_calib.safetensors | |
| # Quick smoke test (2 samples per bucket) | |
| python bench_triattention.py --model gg-hf-gg/gemma-4-31b-it \\ | |
| --calib /tmp/gemma4_triattention_calib.safetensors --num-samples 2 | |
| # Custom budgets | |
| python bench_triattention.py --model gg-hf-gg/gemma-4-31b-it \\ | |
| --calib /tmp/gemma4_triattention_calib.safetensors --budgets 1024 512 256 | |
| """ | |
| import argparse | |
| import importlib | |
| import os | |
| import time | |
| import mlx.core as mx | |
| import numpy as np | |
| from datasets import load_dataset | |
| from PIL import Image | |
| from mlx_vlm import load | |
| from mlx_vlm.models.cache import make_prompt_cache | |
| from mlx_vlm.triattention import ( | |
| TriAttentionKVCache, | |
| extract_rope_config, | |
| load_calibration, | |
| maybe_apply_triattention, | |
| ) | |
| mod = importlib.import_module("mlx_vlm.generate") | |
| DEFAULT_IMAGE_ROOT = "/tmp/mm_niah_val_images/mm_niah_dev/images" | |
| BUCKETS = [ | |
| (0, 2000, "~1K"), | |
| (2000, 5000, "~3K"), | |
| (5000, 10000, "~7K"), | |
| (10000, 20000, "~15K"), | |
| (20000, 40000, "~30K"), | |
| (40000, 100000, "~60K"), | |
| ] | |
| def select_samples(ds, tokenizer, num_per_bucket): | |
| """Select samples bucketed by actual token count.""" | |
| all_samples = [] | |
| for i, s in enumerate(ds): | |
| tok_count = len(tokenizer.encode(s["context"])) | |
| all_samples.append((i, tok_count)) | |
| indices = [] | |
| for lo, hi, label in BUCKETS: | |
| bucket = [(i, n) for i, n in all_samples if lo <= n < hi] | |
| bucket.sort(key=lambda x: x[1]) | |
| if len(bucket) >= num_per_bucket: | |
| step = max(1, len(bucket) // num_per_bucket) | |
| picked = [bucket[j] for j in range(0, len(bucket), step)][ | |
| :num_per_bucket | |
| ] | |
| else: | |
| picked = bucket | |
| for idx, tok_count in picked: | |
| indices.append((idx, tok_count, label)) | |
| return indices | |
| def load_images(sample, image_root): | |
| images = [] | |
| for img_path in sample["images_list"]: | |
| full_path = os.path.join(image_root, img_path) | |
| if os.path.exists(full_path): | |
| images.append(Image.open(full_path).convert("RGB")) | |
| return images | |
| def measure_cache_bytes(prompt_cache) -> int: | |
| """Sum actual .nbytes across all layers of the prompt cache.""" | |
| total = 0 | |
| for entry in prompt_cache: | |
| total += entry.nbytes | |
| return total | |
| def run_sample(input_ids, model, pv, mask, kv_args, calib_path=None, budget=None): | |
| n = input_ids.shape[1] | |
| prompt_cache = make_prompt_cache(model.language_model) | |
| # Apply TriAttention if requested | |
| if calib_path is not None and budget is not None: | |
| maybe_apply_triattention( | |
| prompt_cache, model, calib_path, budget=budget | |
| ) | |
| gen = mod.generate_step( | |
| input_ids, | |
| model, | |
| pv, | |
| mask, | |
| max_tokens=20, | |
| temperature=0.0, | |
| prompt_cache=prompt_cache, | |
| **kv_args, | |
| ) | |
| t0 = time.perf_counter() | |
| token, _ = next(gen) | |
| mx.eval(token if isinstance(token, mx.array) else mx.array(token)) | |
| t_prefill = time.perf_counter() - t0 | |
| prefill_tps = n / t_prefill | |
| t0 = time.perf_counter() | |
| toks = [token.item() if isinstance(token, mx.array) else token] | |
| count = 0 | |
| for tok, _ in gen: | |
| mx.eval(tok if isinstance(tok, mx.array) else tok) | |
| toks.append(tok.item() if isinstance(tok, mx.array) else tok) | |
| count += 1 | |
| t_decode = time.perf_counter() - t0 | |
| decode_tps = count / t_decode if t_decode > 0 else 0 | |
| kv_bytes = measure_cache_bytes(prompt_cache) | |
| peak_mem = mx.metal.get_peak_memory() / (1 << 30) | |
| return toks, prefill_tps, decode_tps, kv_bytes, peak_mem | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Benchmark TriAttention on MM-NIAH" | |
| ) | |
| parser.add_argument( | |
| "--model", type=str, default="gg-hf-gg/gemma-4-31b-it" | |
| ) | |
| parser.add_argument( | |
| "--calib", type=str, required=True, help="Path to calibration .safetensors" | |
| ) | |
| parser.add_argument( | |
| "--num-samples", type=int, default=1, help="Samples per bucket" | |
| ) | |
| parser.add_argument( | |
| "--budgets", | |
| type=int, | |
| nargs="+", | |
| default=[512, 256, 128], | |
| help="TriAttention KV budgets to test", | |
| ) | |
| parser.add_argument( | |
| "--image-root", type=str, default=DEFAULT_IMAGE_ROOT | |
| ) | |
| args = parser.parse_args() | |
| if not os.path.exists(args.image_root): | |
| print(f"Image root not found: {args.image_root}") | |
| print("Extract images first (see script docstring for instructions).") | |
| return | |
| ds = load_dataset("OpenGVLab/MM-NIAH", split="val") | |
| model, processor = load(args.model) | |
| indices = select_samples(ds, processor.tokenizer, args.num_samples) | |
| print(f"Model: {args.model}") | |
| print(f"Calib: {args.calib}") | |
| print(f"Samples: {len(indices)} ({args.num_samples} per bucket)") | |
| print(f"Budgets: {args.budgets}\n") | |
| # Build mode list: Baseline + each budget | |
| modes = [("BL", None)] | |
| for b in args.budgets: | |
| modes.append((f"TA-{b}", b)) | |
| results = [] | |
| for idx, ctx_len, label in indices: | |
| s = ds[idx] | |
| images = load_images(s, args.image_root) | |
| if not images: | |
| print(f" Skipping {label} idx={idx}, no images found", flush=True) | |
| continue | |
| prompt = f"{s['context']}\n\nQuestion: {s['question']}\nAnswer briefly:" | |
| content = [{"type": "image"} for _ in images] | |
| content.append({"type": "text", "text": prompt}) | |
| messages = [{"role": "user", "content": content}] | |
| text = processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=False | |
| ) | |
| try: | |
| inputs = processor(text=text, images=images, return_tensors="np") | |
| except Exception as e: | |
| print( | |
| f" Skipping {label} idx={idx}, processor error: {e}", | |
| flush=True, | |
| ) | |
| continue | |
| input_ids = mx.array(inputs["input_ids"]) | |
| pv = inputs.get("pixel_values", None) | |
| if pv is not None and not isinstance(pv, (list, mx.array)): | |
| pv = mx.array(np.asarray(pv)) | |
| mask = ( | |
| mx.array(inputs["attention_mask"]) | |
| if "attention_mask" in inputs | |
| else None | |
| ) | |
| n = input_ids.shape[1] | |
| gold = s["answer"] | |
| for mode_name, budget in modes: | |
| mx.clear_cache() | |
| calib_path = args.calib if budget is not None else None | |
| toks, prefill_tps, decode_tps, kv_bytes, peak_mem = run_sample( | |
| input_ids, model, pv, mask, {}, calib_path, budget | |
| ) | |
| kv_gb = kv_bytes / (1 << 30) | |
| answer = ( | |
| processor.tokenizer.decode(toks) | |
| .strip() | |
| .replace("\n", " ")[:60] | |
| ) | |
| correct = gold.lower() in answer.lower() | |
| results.append( | |
| { | |
| "idx": idx, | |
| "n": n, | |
| "mode": mode_name, | |
| "budget": budget, | |
| "prefill": prefill_tps, | |
| "decode": decode_tps, | |
| "kv_gb": kv_gb, | |
| "peak_gb": peak_mem, | |
| "correct": correct, | |
| "gold": gold, | |
| "answer": answer, | |
| "label": label, | |
| "num_images": len(images), | |
| } | |
| ) | |
| mark = "Y" if correct else "N" | |
| print( | |
| f"{label:>5} {n:>6} ({len(images):>2} img) | {mode_name:<6} | " | |
| f"pf {prefill_tps:>7.1f} | dec {decode_tps:>5.1f} | " | |
| f"KV {kv_gb:>6.3f}G | peak {peak_mem:>5.1f}G | {mark} | " | |
| f"gold={gold[:15]:>15} | {answer[:40]}", | |
| flush=True, | |
| ) | |
| print(flush=True) | |
| # ── Summary table ── | |
| bl_results = [r for r in results if r["mode"] == "BL"] | |
| if not bl_results: | |
| print("No completed baseline results.") | |
| return | |
| print(f"\n{'=' * 120}") | |
| print("SUMMARY") | |
| print(f"{'=' * 120}") | |
| header = ( | |
| f"{'Bucket':>5} {'Tok':>6} {'Img':>3} | {'BL pf':>6} {'BL dec':>6} " | |
| f"{'KV BL':>6} {'Peak':>5}" | |
| ) | |
| for b in args.budgets: | |
| header += f" | {'pf':>6} {'dec':>6} {'KV':>6} {'Save':>5} {'Pk':>5} {'Ok':>2}" | |
| header = header # label comes from the column header | |
| print(header) | |
| print("-" * 120) | |
| for bl in bl_results: | |
| row = ( | |
| f"{bl['label']:>5} {bl['n']:>6} {bl['num_images']:>3} | " | |
| f"{bl['prefill']:>6.0f} {bl['decode']:>6.1f} " | |
| f"{bl['kv_gb']:>5.3f}G {bl['peak_gb']:>5.1f}G" | |
| ) | |
| for b in args.budgets: | |
| ta = next( | |
| ( | |
| r | |
| for r in results | |
| if r["idx"] == bl["idx"] | |
| and r["mode"] == f"TA-{b}" | |
| ), | |
| None, | |
| ) | |
| if ta: | |
| save = ( | |
| (1 - ta["kv_gb"] / bl["kv_gb"]) * 100 | |
| if bl["kv_gb"] > 0 | |
| else 0 | |
| ) | |
| bm = "Y" if bl["correct"] else "N" | |
| tm = "Y" if ta["correct"] else "N" | |
| row += ( | |
| f" | {ta['prefill']:>6.0f} {ta['decode']:>6.1f} " | |
| f"{ta['kv_gb']:>5.3f}G {save:>4.0f}% " | |
| f"{ta['peak_gb']:>5.1f}G {tm:>2}" | |
| ) | |
| else: | |
| row += " | - - - - - -" | |
| print(row) | |
| # Aggregate stats | |
| print(f"\n{'─' * 80}") | |
| for mode_name, budget in modes: | |
| mode_results = [r for r in results if r["mode"] == mode_name] | |
| if not mode_results: | |
| continue | |
| n_total = len(mode_results) | |
| n_correct = sum(1 for r in mode_results if r["correct"]) | |
| avg_prefill = sum(r["prefill"] for r in mode_results) / n_total | |
| avg_decode = sum(r["decode"] for r in mode_results) / n_total | |
| avg_kv = sum(r["kv_gb"] for r in mode_results) / n_total | |
| avg_peak = sum(r["peak_gb"] for r in mode_results) / n_total | |
| if budget is None: | |
| kv_save = " - " | |
| else: | |
| bl_avg_kv = sum(r["kv_gb"] for r in bl_results) / len(bl_results) | |
| kv_save = f"{(1 - avg_kv / bl_avg_kv) * 100:4.0f}%" | |
| print( | |
| f" {mode_name:<8} | acc={n_correct}/{n_total} ({n_correct/n_total*100:4.0f}%) | " | |
| f"prefill={avg_prefill:7.1f} t/s | decode={avg_decode:5.1f} t/s | " | |
| f"KV={avg_kv:.3f}G | peak={avg_peak:.1f}G | KV saved: {kv_save}" | |
| ) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment