Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save a-r-r-o-w/c2b0b047bcbdd279b794a8c9e4066aef to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/c2b0b047bcbdd279b794a8c9e4066aef to your computer and use it in GitHub Desktop.
Attempt to make fused LayerNorm + Linear + Activation
import pathlib
import torch
import torch._dynamo.config
import triton
import triton.language as tl
torch._dynamo.config.cache_size_limit = 10000
class Block(torch.nn.Module):
def __init__(self, embedding_dim: int, mult: int = 4):
super().__init__()
hidden_dim = embedding_dim * mult
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(embedding_dim, hidden_dim)
self.act = torch.nn.GELU(approximate="tanh")
def forward(self, x: torch.Tensor, *args) -> torch.Tensor:
x = self.norm(x)
x = self.linear(x)
x = self.act(x)
return x
def flattened_layernorm_linear_act(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
eps: float = 1e-6,
) -> torch.Tensor:
x = torch.nn.functional.layer_norm(x, x.shape[-1:], eps=eps)
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.gelu(x, approximate="tanh")
return x
# ===== Operations =====
# 1. LayerNorm
# n_ij = (x_ij - mean_i) / sqrt(var_i + eps)
# mean_i = (∑_p x_ip) / n
# var_i = (∑_p (x_ip - mean_i)^2) / n
#
# 2. Linear
# y_ik = n_ij * w_jk + b_k
#
# 3. GELU
# z_ik = GELU(y_ik)
#
# Overall: z_ik = GELU((x_ij - mean_j) / sqrt(var_j + eps) * w_jk + b_k)
# ===== Rewrite =====
# z_ik = GELU((x_ij - mean_j) / sqrt(var_j + eps) * w_jk + b_k)
# => GELU((x_ij * w_jk - mean_i * w_jk) / sqrt(var_i + eps) + b_k)
# => GELU((x_ij * w_jk - mean_i * sum(w_jk)) / sqrt(var_i + eps) + b_k)
#
# Also, var_i = (∑_p x_ip^2) / n - mean_i^2 (recall E[X^2] = E[X]^2 + Var[X])
def reordered_layernorm_linear_act(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
eps: float = 1e-6,
) -> torch.Tensor:
x_mean = x.mean(dim=-1, keepdim=True)
x_square_mean = torch.mean(x * x, dim=-1, keepdim=True)
x_w = torch.matmul(x, weight)
w_sum = weight.sum(dim=0, keepdim=True)
nr = x_w - x_mean * w_sum
rstd = torch.rsqrt(x_square_mean - x_mean * x_mean + eps)
y = nr * rstd + bias if bias is not None else nr * rstd
y = torch.nn.functional.gelu(y, approximate="tanh")
return y
batch_size = 1
embedding_dim = 3072
mult = 4
sequence_lengths = [4096 + 128 * i for i in range(1, 11)]
warmups = 5
repeats = 20
torch.manual_seed(42)
block = Block(embedding_dim, mult).cuda().bfloat16()
weight = block.linear.weight
weight_t = block.linear.weight.t().contiguous()
bias = block.linear.bias
block_compiled_d = torch.compile(block, fullgraph=True, mode="default", dynamic=False)
block_compiled_ma = torch.compile(block, fullgraph=True, mode="max-autotune", dynamic=False)
flat_compiled_d = torch.compile(flattened_layernorm_linear_act, fullgraph=True, mode="default", dynamic=False)
flat_compiled_ma = torch.compile(flattened_layernorm_linear_act, fullgraph=True, mode="max-autotune", dynamic=False)
reordered_compiled_d = torch.compile(reordered_layernorm_linear_act, fullgraph=True, mode="default", dynamic=False)
reordered_compiled_ma = torch.compile(reordered_layernorm_linear_act, fullgraph=True, mode="max-autotune", dynamic=False)
ops = {
"block": block,
"block_compiled_d": block_compiled_d,
"block_compiled_ma": block_compiled_ma,
"flat": flattened_layernorm_linear_act,
"flat_compiled_d": flat_compiled_d,
"flat_compiled_ma": flat_compiled_ma,
"reordered": reordered_layernorm_linear_act,
"reordered_compiled_d": reordered_compiled_d,
"reordered_compiled_ma": reordered_compiled_ma,
}
def get_color_and_linestyle(n: int) -> tuple[str, str]:
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
line_styles = ["-", ":", "-.", "--"]
if n > len(colors) * len(line_styles):
raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
styles = []
for i in range(n):
color = colors[i % len(colors)]
linestyle = line_styles[i // len(colors)]
styles.append((color, linestyle))
return styles
output_dir = pathlib.Path("dump_layernorm_linear_gelu_benchmark")
output_dir.mkdir(parents=True, exist_ok=True)
def correctness():
torch.manual_seed(0)
for seq_len in sequence_lengths:
print(f"\n\n===== seq_len: {seq_len} =====")
shape = (batch_size, seq_len, embedding_dim)
x = torch.randn(shape, device="cuda", dtype=torch.bfloat16)
outputs = {}
for name, op in ops.items():
w = weight_t if name.startswith("reordered") else weight
outputs[name] = op(x, w, bias).clone()
golden_truth = outputs["block"]
for name, output in outputs.items():
absdiff = torch.abs(output - golden_truth)
absmax = absdiff.max()
mae = absdiff.mean()
mse = (absdiff * absdiff).mean()
print(f"{name:<20}: absmax={absmax:.5f}, mae={mae:.5f}, mse={mse:.5f}")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["seq_len"],
x_vals=sequence_lengths,
x_log=False,
line_arg="provider",
line_vals=list(ops.keys()),
line_names=list(ops.keys()),
ylabel="Time (ms)",
styles=get_color_and_linestyle(len(ops)),
plot_name="layernorm_linear_gelu benchmark",
args={},
)
)
def benchmark_fn(seq_len: int, provider: str):
torch._dynamo.reset()
torch.compiler.reset()
torch.manual_seed(42)
shape = (batch_size, seq_len, embedding_dim)
x = torch.randn(shape, device="cuda", dtype=torch.bfloat16)
w = weight_t if provider.startswith("reordered") else weight
fn = ops[provider]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fn(x, w, bias),
warmup=3,
rep=10,
quantiles=[0.5, 0.2, 0.8],
)
return ms, max_ms, min_ms
with torch.inference_mode():
correctness()
benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
@a-r-r-o-w
Copy link
Author

H100

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment