Last active
August 9, 2025 02:16
-
-
Save LiutongZhou/93cc92f0eb3745c44c5633215336b631 to your computer and use it in GitHub Desktop.
sdpa and MOE of OpenAI OSS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""OpenAI OSS sdpa and moe implementations that are suitable for both training and inference.""" | |
from typing import Final | |
import torch | |
import torch.nn.functional as F | |
from einops import einsum, rearrange, repeat | |
from torch import Tensor, nn | |
__all__ = ["sdpa", "MOEBlock"] | |
def sdpa( | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
sink_logits: Tensor | None, | |
*, | |
sliding_window: int = 0, | |
attn_dropout_p: float = 0.0, | |
training: bool | None = None, | |
) -> Tensor: | |
"""Scaled dot-product attention | |
with grouped queries, causal masking, optional sliding window, and an optional per-head *sink* logit. | |
Parameters | |
---------- | |
query : torch.Tensor | |
Query tensor of shape ``[batch, query_seq_len, heads_kv, group_size, hidden_dim]`` | |
or ``[query_seq_len, heads_kv, group_size, hidden_dim]``. | |
``group_size`` is the group size used in GQA, i.e. ``heads_queries = heads_kv * group_size``. | |
key : torch.Tensor | |
Key tensor of shape ``[batch, query_seq_len, heads_kv, hidden_dim]`` or ``[query_seq_len, heads_kv, hidden_dim]``. | |
value : torch.Tensor | |
Value tensor of shape ``[batch, query_seq_len, heads_kv, hidden_dim]`` or ``[query_seq_len, heads_kv, hidden_dim]``. | |
sink_logits : torch.Tensor or None | |
Optional per-attention-head sink attention_scores of shape ``[heads_queries] == [heads_kv * group_size]``. | |
When provided, a sink column is appended to the attention attention_scores; it | |
draws probability mass but is discarded before applying values. | |
sliding_window : int, default=0 | |
If ``> 0``, token ``t`` may only attend to keys in | |
``[t - sliding_window, ..., t]`` (inclusive). | |
attn_dropout_p : float, default=0.0 | |
Dropout probability applied to attention probabilities *after* softmax. | |
training : bool or None, default=None | |
If training (or prefilling), set to True to apply mask | |
Returns | |
------- | |
torch.Tensor | |
Attention output of shape ``[batch, query_seq_len, heads_queries * hidden_dim]`` | |
(or ``[query_seq_len, heads_queries * hidden_dim]`` if the input did not have a batch dimension). | |
Notes | |
----- | |
- Einsum index legend used below: | |
``b``=batch, ``t``=query sequence, ``s``=key value sequence, ``h``=KV head, | |
``g``=group size, ``d``=head dim. | |
""" | |
# --- normalize shapes to 5D with an explicit batch dimension --- | |
added_batch_dim: Final[bool] = query.ndim == 4 | |
if added_batch_dim: | |
query = query.unsqueeze(0) | |
key = key.unsqueeze(0) | |
value = value.unsqueeze(0) | |
batch, query_seq_len, heads_kv, group_size, hidden_dim = query.shape | |
device = query.device | |
dtype = query.dtype | |
heads_queries: Final[int] = heads_kv * group_size | |
scale = torch.rsqrt(torch.sqrt(torch.as_tensor(hidden_dim, dtype=dtype, device=device))) # 1 / sqrt(hidden_dim) | |
attention_scores = einsum(query * scale, key * scale, "b t h g d, b s h d -> b h g t s") | |
if training: # or prefilling | |
# causal mask -- upper triangular matrix with -inf | |
mask_shape = attention_scores.shape[-2:] | |
causal_mask = torch.triu( | |
torch.full(mask_shape, fill_value=-float("inf"), dtype=dtype, device=device), | |
diagonal=1, | |
) | |
# sliding window mask -- lower triangular matrix with -inf | |
if sliding_window and sliding_window > 0: | |
sliding_mask = torch.tril( | |
torch.full(mask_shape, fill_value=-float("inf"), dtype=dtype, device=device), | |
diagonal=-sliding_window, | |
) | |
mask = torch.minimum(causal_mask, sliding_mask) | |
else: | |
mask = causal_mask | |
attention_scores = attention_scores + rearrange( | |
mask, "t s -> 1 1 1 t s" | |
) # [batch,heads_kv,group_size,query_seq_len,S] | |
# --- optional sink column (per head) --- | |
if sink_logits is not None: | |
if sink_logits.numel() != heads_queries: | |
msg = f"sink_logits must have shape [heads_queries]={heads_queries}, got {tuple(sink_logits.shape)}" | |
raise ValueError(msg) | |
sink = rearrange( | |
sink_logits.to(dtype), | |
"(h g) -> 1 h g 1 1", | |
h=heads_kv, | |
g=group_size, | |
) | |
sink = repeat( | |
sink, "1 h g 1 1 -> b h g t 1", b=batch, t=query_seq_len | |
) # [batch, heads_kv, group_size, query_seq_len, 1] | |
attention_scores = torch.cat([attention_scores, sink], dim=-1) # append sink column | |
# softmax over keys+sink, then drop sink column | |
attn_prob = torch.softmax(attention_scores, dim=-1)[..., :-1] | |
else: | |
attn_prob = torch.softmax(attention_scores, dim=-1) | |
if attn_dropout_p and training: | |
attn_prob = F.dropout(attn_prob, p=attn_dropout_p, training=True) | |
attn_out = einsum(attn_prob, value, "b h g t s, b s h d -> b t h g d") | |
out = rearrange(attn_out, "b t h g d -> b t (h g d)") | |
if added_batch_dim: | |
out = out.squeeze(0) | |
return out | |
def test_sdpa(): | |
torch.manual_seed(0) | |
b, t, h, g, d = 1, 8, 4, 2, 16 | |
query, key, value = torch.randn((b, t, h, g, d)), torch.randn((b, t, h, d)), torch.randn((b, t, h, d)) | |
sink = torch.ones((h * g)) | |
sdpa(query, key, value, sink, sliding_window=3) | |
def swiglu(x: Tensor, alpha: float = 1.702, clamp_limit: float = 7.0) -> Tensor: | |
"""OpenAI's unconventional SwiGLU activation function. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Input tensor of shape ``[B, T, D]`` or ``[T, D]`` | |
Returns | |
------- | |
torch.Tensor | |
Output tensor of shape [B, T, D//2] or [T, D//2], where the last dimension is halved. | |
""" | |
x_glu, x_linear = x[..., ::2], x[..., 1::2] | |
if clamp_limit: | |
x_glu = x_glu.clamp(min=None, max=clamp_limit) | |
x_linear = x_linear.clamp(min=-clamp_limit, max=clamp_limit) | |
return (x_linear + 1) * x_glu * (alpha * x_glu).sigmoid() | |
class RMSNorm(nn.Module): | |
def __init__(self, hidden_dim: int, eps: float = 1e-05, device: torch.device | None = None): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.eps = eps | |
self.scale = nn.Parameter(torch.ones(hidden_dim, device=device, dtype=torch.float32)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
assert x.shape[-1] == self.hidden_dim | |
x, dtype = x.float(), x.dtype | |
x = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) # x / root_mean_squared(x) | |
return (x * self.scale).to(dtype) | |
class MOEBlock(nn.Module): | |
"""Mixture-of-Experts two-layer MLP with Top-experts_per_token routing (SwiGLU). | |
The block applies a pre-layer-normalization, routes each token to its top-experts_per_token experts, | |
executes expert-specific MLPs, sums expert outputs with softmax weights, and | |
adds a residual connection. | |
Parameters | |
---------- | |
hidden_dim : int | |
hidden dimension of the input states | |
intermediate_size : int | |
intermediate dimension of the two-layer MLPs. Can be set different from hidden_dim. | |
If set to > hidden_dim, the MLP will up-project the input to a larger intermediate space | |
If set to < hidden_dim, the MLP will down-project the input to a smaller intermediate space. | |
num_experts : int | |
Number of experts ``num_experts``. | |
experts_per_token : int, default=4 | |
Number of experts selected per token (Top-experts_per_token routing). | |
Shapes | |
------ | |
Input: | |
``x`` has shape ``[B, T, hidden_dim]`` or ``[T, hidden_dim]``. | |
Parameters: | |
``w1``: ``[num_experts, 2*intermediate_size, hidden_dim]``, | |
``b1``: ``[num_experts, 2*intermediate_size]`` (up-projection) | |
``w2``: ``[num_experts, hidden_dim, intermediate_size]``, | |
``b2``: ``[num_experts, hidden_dim]`` (down-projection) | |
Output: | |
Same shape as input, with residual connection applied. | |
Notes | |
----- | |
- Einsum index legend: ``b``=batch, ``t``=query sequence, ``k``=top-k experts, | |
``i`` = intermediate size, ``e``=expert id, ``c``=model dim, ``h``=hidden dim | |
""" | |
def __init__( | |
self, | |
hidden_dim: int = 2880, | |
intermediate_size: int = 2880, | |
num_experts: int = 32, | |
experts_per_token: int = 4, | |
*, | |
device: torch.device | None = None, | |
) -> None: | |
super().__init__() | |
assert experts_per_token >= 1, "experts_per_token must be >= 1" | |
assert num_experts >= experts_per_token, "num_experts must be >= experts_per_token" | |
self.hidden_dim: Final[int] = hidden_dim | |
self.intermediate_size: Final[int] = intermediate_size | |
self.num_experts: Final[int] = num_experts | |
self.experts_per_token: Final[int] = experts_per_token | |
self.norm = RMSNorm(hidden_dim, device=device) | |
self.router = nn.Linear(self.hidden_dim, self.num_experts, device=device) | |
# Experts' weights and biases | |
# (num_experts stacks of Linear(hidden_dim -> 2*intermediate_size) and Linear(intermediate_size -> hidden_dim)) | |
w1 = torch.empty(self.num_experts, 2 * self.intermediate_size, self.hidden_dim, device=device) | |
b1 = torch.empty(self.num_experts, 2 * self.intermediate_size, device=device) | |
w2 = torch.empty(self.num_experts, self.hidden_dim, self.intermediate_size, device=device) | |
b2 = torch.empty(self.num_experts, self.hidden_dim, device=device) | |
# Xavier init for weights, zeros for biases | |
nn.init.xavier_uniform_(w1) | |
nn.init.xavier_uniform_(w2) | |
nn.init.zeros_(b1) | |
nn.init.zeros_(b2) | |
self.w1 = nn.Parameter(w1) | |
self.b1 = nn.Parameter(b1) | |
self.w2 = nn.Parameter(w2) | |
self.b2 = nn.Parameter(b2) | |
def forward(self, x: Tensor) -> Tensor: | |
"""Apply Top-experts_per_token routed expert MLP with residual. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Input of shape ``[B, T, hidden_dim]`` or ``[T, hidden_dim]``. | |
Returns | |
------- | |
torch.Tensor | |
Output tensor with the same shape as ``x``. | |
""" | |
# --- normalize input shape to [B, T, hidden_dim] --- | |
added_batch_dim: Final[bool] = x.ndim == 2 | |
if added_batch_dim: | |
x = x.unsqueeze(0) | |
B, T, hidden_dim = x.shape | |
assert hidden_dim == self.hidden_dim, f"Expected last dim {self.hidden_dim}, got {hidden_dim}" | |
# Pre-layer-RMSNorm | |
x_norm = self.norm(x) # [B, T, hidden_dim] | |
router_logits = self.router(x_norm) # project to [B, T, num_experts] | |
expert_scores, expert_indices = torch.topk( | |
router_logits, k=self.experts_per_token, dim=-1, sorted=True | |
) # [B, T, experts_per_token] | |
expert_weights = torch.softmax(expert_scores, dim=-1) # [B, T, experts_per_token] | |
# Gather per-token expert parameters -> [B, T, experts_per_token, ...] | |
experts_w1 = self.w1[expert_indices, ...] # [B, T, experts_per_token, 2*intermediate_size, hidden_dim] | |
experts_b1 = self.b1[expert_indices, ...] # [B, T, experts_per_token, 2*intermediate_size] | |
experts_w2 = self.w2[expert_indices, ...] # [B, T, experts_per_token, hidden_dim, intermediate_size] | |
experts_b2 = self.b2[expert_indices, ...] # [B, T, experts_per_token, hidden_dim] | |
# Expert FFN-1 (up-proj) + SwiGLU | |
up_projected = einsum(x_norm, experts_w1, "b t h, b t k double_i h -> b t k double_i") + experts_b1 | |
intermediate = swiglu(up_projected) # [B, T, experts_per_token, intermediate_size] | |
# Expert FFN-2 (down-proj) | |
down_projected = einsum(intermediate, experts_w2, "b t k i, b t k i h -> b t k h") + experts_b2 | |
# Weighted mixture over the selected experts_per_token experts | |
y = einsum(down_projected, expert_weights, "b t k h, b t k -> b t h") | |
# Residual connection and shape restore | |
out = x + y | |
if added_batch_dim: | |
out = out.squeeze(0) | |
return out | |
def test_moe_block(): | |
moe = MOEBlock() | |
torch.manual_seed(0) | |
x = torch.randn((2, 10, moe.hidden_dim)) | |
with torch.inference_mode(): | |
output = moe(x) | |
assert output.shape == ( | |
2, | |
10, | |
moe.hidden_dim, | |
), f"Expected output shape {(2, 10, moe.hidden_dim)}, got {output.shape}" | |
if __name__ == "__main__": | |
test_sdpa() | |
test_moe_block() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment