Created
June 20, 2025 13:13
-
-
Save a-r-r-o-w/c2b0b047bcbdd279b794a8c9e4066aef to your computer and use it in GitHub Desktop.
Attempt to make fused LayerNorm + Linear + Activation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pathlib | |
import torch | |
import torch._dynamo.config | |
import triton | |
import triton.language as tl | |
torch._dynamo.config.cache_size_limit = 10000 | |
class Block(torch.nn.Module): | |
def __init__(self, embedding_dim: int, mult: int = 4): | |
super().__init__() | |
hidden_dim = embedding_dim * mult | |
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
self.linear = torch.nn.Linear(embedding_dim, hidden_dim) | |
self.act = torch.nn.GELU(approximate="tanh") | |
def forward(self, x: torch.Tensor, *args) -> torch.Tensor: | |
x = self.norm(x) | |
x = self.linear(x) | |
x = self.act(x) | |
return x | |
def flattened_layernorm_linear_act( | |
x: torch.Tensor, | |
weight: torch.Tensor, | |
bias: torch.Tensor | None = None, | |
eps: float = 1e-6, | |
) -> torch.Tensor: | |
x = torch.nn.functional.layer_norm(x, x.shape[-1:], eps=eps) | |
x = torch.nn.functional.linear(x, weight, bias) | |
x = torch.nn.functional.gelu(x, approximate="tanh") | |
return x | |
# ===== Operations ===== | |
# 1. LayerNorm | |
# n_ij = (x_ij - mean_i) / sqrt(var_i + eps) | |
# mean_i = (∑_p x_ip) / n | |
# var_i = (∑_p (x_ip - mean_i)^2) / n | |
# | |
# 2. Linear | |
# y_ik = n_ij * w_jk + b_k | |
# | |
# 3. GELU | |
# z_ik = GELU(y_ik) | |
# | |
# Overall: z_ik = GELU((x_ij - mean_j) / sqrt(var_j + eps) * w_jk + b_k) | |
# ===== Rewrite ===== | |
# z_ik = GELU((x_ij - mean_j) / sqrt(var_j + eps) * w_jk + b_k) | |
# => GELU((x_ij * w_jk - mean_i * w_jk) / sqrt(var_i + eps) + b_k) | |
# => GELU((x_ij * w_jk - mean_i * sum(w_jk)) / sqrt(var_i + eps) + b_k) | |
# | |
# Also, var_i = (∑_p x_ip^2) / n - mean_i^2 (recall E[X^2] = E[X]^2 + Var[X]) | |
def reordered_layernorm_linear_act( | |
x: torch.Tensor, | |
weight: torch.Tensor, | |
bias: torch.Tensor | None = None, | |
eps: float = 1e-6, | |
) -> torch.Tensor: | |
x_mean = x.mean(dim=-1, keepdim=True) | |
x_square_mean = torch.mean(x * x, dim=-1, keepdim=True) | |
x_w = torch.matmul(x, weight) | |
w_sum = weight.sum(dim=0, keepdim=True) | |
nr = x_w - x_mean * w_sum | |
rstd = torch.rsqrt(x_square_mean - x_mean * x_mean + eps) | |
y = nr * rstd + bias if bias is not None else nr * rstd | |
y = torch.nn.functional.gelu(y, approximate="tanh") | |
return y | |
batch_size = 1 | |
embedding_dim = 3072 | |
mult = 4 | |
sequence_lengths = [4096 + 128 * i for i in range(1, 11)] | |
warmups = 5 | |
repeats = 20 | |
torch.manual_seed(42) | |
block = Block(embedding_dim, mult).cuda().bfloat16() | |
weight = block.linear.weight | |
weight_t = block.linear.weight.t().contiguous() | |
bias = block.linear.bias | |
block_compiled_d = torch.compile(block, fullgraph=True, mode="default", dynamic=False) | |
block_compiled_ma = torch.compile(block, fullgraph=True, mode="max-autotune", dynamic=False) | |
flat_compiled_d = torch.compile(flattened_layernorm_linear_act, fullgraph=True, mode="default", dynamic=False) | |
flat_compiled_ma = torch.compile(flattened_layernorm_linear_act, fullgraph=True, mode="max-autotune", dynamic=False) | |
reordered_compiled_d = torch.compile(reordered_layernorm_linear_act, fullgraph=True, mode="default", dynamic=False) | |
reordered_compiled_ma = torch.compile(reordered_layernorm_linear_act, fullgraph=True, mode="max-autotune", dynamic=False) | |
ops = { | |
"block": block, | |
"block_compiled_d": block_compiled_d, | |
"block_compiled_ma": block_compiled_ma, | |
"flat": flattened_layernorm_linear_act, | |
"flat_compiled_d": flat_compiled_d, | |
"flat_compiled_ma": flat_compiled_ma, | |
"reordered": reordered_layernorm_linear_act, | |
"reordered_compiled_d": reordered_compiled_d, | |
"reordered_compiled_ma": reordered_compiled_ma, | |
} | |
def get_color_and_linestyle(n: int) -> tuple[str, str]: | |
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"] | |
line_styles = ["-", ":", "-.", "--"] | |
if n > len(colors) * len(line_styles): | |
raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}") | |
styles = [] | |
for i in range(n): | |
color = colors[i % len(colors)] | |
linestyle = line_styles[i // len(colors)] | |
styles.append((color, linestyle)) | |
return styles | |
output_dir = pathlib.Path("dump_layernorm_linear_gelu_benchmark") | |
output_dir.mkdir(parents=True, exist_ok=True) | |
def correctness(): | |
torch.manual_seed(0) | |
for seq_len in sequence_lengths: | |
print(f"\n\n===== seq_len: {seq_len} =====") | |
shape = (batch_size, seq_len, embedding_dim) | |
x = torch.randn(shape, device="cuda", dtype=torch.bfloat16) | |
outputs = {} | |
for name, op in ops.items(): | |
w = weight_t if name.startswith("reordered") else weight | |
outputs[name] = op(x, w, bias).clone() | |
golden_truth = outputs["block"] | |
for name, output in outputs.items(): | |
absdiff = torch.abs(output - golden_truth) | |
absmax = absdiff.max() | |
mae = absdiff.mean() | |
mse = (absdiff * absdiff).mean() | |
print(f"{name:<20}: absmax={absmax:.5f}, mae={mae:.5f}, mse={mse:.5f}") | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=["seq_len"], | |
x_vals=sequence_lengths, | |
x_log=False, | |
line_arg="provider", | |
line_vals=list(ops.keys()), | |
line_names=list(ops.keys()), | |
ylabel="Time (ms)", | |
styles=get_color_and_linestyle(len(ops)), | |
plot_name="layernorm_linear_gelu benchmark", | |
args={}, | |
) | |
) | |
def benchmark_fn(seq_len: int, provider: str): | |
torch._dynamo.reset() | |
torch.compiler.reset() | |
torch.manual_seed(42) | |
shape = (batch_size, seq_len, embedding_dim) | |
x = torch.randn(shape, device="cuda", dtype=torch.bfloat16) | |
w = weight_t if provider.startswith("reordered") else weight | |
fn = ops[provider] | |
ms, min_ms, max_ms = triton.testing.do_bench( | |
lambda: fn(x, w, bias), | |
warmup=3, | |
rep=10, | |
quantiles=[0.5, 0.2, 0.8], | |
) | |
return ms, max_ms, min_ms | |
with torch.inference_mode(): | |
correctness() | |
benchmark_fn.run(print_data=True, save_path=output_dir.as_posix()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
H100