Skip to content

Instantly share code, notes, and snippets.

@yberreby
Last active April 20, 2025 01:16
Show Gist options
  • Save yberreby/a307e44e49da65be11b0e1e5489604a3 to your computer and use it in GitHub Desktop.
Save yberreby/a307e44e49da65be11b0e1e5489604a3 to your computer and use it in GitHub Desktop.
Quick JAX vs Triton comparison on a toy kernel. Outputs from runs on a RTX 4060 Mobile.
import functools, time, jax, jax.numpy as jnp
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
SQRT2_OVER_PI = 0.7978845608028654
# ----------------------------------------------------------------------
def gelu_fast(x):
u = SQRT2_OVER_PI * (x + 0.044715 * x * x * x)
return 0.5 * x * (1. + jnp.tanh(u))
def build_fused(variant: str):
"""Return a JIT‑compiled fused op for the chosen GELU variant."""
@functools.partial(jax.jit, static_argnames="alpha_beta")
def fused(a, b, c, *, alpha_beta=(1.0, 1.0)):
alpha, beta = alpha_beta
z = jnp.matmul(a, b, precision=jax.lax.Precision.HIGH)
z = (z * alpha).astype(jnp.float16)
if variant == "fast":
z = gelu_fast(z)
elif variant == "erf":
z = 0.5 * z * (1. + jax.scipy.special.erf(z / jnp.sqrt(2.0)))
elif variant == "nn_exact":
z = jax.nn.gelu(z, approximate=False)
elif variant == "nn_approx":
z = jax.nn.gelu(z, approximate=True)
return (z + beta * c).astype(jnp.float16)
return fused
# ----------------------------------------------------------------------
def bench_jax(N: int, variant: str, rep=10) -> float:
key = jax.random.PRNGKey(0)
a = jax.random.normal(key, (N, N), dtype=jnp.float16); key, _ = jax.random.split(key)
b = jax.random.normal(key, (N, N), dtype=jnp.float16); key, _ = jax.random.split(key)
c = jax.random.normal(key, (N, N), dtype=jnp.float16)
fused = build_fused(variant)
fused(a, b, c).block_until_ready() # compile + warm‑up
t0 = time.perf_counter()
for _ in range(rep):
fused(a, b, c).block_until_ready()
t1 = time.perf_counter()
return (t1 - t0)*1e3/rep # ms
# ----------------------------------------------------------------------
if __name__ == "__main__":
variants = ["fast", "erf", "nn_exact", "nn_approx"]
sizes = [512, 1024, 2048, 4096, 8192]
rows = {v: [] for v in variants}
for N in sizes:
for v in variants:
rows[v].append(f"{bench_jax(N, v):7.3f}")
header = "N " + " ".join(f"{v:>7}" for v in variants)
print(header)
print("-" * len(header))
for i, N in enumerate(sizes):
print(f"{N:<4}" + " ".join(rows[v][i] for v in variants))
# bench_triton.py: Triton >=3.3, PyTorch >=2.3
import torch
import triton
import triton.language as tl
# ----------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------
# Set verbose=True to see the autotuned config for each run
VERBOSE_AUTOTUNE = True
# ----------------------------------------------------------------------
# Constants (constexpr for Triton kernels, not tunable)
# ----------------------------------------------------------------------
SQRT2_OVER_PI = tl.constexpr(0.7978845608028654) # sqrt(2/pi)
INV_SQRT2 = tl.constexpr(0.7071067811865476) # 1/sqrt(2)
# ----------------------------------------------------------------------
# Triton kernel: GEMM + GELU (fast or erf) + residual
# ----------------------------------------------------------------------
@triton.autotune(
configs=[
# A few reasonable configurations, balancing block sizes and K-dimension chunk size
# Good general purpose, medium block, larger K
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_warps': 4, 'num_stages': 4}, num_ctas=1),
# Larger M dimension preference
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_warps': 8, 'num_stages': 3}, num_ctas=1),
# Larger N dimension preference
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'num_warps': 8, 'num_stages': 3}, num_ctas=1),
# Smaller K dimension chunk size, balanced blocks
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'num_warps': 4, 'num_stages': 4}, num_ctas=1),
],
key=['M', 'N', 'K', 'USE_ERF'], # Arguments that determine the kernel variant
# cache_results=True # Set to True to cache autotuning results to disk across script runs
)
@triton.jit
def fused_kernel(
A, B, C, D,
alpha_ptr, beta_ptr,
M, N, K,
sAm, sAk, sBk, sBn, sCm, sCn, sDm, sDn,
# Tunable parameters are now constexpr arguments
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
# Other kernel parameters (like USE_ERF)
USE_ERF: tl.constexpr,
# num_warps and num_stages are used internally by Triton based on the config
):
pid = tl.program_id(0)
# Calculate grid dimensions based on tuned block sizes
num_m = tl.cdiv(M, BLOCK_M)
num_n = tl.cdiv(N, BLOCK_N)
# Program ID calculation (assuming grid is 1D: num_m * num_n)
pm = pid // num_n
pn = pid % num_n
# Define ranges for the current block
rm = pm * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pn * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# Initialize accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)
# Pointers to the input matrices A and B for the current block
a_ptr = A + rm[:, None] * sAm + rk[None, :] * sAk
b_ptr = B + rk[:, None] * sBk + rn[None, :] * sBn
# Matrix multiply main loop
k_left = K
while k_left > 0:
# Correctly mask M, N, and K dimensions for loads
mask_a = (rm[:, None] < M) & (rk[None, :] < k_left)
mask_b = (rk[:, None] < k_left) & (rn[None, :] < N)
a = tl.load(a_ptr, mask=mask_a, other=0.0)
b = tl.load(b_ptr, mask=mask_b, other=0.0)
acc = tl.dot(a, b, acc)
# Advance pointers for the next iteration
a_ptr += BLOCK_K * sAk
b_ptr += BLOCK_K * sBk
k_left -= BLOCK_K
# --- Post-processing ---
alpha = tl.load(alpha_ptr)
beta = tl.load(beta_ptr)
# Scale accumulator
out16 = (acc * alpha).to(tl.float16)
# Inline GELU computation
if USE_ERF:
# erf-based GELU
tmp = out16.to(tl.float32) * INV_SQRT2
e = tl.erf(tmp)
gelu16 = 0.5 * out16 * (1.0 + e.to(out16.dtype))
else:
# Fast tanh-based GELU
x32 = out16.to(tl.float32)
u = SQRT2_OVER_PI * (x32 + 0.044715 * x32 * x32 * x32)
t = (2.0 * tl.sigmoid(2.0 * u) - 1.0).to(out16.dtype)
gelu16 = 0.5 * out16 * (1.0 + t)
# Define mask for C/D access (depends on M and N)
mask_cd = (rm[:, None] < M) & (rn[None, :] < N)
# Load residual C
c_ptr = C + rm[:, None]*sCm + rn[None,:]*sCn
c16 = tl.load(c_ptr, mask=mask_cd, other=0.)
# Compute final output D = GELU(alpha * A @ B) + beta * C
output = gelu16 + beta * c16
# Store the result D
d_ptr = D + rm[:, None]*sDm + rn[None,:]*sDn
tl.store(d_ptr, output, mask=mask_cd)
# ----------------------------------------------------------------------
# Host wrapper & benchmark
# ----------------------------------------------------------------------
def fused_triton(A, B, C, variant: str, alpha=1.0, beta=1.0):
"""Host wrapper for the fused kernel."""
use_erf = 1 if variant == 'erf' else 0
M, K = A.shape
K_B, N = B.shape # Use K_B to check compatibility
assert K == K_B, f"Incompatible dimensions: A({M},{K}), B({K_B},{N})"
assert C.shape == (M, N), f"Incompatible dimensions: C({C.shape}), expected ({M},{N})"
D = torch.empty((M, N), device='cuda', dtype=torch.float16)
# Grid calculation is a lambda function passed to the kernel launcher
# It uses the 'meta' dictionary containing the tuned block sizes
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']) * triton.cdiv(N, meta['BLOCK_N']),)
# Ensure scalar alpha/beta are tensors on the correct device
alpha_t = torch.tensor([alpha], device='cuda', dtype=torch.float16)
beta_t = torch.tensor([beta], device='cuda', dtype=torch.float16)
# Launch the autotuned kernel
fused_kernel[grid](
A, B, C, D,
alpha_t, beta_t,
M, N, K,
A.stride(0), A.stride(1), # sAm, sAk
B.stride(0), B.stride(1), # sBk, sBn
C.stride(0), C.stride(1), # sCm, sCn
D.stride(0), D.stride(1), # sDm, sDn
USE_ERF=use_erf
)
return D
def bench_triton(N: int, variant: str, rep: int = 10) -> float:
"""Benchmarks the fused Triton kernel for a given size and variant."""
torch.manual_seed(0)
# Ensure tensors are contiguous for predictable strides
A = torch.randn((N, N), device='cuda', dtype=torch.float16).contiguous()
B = torch.randn((N, N), device='cuda', dtype=torch.float16).contiguous()
C = torch.randn((N, N), device='cuda', dtype=torch.float16).contiguous()
D_dummy = torch.empty_like(C) # For warmup
M, K = A.shape
_, N_dim = B.shape
alpha_t = torch.tensor([1.0], device='cuda', dtype=torch.float16)
beta_t = torch.tensor([1.0], device='cuda', dtype=torch.float16)
use_erf = 1 if variant == 'erf' else 0
# Define grid lambda function needed for both warmup and actual call
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']) * triton.cdiv(N_dim, meta['BLOCK_N']),)
# Optional: Print config chosen by autotuner after warm-up
if VERBOSE_AUTOTUNE:
print(f"\n--- Autotuning for N={N}, variant='{variant}' ---")
# Pass the grid lambda and all required args to warmup
fused_kernel.warmup(
A, B, C, D_dummy,
alpha_t, beta_t,
M, N_dim, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
D_dummy.stride(0), D_dummy.stride(1),
grid=grid, # Pass grid lambda
USE_ERF=use_erf
)
torch.cuda.synchronize()
# Ensure the best_config is available after warmup before printing
if hasattr(fused_kernel, 'best_config'):
print(f"Best config: {fused_kernel.best_config}")
else:
print("Warmup did not produce a best_config (check Triton version or autotuning setup).")
print("------------------------------------------")
# Warm-up call regardless of verbosity to trigger autotuning/compilation
_ = fused_triton(A, B, C, variant)
torch.cuda.synchronize()
# Timing loop using CUDA events for accuracy
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(rep):
_ = fused_triton(A, B, C, variant)
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
return elapsed_ms / rep # Average time in ms
if __name__ == '__main__':
variants = ['fast', 'erf']
sizes = [512, 1024, 2048, 4096, 8192]
rows = {v:[] for v in variants}
print("Running benchmarks (first run for each size/variant includes autotuning)...")
for N_size in sizes:
print(f"Benchmarking N={N_size}...")
for v in variants:
avg_time_ms = bench_triton(N_size, v, rep=20) # Use sufficient reps for stable timing
rows[v].append(f"{avg_time_ms:7.3f}")
# time.sleep(0.5) # Optional delay between benchmarks
# Print results
print("\n--- Benchmark Results (avg ms per call) ---")
header = 'N ' + ' '.join(f"{v:>7}" for v in variants)
print(header)
print('-'*len(header))
for i, N_size in enumerate(sizes):
print(f"{N_size:<4}" + ' '.join(rows[v][i] for v in variants))
gpu-exp on  master [!] is 📦 v0.1.0 via 🐍 v3.13.2
❯ uv run bench_jax.py
N fast erf nn_exact nn_approx
----------------------------------------
512 0.099 0.064 0.087 0.077
1024 0.210 0.170 0.191 0.188
2048 0.917 0.942 0.907 0.891
4096 6.629 5.651 5.646 5.619
8192 47.132 48.638 47.175 47.903
❯ uv run bench_triton.py
Running benchmarks (first run for each size/variant includes autotuning)...
Benchmarking N=512...
--- Autotuning for N=512, variant='fast' ---
Warmup did not produce a best_config (check Triton version or autotuning setup).
------------------------------------------
--- Autotuning for N=512, variant='erf' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
Benchmarking N=1024...
--- Autotuning for N=1024, variant='fast' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
--- Autotuning for N=1024, variant='erf' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
Benchmarking N=2048...
--- Autotuning for N=2048, variant='fast' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
--- Autotuning for N=2048, variant='erf' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
Benchmarking N=4096...
--- Autotuning for N=4096, variant='fast' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
--- Autotuning for N=4096, variant='erf' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
Benchmarking N=8192...
--- Autotuning for N=8192, variant='fast' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
--- Autotuning for N=8192, variant='erf' ---
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None
------------------------------------------
--- Benchmark Results (avg ms per call) ---
N fast erf
-------------------
512 0.088 0.087
1024 0.242 0.225
2048 1.152 1.137
4096 9.448 9.368
8192 79.302 78.870
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment