Skip to content

Instantly share code, notes, and snippets.

@Micky774
Created May 20, 2026 20:26
Show Gist options
  • Select an option

  • Save Micky774/70367aa54c88ae39b77200e8a0284fc8 to your computer and use it in GitHub Desktop.

Select an option

Save Micky774/70367aa54c88ae39b77200e8a0284fc8 to your computer and use it in GitHub Desktop.
Triton RMSNorm benchmark file
#!/usr/bin/env python3
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information
"""
Isolated benchmark for the Triton RMSNorm fwd/bwd kernels.
Workflow for before/after comparison:
# On baseline (before optimization)
python benchmarks/bench_rmsnorm_triton.py --tag baseline --out /tmp/rms_baseline.tsv
# After applying changes
python benchmarks/bench_rmsnorm_triton.py --tag opt --out /tmp/rms_opt.tsv
# Diff
python benchmarks/bench_rmsnorm_triton.py --compare /tmp/rms_baseline.tsv /tmp/rms_opt.tsv
The script measures the public Python entry points
(`te_rmsnorm_fwd_triton`, `te_rmsnorm_bwd_triton`) so dispatch overhead and
post-kernel quantize calls (e.g. MXFP8 path) are included exactly as
production sees them.
Reports median latency (ms) and an "effective" memory bandwidth (GB/s)
computed from the minimal useful I/O (input read, output write, gamma,
rsigma, transpose). Internal scratch buffers (e.g. `dg_tmp`) are not
counted toward useful bytes -- so a kernel that wastes HBM on a large
scratch will show as lower achieved GB/s. That is intentional.
"""
import argparse
import csv
import inspect
import os
import statistics
import sys
import time
import traceback
import torch
import triton
# Lazy-import TE bits to keep --compare usable without a CUDA build.
def _import_te():
from transformer_engine.pytorch import cpp_extensions as tex
from transformer_engine.pytorch.triton_kernels.common import torch_dtype_to_te_dtype
from transformer_engine.pytorch.triton_kernels.norms_common import (
te_rmsnorm_fwd_triton,
te_rmsnorm_bwd_triton,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
# Detect whether the installed bwd launcher accepts an `autotune` kwarg.
# Pre-optimization baselines (and any upstream version we want to compare
# against) don't, so passing it would TypeError.
try:
bwd_accepts_autotune = (
"autotune" in inspect.signature(te_rmsnorm_bwd_triton).parameters
)
except (TypeError, ValueError):
bwd_accepts_autotune = False
return dict(
tex=tex,
torch_dtype_to_te_dtype=torch_dtype_to_te_dtype,
te_rmsnorm_fwd_triton=te_rmsnorm_fwd_triton,
te_rmsnorm_bwd_triton=te_rmsnorm_bwd_triton,
Float8Quantizer=Float8Quantizer,
bwd_accepts_autotune=bwd_accepts_autotune,
)
# Representative (M, N) shapes. Mix of common LLM shapes plus a few
# stress / edge cases. M = batch * seqlen, N = hidden.
#
# Path selector (in transformer_engine/pytorch/triton_kernels/utils.py):
# fwd USE_BLOCKED when N > 65536 / itemsize (bf16 -> N > 32768)
# bwd USE_BLOCKED when N > 16384 / itemsize (bf16 -> N > 8192)
#
# We deliberately include both sides of those boundaries.
DEFAULT_SHAPES = [
(4096, 1024), # tiny, non-blocked everywhere
(8192, 2048), # small llama-ish
(8192, 4096), # llama-7B / common
(8192, 8192), # llama-70B class, bwd at boundary
(16384, 4096), # large batch
(4096, 12288), # wider hidden, bwd blocked
(2048, 16384), # wide
(1024, 32768), # fwd at boundary, bwd blocked
(256, 65536), # extreme wide, fwd blocked
(29, 17389), # odd shape, exercises masked tail
]
DEFAULT_DTYPES = ("bf16", "fp16")
DEFAULT_QUANT_MODES = ("none", "fp8", "fp8_t") # plain, FP8 rowwise, FP8 row+col
_DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def _parse_shapes(spec):
if not spec:
return DEFAULT_SHAPES
out = []
for item in spec.split(","):
item = item.strip()
if not item:
continue
m_str, n_str = item.lower().split("x")
out.append((int(m_str), int(n_str)))
return out
def _device_info():
info = {}
try:
info["torch"] = torch.__version__
info["triton"] = triton.__version__
except Exception:
pass
if torch.cuda.is_available():
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
info["device"] = props.name
info["cu_count"] = props.multi_processor_count
try:
info["arch"] = triton.runtime.driver.active.get_current_target().arch
except Exception:
info["arch"] = f"sm_{props.major}{props.minor}"
return info
def _make_fp8_quantizer(te_mod, columnwise):
scale = torch.ones(1, dtype=torch.float32, device="cuda")
amax = torch.zeros(1, dtype=torch.float32, device="cuda")
return te_mod["Float8Quantizer"](
scale, amax, te_mod["tex"].DType.kFloat8E4M3, columnwise=columnwise
)
def _useful_fwd_bytes(M, N, itemsize, quant_mode):
# Minimum HBM traffic the kernel *should* do, in bytes.
out_elem = 1 if quant_mode != "none" else itemsize
b = M * N * itemsize # read x
b += M * N * out_elem # write rowwise out
if quant_mode == "fp8_t":
b += M * N * out_elem # write transpose
b += N * itemsize # read gamma (counted once)
b += M * 4 # write rsigma
return b
def _useful_bwd_bytes(M, N, itemsize):
b = 2 * M * N * itemsize # read x + dz
b += M * N * itemsize # write dx
b += N * itemsize # read gamma
b += N * itemsize # write final dgamma
b += M * 4 # read rsigma
return b
def _do_bench(fn, warmup_ms, rep_ms):
# Manual warmup first so autotune (if any) completes before we time.
for _ in range(5):
fn()
torch.cuda.synchronize()
return triton.testing.do_bench(
fn, warmup=warmup_ms, rep=rep_ms, quantiles=[0.5]
)
def bench_fwd(te_mod, M, N, dtype, quant_mode, autotune, warmup_ms, rep_ms):
x = (torch.randn(M, N, dtype=dtype, device="cuda") * 0.1)
g = torch.randn(N, dtype=dtype, device="cuda")
if quant_mode == "none":
quantizer = None
otype = te_mod["torch_dtype_to_te_dtype"](dtype)
else:
quantizer = _make_fp8_quantizer(te_mod, columnwise=(quant_mode == "fp8_t"))
otype = te_mod["tex"].DType.kFloat8E4M3
fwd = te_mod["te_rmsnorm_fwd_triton"]
def call():
return fwd(
input=x,
weight=g,
eps=1e-5,
ln_out=None,
quantizer=quantizer,
otype=otype,
sm_margin=0,
zero_centered_gamma=False,
autotune=autotune,
)
ms = _do_bench(call, warmup_ms, rep_ms)
gbs = _useful_fwd_bytes(M, N, x.element_size(), quant_mode) / (ms * 1e-3) / 1e9
return ms, gbs
def bench_bwd(te_mod, M, N, dtype, autotune, warmup_ms, rep_ms):
x = (torch.randn(M, N, dtype=dtype, device="cuda") * 0.1)
g = torch.randn(N, dtype=dtype, device="cuda")
dz = (torch.randn(M, N, dtype=dtype, device="cuda") * 0.1)
# Produce a real rsigma from the fwd kernel so we measure the bwd
# in the same numerical regime as production.
_, _, rsigma = te_mod["te_rmsnorm_fwd_triton"](
input=x,
weight=g,
eps=1e-5,
ln_out=None,
quantizer=None,
otype=te_mod["torch_dtype_to_te_dtype"](dtype),
sm_margin=0,
zero_centered_gamma=False,
autotune=False,
)
bwd = te_mod["te_rmsnorm_bwd_triton"]
if te_mod["bwd_accepts_autotune"]:
def call():
return bwd(dz, x, rsigma, g, 0, False, autotune=autotune)
else:
# Baseline kernel has no autotune kwarg; `--autotune` / `--no-autotune`
# silently no-ops for bwd in that case.
def call():
return bwd(dz, x, rsigma, g, 0, False)
ms = _do_bench(call, warmup_ms, rep_ms)
gbs = _useful_bwd_bytes(M, N, x.element_size()) / (ms * 1e-3) / 1e9
return ms, gbs
def _enumerate_configs(args):
"""Build the full (kind, dtype, quant, M, N) task list."""
shapes = _parse_shapes(args.shapes)
dtypes = [_DTYPE_MAP[d] for d in args.dtypes.split(",")]
quant_modes = args.quant.split(",")
if args.no_fp8:
quant_modes = [q for q in quant_modes if q == "none"]
configs = []
if not args.no_fwd:
for dtype in dtypes:
for q in quant_modes:
for (M, N) in shapes:
configs.append(("fwd", dtype, q, M, N))
if not args.no_bwd:
# bwd kernels don't support quantization
for dtype in dtypes:
for (M, N) in shapes:
configs.append(("bwd", dtype, "none", M, N))
return configs
def run_suite(args):
te_mod = _import_te()
configs = _enumerate_configs(args)
total = len(configs)
idx_w = max(len(str(total)), 1)
if args.verbose:
print(f"# running {total} configurations "
f"(warmup={args.warmup_ms}ms, rep={args.rep_ms}ms, "
f"autotune={args.autotune})",
file=sys.stderr, flush=True)
if not te_mod["bwd_accepts_autotune"]:
print("# note: installed te_rmsnorm_bwd_triton has no `autotune` kwarg "
"(baseline build); --autotune/--no-autotune affects fwd only.",
file=sys.stderr, flush=True)
rows = []
suite_t0 = time.perf_counter()
for i, (kind, dtype, q, M, N) in enumerate(configs, 1):
dtype_s = str(dtype).split(".")[-1]
shape_s = f"{M}x{N}"
if args.verbose:
print(f"[{i:>{idx_w}}/{total}] {kind} {dtype_s:<8} {q:<6} "
f"{shape_s:<14} ... ",
end="", file=sys.stderr, flush=True)
cfg_t0 = time.perf_counter()
try:
if kind == "fwd":
ms, gbs = bench_fwd(
te_mod, M, N, dtype, q,
autotune=args.autotune,
warmup_ms=args.warmup_ms,
rep_ms=args.rep_ms,
)
else:
ms, gbs = bench_bwd(
te_mod, M, N, dtype,
autotune=args.autotune,
warmup_ms=args.warmup_ms,
rep_ms=args.rep_ms,
)
elapsed = time.perf_counter() - cfg_t0
rows.append((kind, args.tag, dtype_s, q, M, N, ms, gbs, ""))
if args.verbose:
print(f"{ms:>9.4f} ms {gbs:>7.1f} GB/s ({elapsed:>5.1f}s)",
file=sys.stderr, flush=True)
except Exception as e:
elapsed = time.perf_counter() - cfg_t0
note = type(e).__name__ + ": " + str(e)[:80]
rows.append((kind, args.tag, dtype_s, q, M, N,
float("nan"), float("nan"), note))
if args.verbose:
print(f"FAIL ({elapsed:>5.1f}s) {note}",
file=sys.stderr, flush=True)
if args.fail_fast:
traceback.print_exc()
raise
if args.verbose:
wall = time.perf_counter() - suite_t0
ok = sum(1 for r in rows if r[6] == r[6]) # NaN check
print(f"# done: {ok}/{total} successful in {wall:.1f}s "
f"({wall / max(total, 1):.2f}s/config avg)",
file=sys.stderr, flush=True)
return rows
def _print_table(rows, info, stream=sys.stdout):
if info:
print("# " + " | ".join(f"{k}={v}" for k, v in info.items()), file=stream)
hdr = f"{'kind':<4} {'tag':<10} {'dtype':<8} {'quant':<6} {'M':>6} {'N':>6} {'ms':>10} {'GB/s':>9} notes"
print(hdr, file=stream)
print("-" * len(hdr), file=stream)
for r in sorted(rows, key=lambda r: (r[0], r[2], r[3], r[5], r[4])):
kind, tag, dtype, q, M, N, ms, gbs, note = r
if ms != ms: # NaN
print(f"{kind:<4} {tag:<10} {dtype:<8} {q:<6} {M:>6} {N:>6} {'':>10} {'':>9} {note}", file=stream)
else:
print(f"{kind:<4} {tag:<10} {dtype:<8} {q:<6} {M:>6} {N:>6} {ms:>10.4f} {gbs:>9.1f} {note}", file=stream)
def _write_tsv(rows, info, path):
with open(path, "w", newline="") as f:
w = csv.writer(f, delimiter="\t")
for k, v in info.items():
f.write(f"# {k}={v}\n")
w.writerow(["kind", "tag", "dtype", "quant", "M", "N", "ms", "GBs", "note"])
for r in rows:
w.writerow(r)
def _load_tsv(path):
rows = []
info = {}
with open(path) as f:
for line in f:
if line.startswith("#"):
kv = line[1:].strip().split("=", 1)
if len(kv) == 2:
info[kv[0].strip()] = kv[1].strip()
continue
break
with open(path) as f:
r = csv.reader(f, delimiter="\t")
header = None
for row in r:
if not row or row[0].startswith("#"):
continue
if header is None:
header = row
continue
rows.append(tuple(row))
return rows, info
def compare(path_a, path_b):
rows_a, info_a = _load_tsv(path_a)
rows_b, info_b = _load_tsv(path_b)
def key(r):
return (r[0], r[2], r[3], int(r[4]), int(r[5]))
a_map = {key(r): r for r in rows_a}
b_map = {key(r): r for r in rows_b}
keys = sorted(set(a_map) | set(b_map))
print(f"# A: {path_a} ({info_a})")
print(f"# B: {path_b} ({info_b})")
hdr = (f"{'kind':<4} {'dtype':<8} {'quant':<6} {'M':>6} {'N':>6} "
f"{'A ms':>10} {'B ms':>10} {'A GB/s':>9} {'B GB/s':>9} "
f"{'speedup':>9}")
print(hdr)
print("-" * len(hdr))
speedups = []
for k in keys:
ra = a_map.get(k)
rb = b_map.get(k)
kind, dtype, q, M, N = k
if ra is None:
print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} (A missing)")
continue
if rb is None:
print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} (B missing)")
continue
try:
a_ms = float(ra[6]); a_gbs = float(ra[7])
b_ms = float(rb[6]); b_gbs = float(rb[7])
speedup = a_ms / b_ms if b_ms > 0 else float("nan")
speedups.append(speedup)
marker = ""
if speedup >= 1.05:
marker = " +"
elif speedup <= 0.95:
marker = " -"
print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} "
f"{a_ms:>10.4f} {b_ms:>10.4f} {a_gbs:>9.1f} {b_gbs:>9.1f} "
f"{speedup:>9.2f}x{marker}")
except (ValueError, IndexError):
print(f"{kind:<4} {dtype:<8} {q:<6} {M:>6} {N:>6} (parse error)")
if speedups:
print()
print(f"# Geomean speedup (B vs A): "
f"{statistics.geometric_mean(speedups):.3f}x "
f"min={min(speedups):.3f}x max={max(speedups):.3f}x "
f"n={len(speedups)}")
def main():
parser = argparse.ArgumentParser(
description="Standalone benchmark for Triton RMSNorm kernels.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument("--tag", default="run", help="Label for this run.")
parser.add_argument("--out", default=None,
help="Write results as TSV to this path.")
parser.add_argument("--shapes", default=None,
help="Comma-separated MxN list, e.g. 8192x4096,2048x16384")
parser.add_argument("--dtypes", default=",".join(DEFAULT_DTYPES),
help="Comma-separated dtypes from {bf16,fp16,fp32}.")
parser.add_argument("--quant", default=",".join(DEFAULT_QUANT_MODES),
help="Comma-separated quant modes from {none,fp8,fp8_t}.")
parser.add_argument("--no-fp8", action="store_true",
help="Drop fp8/fp8_t entries.")
parser.add_argument("--no-fwd", action="store_true")
parser.add_argument("--no-bwd", action="store_true")
parser.add_argument("--autotune", dest="autotune", action="store_true",
default=True, help="Enable Triton autotune (default).")
parser.add_argument("--no-autotune", dest="autotune", action="store_false")
parser.add_argument("--warmup-ms", type=int, default=100)
parser.add_argument("--rep-ms", type=int, default=300)
parser.add_argument("--fail-fast", action="store_true",
help="Re-raise first benchmark exception with traceback.")
parser.add_argument("-v", "--verbose", action="store_true",
help="Print per-configuration progress (one line each) "
"to stderr, including wall-clock time per config.")
parser.add_argument("--compare", nargs=2, metavar=("A.tsv", "B.tsv"),
help="Compare two TSV result files (no GPU needed).")
args = parser.parse_args()
if args.compare:
compare(*args.compare)
return 0
if not torch.cuda.is_available():
print("error: CUDA/ROCm device required", file=sys.stderr)
return 2
info = _device_info()
info["tag"] = args.tag
rows = run_suite(args)
_print_table(rows, info)
if args.out:
_write_tsv(rows, info, args.out)
print(f"\n# wrote {len(rows)} rows to {args.out}", file=sys.stderr)
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment