Created
March 11, 2025 16:04
-
-
Save danyaljj/27beda96053623a7499070fa4019c2a4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class MultiQueryAttention(nn.Module): | |
def __init__(self, embed_dim, num_heads): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.head_dim = embed_dim // num_heads | |
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads" | |
# Independent queries, but shared keys and values | |
self.W_q = nn.Linear(embed_dim, embed_dim, bias=False) # Queries | |
self.W_kv = nn.Linear(embed_dim, 2 * self.head_dim, bias=False) # Shared Key and Value | |
self.out_proj = nn.Linear(embed_dim, embed_dim) | |
def forward(self, x): | |
batch_size, seq_len, _ = x.shape | |
# Compute Queries (B, L, D) → (B, L, H, D/H) → (B, H, L, D/H) | |
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
# Compute shared Keys and Values (B, L, D) → (B, L, 2 * (D/H)) → (B, 1, L, D/H) | |
KV = self.W_kv(x).view(batch_size, seq_len, 2, self.head_dim).permute(2, 0, 1, 3) | |
K, V = KV[0].unsqueeze(1), KV[1].unsqueeze(1) # Shared across all heads | |
# Scaled Dot-Product Attention | |
attn_weights = torch.einsum("bhqd,bkhd->bhqk", Q, K) / (self.head_dim ** 0.5) | |
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) | |
output = torch.einsum("bhqk,bkhd->bhqd", attn_weights, V) | |
# Merge heads and apply output projection | |
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim) | |
return self.out_proj(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment