Created
May 20, 2026 20:26
-
-
Save Micky774/70367aa54c88ae39b77200e8a0284fc8 to your computer and use it in GitHub Desktop.
Triton RMSNorm benchmark file
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | |
| # License for AMD contributions = MIT. See LICENSE for more information | |
| """ | |
| Isolated benchmark for the Triton RMSNorm fwd/bwd kernels. | |
| Workflow for before/after comparison: | |
| # On baseline (before optimization) | |
| python benchmarks/bench_rmsnorm_triton.py --tag baseline --out /tmp/rms_baseline.tsv | |
| # After applying changes | |
| python benchmarks/bench_rmsnorm_triton.py --tag opt --out /tmp/rms_opt.tsv | |
| # Diff | |
| python benchmarks/bench_rmsnorm_triton.py --compare /tmp/rms_baseline.tsv /tmp/rms_opt.tsv | |
| The script measures the public Python entry points | |
| (`te_rmsnorm_fwd_triton`, `te_rmsnorm_bwd_triton`) so dispatch overhead and | |
| post-kernel quantize calls (e.g. MXFP8 path) are included exactly as | |
| production sees them. | |
| Reports median latency (ms) and an "effective" memory bandwidth (GB/s) | |
| computed from the minimal useful I/O (input read, output write, gamma, | |
| rsigma, transpose). Internal scratch buffers (e.g. `dg_tmp`) are not | |
| counted toward useful bytes -- so a kernel that wastes HBM on a large | |
| scratch will show as lower achieved GB/s. That is intentional. | |
| """ | |
| import argparse | |
| import csv | |
| import inspect | |
| import os | |
| import statistics | |
| import sys | |
| import time | |
| import traceback | |
| import torch | |
| import triton | |
| # Lazy-import TE bits to keep --compare usable without a CUDA build. | |
| def _import_te(): | |
| from transformer_engine.pytorch import cpp_extensions as tex | |
| from transformer_engine.pytorch.triton_kernels.common import torch_dtype_to_te_dtype | |
| from transformer_engine.pytorch.triton_kernels.norms_common import ( | |
| te_rmsnorm_fwd_triton, | |
| te_rmsnorm_bwd_triton, | |
| ) | |
| from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer | |
| # Detect whether the installed bwd launcher accepts an `autotune` kwarg. | |
| # Pre-optimization baselines (and any upstream version we want to compare | |
| # against) don't, so passing it would TypeError. | |
| try: | |
| bwd_accepts_autotune = ( | |
| "autotune" in inspect.signature(te_rmsnorm_bwd_triton).parameters | |
| ) | |
| except (TypeError, ValueError): | |
| bwd_accepts_autotune = False | |
| return dict( | |
| tex=tex, | |
| torch_dtype_to_te_dtype=torch_dtype_to_te_dtype, | |
| te_rmsnorm_fwd_triton=te_rmsnorm_fwd_triton, | |
| te_rmsnorm_bwd_triton=te_rmsnorm_bwd_triton, | |
| Float8Quantizer=Float8Quantizer, | |
| bwd_accepts_autotune=bwd_accepts_autotune, | |
| ) | |
| # Representative (M, N) shapes. Mix of common LLM shapes plus a few | |
| # stress / edge cases. M = batch * seqlen, N = hidden. | |
| # | |
| # Path selector (in transformer_engine/pytorch/triton_kernels/utils.py): | |
| # fwd USE_BLOCKED when N > 65536 / itemsize (bf16 -> N > 32768) | |
| # bwd USE_BLOCKED when N > 16384 / itemsize (bf16 -> N > 8192) | |
| # | |
| # We deliberately include both sides of those boundaries. | |
| DEFAULT_SHAPES = [ | |
| (4096, 1024), # tiny, non-blocked everywhere | |
| (8192, 2048), # small llama-ish | |
| (8192, 4096), # llama-7B / common | |
| (8192, 8192), # llama-70B class, bwd at boundary | |
| (16384, 4096), # large batch | |
| (4096, 12288), # wider hidden, bwd blocked | |
| (2048, 16384), # wide | |
| (1024, 32768), # fwd at boundary, bwd blocked | |
| (256, 65536), # extreme wide, fwd blocked | |
| (29, 17389), # odd shape, exercises masked tail | |
| ] | |
| DEFAULT_DTYPES = ("bf16", "fp16") | |
| DEFAULT_QUANT_MODES = ("none", "fp8", "fp8_t") # plain, FP8 rowwise, FP8 row+col | |
| _DTYPE_MAP = { | |
| "bf16": torch.bfloat16, | |
| "fp16": torch.float16, | |
| "fp32": torch.float32, | |
| } | |
| def _parse_shapes(spec): | |
| if not spec: | |
| return DEFAULT_SHAPES | |
| out = [] | |
| for item in spec.split(","): | |
| item = item.strip() | |
| if not item: | |
| continue | |
| m_str, n_str = item.lower().split("x") | |
| out.append((int(m_str), int(n_str))) | |
| return out | |
| def _device_info(): | |
| info = {} | |
| try: | |
| info["torch"] = torch.__version__ | |
| info["triton"] = triton.__version__ | |
| except Exception: | |
| pass | |
| if torch.cuda.is_available(): | |
| idx = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(idx) | |
| info["device"] = props.name | |
| info["cu_count"] = props.multi_processor_count | |
| try: | |
| info["arch"] = triton.runtime.driver.active.get_current_target().arch | |
| except Exception: | |
| info["arch"] = f"sm_{props.major}{props.minor}" | |
| return info | |
| def _make_fp8_quantizer(te_mod, columnwise): | |
| scale = torch.ones(1, dtype=torch.float32, device="cuda") | |
| amax = torch.zeros(1, dtype=torch.float32, device="cuda") | |
| return te_mod["Float8Quantizer"]( | |
| scale, amax, te_mod["tex"].DType.kFloat8E4M3, columnwise=columnwise | |
| ) | |
| def _useful_fwd_bytes(M, N, itemsize, quant_mode): | |
| # Minimum HBM traffic the kernel *should* do, in bytes. | |
| out_elem = 1 if quant_mode != "none" else itemsize | |
| b = M * N * itemsize # read x | |
| b += M * N * out_elem # write rowwise out | |
| if quant_mode == "fp8_t": | |
| b += M * N * out_elem # write transpose | |
| b += N * itemsize # read gamma (counted once) | |
| b += M * 4 # write rsigma | |
| return b | |
| def _useful_bwd_bytes(M, N, itemsize): | |
| b = 2 * M * N * itemsize # read x + dz | |
| b += M * N * itemsize # write dx | |
| b += N * itemsize # read gamma | |
| b += N * itemsize # write final dgamma | |
| b += M * 4 # read rsigma | |
| return b | |
| def _do_bench(fn, warmup_ms, rep_ms): | |
| # Manual warmup first so autotune (if any) completes before we time. | |
| for _ in range(5): | |
| fn() | |
| torch.cuda.synchronize() | |
| return triton.testing.do_bench( | |
| fn, warmup=warmup_ms, rep=rep_ms, quantiles=[0.5] | |
| ) | |
| def bench_fwd(te_mod, M, N, dtype, quant_mode, autotune, warmup_ms, rep_ms): | |
| x = (torch.randn(M, N, dtype=dtype, device="cuda") * 0.1) | |
| g = torch.randn(N, dtype=dtype, device="cuda") | |
| if quant_mode == "none": | |
| quantizer = None | |
| otype = te_mod["torch_dtype_to_te_dtype"](dtype) | |
| else: | |
| quantizer = _make_fp8_quantizer(te_mod, columnwise=(quant_mode == "fp8_t")) | |
| otype = te_mod["tex"].DType.kFloat8E4M3 | |
| fwd = te_mod["te_rmsnorm_fwd_triton"] | |
| def call(): | |
| return fwd( | |
| input=x, | |
| weight=g, | |
| eps=1e-5, | |
| ln_out=None, | |
| quantizer=quantizer, | |
| otype=otype, | |
| sm_margin=0, | |
| zero_centered_gamma=False, | |
| autotune=autotune, | |
| ) | |
| ms = _do_bench(call, warmup_ms, rep_ms) | |
| gbs = _useful_fwd_bytes(M, N, x.element_size(), quant_mode) / (ms * 1e-3) / 1e9 | |
| return ms, gbs | |
| def bench_bwd(te_mod, M, N, dtype, autotune, warmup_ms, rep_ms): | |
| x = (torch.randn(M, N, dtype=dtype, device="cuda") * 0.1) | |
| g = torch.randn(N, dtype=dtype, device="cuda") | |
| dz = (torch.randn(M, N, dtype=dtype, device="cuda") * 0.1) | |
| # Produce a real rsigma from the fwd kernel so we measure the bwd | |
| # in the same numerical regime as production. | |
| _, _, rsigma = te_mod["te_rmsnorm_fwd_triton"]( | |
| input=x, | |
| weight=g, | |
| eps=1e-5, | |
| ln_out=None, | |
| quantizer=None, | |
| otype=te_mod["torch_dtype_to_te_dtype"](dtype), | |
| sm_margin=0, | |
| zero_centered_gamma=False, | |
| autotune=False, | |
| ) | |
| bwd = te_mod["te_rmsnorm_bwd_triton"] | |
| if te_mod["bwd_accepts_autotune"]: | |
| def call(): | |
| return bwd(dz, x, rsigma, g, 0, False, autotune=autotune) | |
| else: | |
| # Baseline kernel has no autotune kwarg; `--autotune` / `--no-autotune` | |
| # silently no-ops for bwd in that case. | |
| def call(): | |
| return bwd(dz, x, rsigma, g, 0, False) | |
| ms = _do_bench(call, warmup_ms, rep_ms) | |
| gbs = _useful_bwd_bytes(M, N, x.element_size()) / (ms * 1e-3) / 1e9 | |
| return ms, gbs | |
| def _enumerate_configs(args): | |
| """Build the full (kind, dtype, quant, M, N) task list.""" | |
| shapes = _parse_shapes(args.shapes) | |
| dtypes = [_DTYPE_MAP[d] for d in args.dtypes.split(",")] | |
| quant_modes = args.quant.split(",") | |
| if args.no_fp8: | |
| quant_modes = [q for q in quant_modes if q == "none"] | |
| configs = [] | |
| if not args.no_fwd: | |
| for dtype in dtypes: | |
| for q in quant_modes: | |
| for (M, N) in shapes: | |
| configs.append(("fwd", dtype, q, M, N)) | |
| if not args.no_bwd: | |
| # bwd kernels don't support quantization | |
| for dtype in dtypes: | |
| for (M, N) in shapes: | |
| configs.append(("bwd", dtype, "none", M, N)) | |
| return configs | |
| def run_suite(args): | |
| te_mod = _import_te() | |
| configs = _enumerate_configs(args) | |
| total = len(configs) | |
| idx_w = max(len(str(total)), 1) | |
| if args.verbose: | |
| print(f"# running {total} configurations " | |
| f"(warmup={args.warmup_ms}ms, rep={args.rep_ms}ms, " | |
| f"autotune={args.autotune})", | |
| file=sys.stderr, flush=True) | |
| if not te_mod["bwd_accepts_autotune"]: | |
| print("# note: installed te_rmsnorm_bwd_triton has no `autotune` kwarg " | |
| "(baseline build); --autotune/--no-autotune affects fwd only.", | |
| file=sys.stderr, flush=True) | |
| rows = [] | |
| suite_t0 = time.perf_counter() | |
| for i, (kind, dtype, q, M, N) in enumerate(configs, 1): | |
| dtype_s = str(dtype).split(".")[-1] | |
| shape_s = f"{M}x{N}" | |
| if args.verbose: | |
| print(f"[{i:>{idx_w}}/{total}] {kind} {dtype_s:<8} {q:<6} " | |
| f"{shape_s:<14} ... ", | |
| end="", file=sys.stderr, flush=True) | |
| cfg_t0 = time.perf_counter() | |
| try: | |
| if kind == "fwd": | |
| ms, gbs = bench_fwd( | |
| te_mod, M, N, dtype, q, | |
| autotune=args.autotune, | |
| warmup_ms=args.warmup_ms, | |
| rep_ms=args.rep_ms, | |
| ) | |
| else: | |
| ms, gbs = bench_bwd( | |
| te_mod, M, N, dtype, | |
| autotune=args.autotune, | |
| warmup_ms=args.warmup_ms, | |
| rep_ms=args.rep_ms, | |
| ) | |
| elapsed = time.perf_counter() - cfg_t0 | |
| rows.append((kind, args.tag, dtype_s, q, M, N, ms, gbs, "")) | |
| if args.verbose: | |
| print(f"{ms:>9.4f} ms {gbs:>7.1f} GB/s ({elapsed:>5.1f}s)", | |
| file=sys.stderr, flush=True) | |
| except Exception as e: | |
| elapsed = time.perf_counter() - cfg_t0 | |
| note = type(e).__name__ + ": " + str(e)[:80] | |
| rows.append((kind, args.tag, dtype_s, q, M, N, | |
| float("nan"), float("nan"), note)) | |
| if args.verbose: | |
| print(f"FAIL ({elapsed:>5.1f}s) {note}", | |
| file=sys.stderr, flush=True) | |
| if args.fail_fast: | |
| traceback.print_exc() | |
| raise | |
| if args.verbose: | |
| wall = time.perf_counter() - suite_t0 | |
| ok = sum(1 for r in rows if r[6] == r[6]) # NaN check | |
| print(f"# done: {ok}/{total} successful in {wall:.1f}s " | |
| f"({wall / max(total, 1):.2f}s/config avg)", | |
| file=sys.stderr, flush=True) | |
| return rows | |
| def _print_table(rows, info, stream=sys.stdout): | |
| if info: | |
| print("# " + " | ".join(f"{k}={v}" for k, v in info.items()), file=stream) | |
| hdr = f"{'kind':<4} {'tag':<10} {'dtype':<8} {'quant':<6} {'M':>6} {'N':>6} {'ms':>10} {'GB/s':>9} notes" | |
| print(hdr, file=stream) | |
| print("-" * len(hdr), file=stream) | |
| for r in sorted(rows, key=lambda r: (r[0], r[2], r[3], r[5], r[4])): | |
| kind, tag, dtype, q, M, N, ms, gbs, note = r | |
| if ms != ms: # NaN | |
| print(f"{kind:<4} {tag:<10} {dtype:<8} {q:<6} {M:>6} {N:>6} {'':>10} {'':>9} {note}", file=stream) | |
| else: | |
| print(f"{kind:<4} {tag:<10} {dtype:<8} {q:<6} {M:>6} {N:>6} {ms:>10.4f} {gbs:>9.1f} {note}", file=stream) | |
| def _write_tsv(rows, info, path): | |
| with open(path, "w", newline="") as f: | |
| w = csv.writer(f, delimiter="\t") | |
| for k, v in info.items(): | |
| f.write(f"# {k}={v}\n") | |
| w.writerow(["kind", "tag", "dtype", "quant", "M", "N", "ms", "GBs", "note"]) | |
| for r in rows: | |
| w.writerow(r) | |
| def _load_tsv(path): | |
| rows = [] | |
| info = {} | |
| with open(path) as f: | |
| for line in f: | |
| if line.startswith("#"): | |
| kv = line[1:].strip().split("=", 1) | |
| if len(kv) == 2: | |
| info[kv[0].strip()] = kv[1].strip() | |
| continue | |
| break | |
| with open(path) as f: | |
| r = csv.reader(f, delimiter="\t") | |
| header = None | |
| for row in r: | |
| if not row or row[0].startswith("#"): | |
| continue | |
| if header is None: | |
| header = row | |
| continue | |
| rows.append(tuple(row)) | |
| return rows, info | |
| def compare(path_a, path_b): | |
| rows_a, info_a = _load_tsv(path_a) | |
| rows_b, info_b = _load_tsv(path_b) | |
| def key(r): | |
| return (r[0], r[2], r[3], int(r[4]), int(r[5])) | |
| a_map = {key(r): r for r in rows_a} | |
| b_map = {key(r): r for r in rows_b} | |
| keys = sorted(set(a_map) | set(b_map)) | |
| print(f"# A: {path_a} ({info_a})") | |
| print(f"# B: {path_b} ({info_b})") | |
| hdr = (f"{'kind':<4} {'dtype':<8} {'quant':<6} {'M':>6} {'N':>6} " | |
| f"{'A ms':>10} {'B ms':>10} {'A GB/s':>9} {'B GB/s':>9} " | |
| f"{'speedup':>9}") | |
| print(hdr) | |
| print("-" * len(hdr)) | |
| speedups = [] | |
| for k in keys: | |
| ra = a_map.get(k) | |
| rb = b_map.get(k) | |
| kind, dtype, q, M, N = k | |
| if ra is None: | |
| print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} (A missing)") | |
| continue | |
| if rb is None: | |
| print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} (B missing)") | |
| continue | |
| try: | |
| a_ms = float(ra[6]); a_gbs = float(ra[7]) | |
| b_ms = float(rb[6]); b_gbs = float(rb[7]) | |
| speedup = a_ms / b_ms if b_ms > 0 else float("nan") | |
| speedups.append(speedup) | |
| marker = "" | |
| if speedup >= 1.05: | |
| marker = " +" | |
| elif speedup <= 0.95: | |
| marker = " -" | |
| print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} " | |
| f"{a_ms:>10.4f} {b_ms:>10.4f} {a_gbs:>9.1f} {b_gbs:>9.1f} " | |
| f"{speedup:>9.2f}x{marker}") | |
| except (ValueError, IndexError): | |
| print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} (parse error)") | |
| if speedups: | |
| print() | |
| print(f"# Geomean speedup (B vs A): " | |
| f"{statistics.geometric_mean(speedups):.3f}x " | |
| f"min={min(speedups):.3f}x max={max(speedups):.3f}x " | |
| f"n={len(speedups)}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Standalone benchmark for Triton RMSNorm kernels.", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__, | |
| ) | |
| parser.add_argument("--tag", default="run", help="Label for this run.") | |
| parser.add_argument("--out", default=None, | |
| help="Write results as TSV to this path.") | |
| parser.add_argument("--shapes", default=None, | |
| help="Comma-separated MxN list, e.g. 8192x4096,2048x16384") | |
| parser.add_argument("--dtypes", default=",".join(DEFAULT_DTYPES), | |
| help="Comma-separated dtypes from {bf16,fp16,fp32}.") | |
| parser.add_argument("--quant", default=",".join(DEFAULT_QUANT_MODES), | |
| help="Comma-separated quant modes from {none,fp8,fp8_t}.") | |
| parser.add_argument("--no-fp8", action="store_true", | |
| help="Drop fp8/fp8_t entries.") | |
| parser.add_argument("--no-fwd", action="store_true") | |
| parser.add_argument("--no-bwd", action="store_true") | |
| parser.add_argument("--autotune", dest="autotune", action="store_true", | |
| default=True, help="Enable Triton autotune (default).") | |
| parser.add_argument("--no-autotune", dest="autotune", action="store_false") | |
| parser.add_argument("--warmup-ms", type=int, default=100) | |
| parser.add_argument("--rep-ms", type=int, default=300) | |
| parser.add_argument("--fail-fast", action="store_true", | |
| help="Re-raise first benchmark exception with traceback.") | |
| parser.add_argument("-v", "--verbose", action="store_true", | |
| help="Print per-configuration progress (one line each) " | |
| "to stderr, including wall-clock time per config.") | |
| parser.add_argument("--compare", nargs=2, metavar=("A.tsv", "B.tsv"), | |
| help="Compare two TSV result files (no GPU needed).") | |
| args = parser.parse_args() | |
| if args.compare: | |
| compare(*args.compare) | |
| return 0 | |
| if not torch.cuda.is_available(): | |
| print("error: CUDA/ROCm device required", file=sys.stderr) | |
| return 2 | |
| info = _device_info() | |
| info["tag"] = args.tag | |
| rows = run_suite(args) | |
| _print_table(rows, info) | |
| if args.out: | |
| _write_tsv(rows, info, args.out) | |
| print(f"\n# wrote {len(rows)} rows to {args.out}", file=sys.stderr) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment