Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created March 11, 2025 16:04
Show Gist options
  • Save danyaljj/27beda96053623a7499070fa4019c2a4 to your computer and use it in GitHub Desktop.
Save danyaljj/27beda96053623a7499070fa4019c2a4 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class MultiQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
# Independent queries, but shared keys and values
self.W_q = nn.Linear(embed_dim, embed_dim, bias=False) # Queries
self.W_kv = nn.Linear(embed_dim, 2 * self.head_dim, bias=False) # Shared Key and Value
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# Compute Queries (B, L, D) → (B, L, H, D/H) → (B, H, L, D/H)
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute shared Keys and Values (B, L, D) → (B, L, 2 * (D/H)) → (B, 1, L, D/H)
KV = self.W_kv(x).view(batch_size, seq_len, 2, self.head_dim).permute(2, 0, 1, 3)
K, V = KV[0].unsqueeze(1), KV[1].unsqueeze(1) # Shared across all heads
# Scaled Dot-Product Attention
attn_weights = torch.einsum("bhqd,bkhd->bhqk", Q, K) / (self.head_dim ** 0.5)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
output = torch.einsum("bhqk,bkhd->bhqd", attn_weights, V)
# Merge heads and apply output projection
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
return self.out_proj(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment