Last active
April 20, 2025 01:16
-
-
Save yberreby/a307e44e49da65be11b0e1e5489604a3 to your computer and use it in GitHub Desktop.
Quick JAX vs Triton comparison on a toy kernel. Outputs from runs on a RTX 4060 Mobile.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools, time, jax, jax.numpy as jnp | |
jax.config.update("jax_default_matmul_precision", "tensorfloat32") | |
SQRT2_OVER_PI = 0.7978845608028654 | |
# ---------------------------------------------------------------------- | |
def gelu_fast(x): | |
u = SQRT2_OVER_PI * (x + 0.044715 * x * x * x) | |
return 0.5 * x * (1. + jnp.tanh(u)) | |
def build_fused(variant: str): | |
"""Return a JIT‑compiled fused op for the chosen GELU variant.""" | |
@functools.partial(jax.jit, static_argnames="alpha_beta") | |
def fused(a, b, c, *, alpha_beta=(1.0, 1.0)): | |
alpha, beta = alpha_beta | |
z = jnp.matmul(a, b, precision=jax.lax.Precision.HIGH) | |
z = (z * alpha).astype(jnp.float16) | |
if variant == "fast": | |
z = gelu_fast(z) | |
elif variant == "erf": | |
z = 0.5 * z * (1. + jax.scipy.special.erf(z / jnp.sqrt(2.0))) | |
elif variant == "nn_exact": | |
z = jax.nn.gelu(z, approximate=False) | |
elif variant == "nn_approx": | |
z = jax.nn.gelu(z, approximate=True) | |
return (z + beta * c).astype(jnp.float16) | |
return fused | |
# ---------------------------------------------------------------------- | |
def bench_jax(N: int, variant: str, rep=10) -> float: | |
key = jax.random.PRNGKey(0) | |
a = jax.random.normal(key, (N, N), dtype=jnp.float16); key, _ = jax.random.split(key) | |
b = jax.random.normal(key, (N, N), dtype=jnp.float16); key, _ = jax.random.split(key) | |
c = jax.random.normal(key, (N, N), dtype=jnp.float16) | |
fused = build_fused(variant) | |
fused(a, b, c).block_until_ready() # compile + warm‑up | |
t0 = time.perf_counter() | |
for _ in range(rep): | |
fused(a, b, c).block_until_ready() | |
t1 = time.perf_counter() | |
return (t1 - t0)*1e3/rep # ms | |
# ---------------------------------------------------------------------- | |
if __name__ == "__main__": | |
variants = ["fast", "erf", "nn_exact", "nn_approx"] | |
sizes = [512, 1024, 2048, 4096, 8192] | |
rows = {v: [] for v in variants} | |
for N in sizes: | |
for v in variants: | |
rows[v].append(f"{bench_jax(N, v):7.3f}") | |
header = "N " + " ".join(f"{v:>7}" for v in variants) | |
print(header) | |
print("-" * len(header)) | |
for i, N in enumerate(sizes): | |
print(f"{N:<4}" + " ".join(rows[v][i] for v in variants)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# bench_triton.py: Triton >=3.3, PyTorch >=2.3 | |
import torch | |
import triton | |
import triton.language as tl | |
# ---------------------------------------------------------------------- | |
# Configuration | |
# ---------------------------------------------------------------------- | |
# Set verbose=True to see the autotuned config for each run | |
VERBOSE_AUTOTUNE = True | |
# ---------------------------------------------------------------------- | |
# Constants (constexpr for Triton kernels, not tunable) | |
# ---------------------------------------------------------------------- | |
SQRT2_OVER_PI = tl.constexpr(0.7978845608028654) # sqrt(2/pi) | |
INV_SQRT2 = tl.constexpr(0.7071067811865476) # 1/sqrt(2) | |
# ---------------------------------------------------------------------- | |
# Triton kernel: GEMM + GELU (fast or erf) + residual | |
# ---------------------------------------------------------------------- | |
@triton.autotune( | |
configs=[ | |
# A few reasonable configurations, balancing block sizes and K-dimension chunk size | |
# Good general purpose, medium block, larger K | |
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_warps': 4, 'num_stages': 4}, num_ctas=1), | |
# Larger M dimension preference | |
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'num_warps': 8, 'num_stages': 3}, num_ctas=1), | |
# Larger N dimension preference | |
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'num_warps': 8, 'num_stages': 3}, num_ctas=1), | |
# Smaller K dimension chunk size, balanced blocks | |
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'num_warps': 4, 'num_stages': 4}, num_ctas=1), | |
], | |
key=['M', 'N', 'K', 'USE_ERF'], # Arguments that determine the kernel variant | |
# cache_results=True # Set to True to cache autotuning results to disk across script runs | |
) | |
@triton.jit | |
def fused_kernel( | |
A, B, C, D, | |
alpha_ptr, beta_ptr, | |
M, N, K, | |
sAm, sAk, sBk, sBn, sCm, sCn, sDm, sDn, | |
# Tunable parameters are now constexpr arguments | |
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, | |
# Other kernel parameters (like USE_ERF) | |
USE_ERF: tl.constexpr, | |
# num_warps and num_stages are used internally by Triton based on the config | |
): | |
pid = tl.program_id(0) | |
# Calculate grid dimensions based on tuned block sizes | |
num_m = tl.cdiv(M, BLOCK_M) | |
num_n = tl.cdiv(N, BLOCK_N) | |
# Program ID calculation (assuming grid is 1D: num_m * num_n) | |
pm = pid // num_n | |
pn = pid % num_n | |
# Define ranges for the current block | |
rm = pm * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pn * BLOCK_N + tl.arange(0, BLOCK_N) | |
rk = tl.arange(0, BLOCK_K) | |
# Initialize accumulator | |
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32) | |
# Pointers to the input matrices A and B for the current block | |
a_ptr = A + rm[:, None] * sAm + rk[None, :] * sAk | |
b_ptr = B + rk[:, None] * sBk + rn[None, :] * sBn | |
# Matrix multiply main loop | |
k_left = K | |
while k_left > 0: | |
# Correctly mask M, N, and K dimensions for loads | |
mask_a = (rm[:, None] < M) & (rk[None, :] < k_left) | |
mask_b = (rk[:, None] < k_left) & (rn[None, :] < N) | |
a = tl.load(a_ptr, mask=mask_a, other=0.0) | |
b = tl.load(b_ptr, mask=mask_b, other=0.0) | |
acc = tl.dot(a, b, acc) | |
# Advance pointers for the next iteration | |
a_ptr += BLOCK_K * sAk | |
b_ptr += BLOCK_K * sBk | |
k_left -= BLOCK_K | |
# --- Post-processing --- | |
alpha = tl.load(alpha_ptr) | |
beta = tl.load(beta_ptr) | |
# Scale accumulator | |
out16 = (acc * alpha).to(tl.float16) | |
# Inline GELU computation | |
if USE_ERF: | |
# erf-based GELU | |
tmp = out16.to(tl.float32) * INV_SQRT2 | |
e = tl.erf(tmp) | |
gelu16 = 0.5 * out16 * (1.0 + e.to(out16.dtype)) | |
else: | |
# Fast tanh-based GELU | |
x32 = out16.to(tl.float32) | |
u = SQRT2_OVER_PI * (x32 + 0.044715 * x32 * x32 * x32) | |
t = (2.0 * tl.sigmoid(2.0 * u) - 1.0).to(out16.dtype) | |
gelu16 = 0.5 * out16 * (1.0 + t) | |
# Define mask for C/D access (depends on M and N) | |
mask_cd = (rm[:, None] < M) & (rn[None, :] < N) | |
# Load residual C | |
c_ptr = C + rm[:, None]*sCm + rn[None,:]*sCn | |
c16 = tl.load(c_ptr, mask=mask_cd, other=0.) | |
# Compute final output D = GELU(alpha * A @ B) + beta * C | |
output = gelu16 + beta * c16 | |
# Store the result D | |
d_ptr = D + rm[:, None]*sDm + rn[None,:]*sDn | |
tl.store(d_ptr, output, mask=mask_cd) | |
# ---------------------------------------------------------------------- | |
# Host wrapper & benchmark | |
# ---------------------------------------------------------------------- | |
def fused_triton(A, B, C, variant: str, alpha=1.0, beta=1.0): | |
"""Host wrapper for the fused kernel.""" | |
use_erf = 1 if variant == 'erf' else 0 | |
M, K = A.shape | |
K_B, N = B.shape # Use K_B to check compatibility | |
assert K == K_B, f"Incompatible dimensions: A({M},{K}), B({K_B},{N})" | |
assert C.shape == (M, N), f"Incompatible dimensions: C({C.shape}), expected ({M},{N})" | |
D = torch.empty((M, N), device='cuda', dtype=torch.float16) | |
# Grid calculation is a lambda function passed to the kernel launcher | |
# It uses the 'meta' dictionary containing the tuned block sizes | |
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']) * triton.cdiv(N, meta['BLOCK_N']),) | |
# Ensure scalar alpha/beta are tensors on the correct device | |
alpha_t = torch.tensor([alpha], device='cuda', dtype=torch.float16) | |
beta_t = torch.tensor([beta], device='cuda', dtype=torch.float16) | |
# Launch the autotuned kernel | |
fused_kernel[grid]( | |
A, B, C, D, | |
alpha_t, beta_t, | |
M, N, K, | |
A.stride(0), A.stride(1), # sAm, sAk | |
B.stride(0), B.stride(1), # sBk, sBn | |
C.stride(0), C.stride(1), # sCm, sCn | |
D.stride(0), D.stride(1), # sDm, sDn | |
USE_ERF=use_erf | |
) | |
return D | |
def bench_triton(N: int, variant: str, rep: int = 10) -> float: | |
"""Benchmarks the fused Triton kernel for a given size and variant.""" | |
torch.manual_seed(0) | |
# Ensure tensors are contiguous for predictable strides | |
A = torch.randn((N, N), device='cuda', dtype=torch.float16).contiguous() | |
B = torch.randn((N, N), device='cuda', dtype=torch.float16).contiguous() | |
C = torch.randn((N, N), device='cuda', dtype=torch.float16).contiguous() | |
D_dummy = torch.empty_like(C) # For warmup | |
M, K = A.shape | |
_, N_dim = B.shape | |
alpha_t = torch.tensor([1.0], device='cuda', dtype=torch.float16) | |
beta_t = torch.tensor([1.0], device='cuda', dtype=torch.float16) | |
use_erf = 1 if variant == 'erf' else 0 | |
# Define grid lambda function needed for both warmup and actual call | |
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']) * triton.cdiv(N_dim, meta['BLOCK_N']),) | |
# Optional: Print config chosen by autotuner after warm-up | |
if VERBOSE_AUTOTUNE: | |
print(f"\n--- Autotuning for N={N}, variant='{variant}' ---") | |
# Pass the grid lambda and all required args to warmup | |
fused_kernel.warmup( | |
A, B, C, D_dummy, | |
alpha_t, beta_t, | |
M, N_dim, K, | |
A.stride(0), A.stride(1), | |
B.stride(0), B.stride(1), | |
C.stride(0), C.stride(1), | |
D_dummy.stride(0), D_dummy.stride(1), | |
grid=grid, # Pass grid lambda | |
USE_ERF=use_erf | |
) | |
torch.cuda.synchronize() | |
# Ensure the best_config is available after warmup before printing | |
if hasattr(fused_kernel, 'best_config'): | |
print(f"Best config: {fused_kernel.best_config}") | |
else: | |
print("Warmup did not produce a best_config (check Triton version or autotuning setup).") | |
print("------------------------------------------") | |
# Warm-up call regardless of verbosity to trigger autotuning/compilation | |
_ = fused_triton(A, B, C, variant) | |
torch.cuda.synchronize() | |
# Timing loop using CUDA events for accuracy | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for _ in range(rep): | |
_ = fused_triton(A, B, C, variant) | |
end_event.record() | |
torch.cuda.synchronize() | |
elapsed_ms = start_event.elapsed_time(end_event) | |
return elapsed_ms / rep # Average time in ms | |
if __name__ == '__main__': | |
variants = ['fast', 'erf'] | |
sizes = [512, 1024, 2048, 4096, 8192] | |
rows = {v:[] for v in variants} | |
print("Running benchmarks (first run for each size/variant includes autotuning)...") | |
for N_size in sizes: | |
print(f"Benchmarking N={N_size}...") | |
for v in variants: | |
avg_time_ms = bench_triton(N_size, v, rep=20) # Use sufficient reps for stable timing | |
rows[v].append(f"{avg_time_ms:7.3f}") | |
# time.sleep(0.5) # Optional delay between benchmarks | |
# Print results | |
print("\n--- Benchmark Results (avg ms per call) ---") | |
header = 'N ' + ' '.join(f"{v:>7}" for v in variants) | |
print(header) | |
print('-'*len(header)) | |
for i, N_size in enumerate(sizes): | |
print(f"{N_size:<4}" + ' '.join(rows[v][i] for v in variants)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
gpu-exp on master [!] is 📦 v0.1.0 via 🐍 v3.13.2 | |
❯ uv run bench_jax.py | |
N fast erf nn_exact nn_approx | |
---------------------------------------- | |
512 0.099 0.064 0.087 0.077 | |
1024 0.210 0.170 0.191 0.188 | |
2048 0.917 0.942 0.907 0.891 | |
4096 6.629 5.651 5.646 5.619 | |
8192 47.132 48.638 47.175 47.903 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
❯ uv run bench_triton.py | |
Running benchmarks (first run for each size/variant includes autotuning)... | |
Benchmarking N=512... | |
--- Autotuning for N=512, variant='fast' --- | |
Warmup did not produce a best_config (check Triton version or autotuning setup). | |
------------------------------------------ | |
--- Autotuning for N=512, variant='erf' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
Benchmarking N=1024... | |
--- Autotuning for N=1024, variant='fast' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
--- Autotuning for N=1024, variant='erf' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
Benchmarking N=2048... | |
--- Autotuning for N=2048, variant='fast' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
--- Autotuning for N=2048, variant='erf' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
Benchmarking N=4096... | |
--- Autotuning for N=4096, variant='fast' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
--- Autotuning for N=4096, variant='erf' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
Benchmarking N=8192... | |
--- Autotuning for N=8192, variant='fast' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
--- Autotuning for N=8192, variant='erf' --- | |
Best config: BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, num_warps: 4, num_stages: 4, num_warps: 4, num_ctas: 1, num_stages: 3, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None | |
------------------------------------------ | |
--- Benchmark Results (avg ms per call) --- | |
N fast erf | |
------------------- | |
512 0.088 0.087 | |
1024 0.242 0.225 | |
2048 1.152 1.137 | |
4096 9.448 9.368 | |
8192 79.302 78.870 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment