Last active
February 5, 2025 22:51
-
-
Save RohanAwhad/ecbf82433ed52494b3e028b6589a6eac to your computer and use it in GitHub Desktop.
Multi-Head Latent Attention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# dependencies = [ | |
# "numpy", | |
# ] | |
# /// | |
# multi head latent attention | |
import numpy as np | |
# SHAPES AND SIZES | |
BATCH_SIZE = 2 | |
SEQ_LEN = 8 | |
EMBED_DIM = 8 | |
N_HEADS = 4 | |
HEAD_DIM = 8 | |
ROPE_DIM = 4 | |
LORA_DIM = 2 | |
# === | |
# utils | |
# === | |
# these are dummy functions for now just to make the program run | |
def apply_rope(x): return x | |
def softmax(x): return x | |
# === | |
# weight matrices | |
# === | |
# for query | |
w_dq = np.random.rand(EMBED_DIM, LORA_DIM) | |
w_uq = np.random.rand(LORA_DIM, N_HEADS*HEAD_DIM) | |
w_rq = np.random.rand(LORA_DIM, N_HEADS*ROPE_DIM) | |
# for kv | |
w_dkv = np.random.rand(EMBED_DIM, LORA_DIM) | |
w_uk = np.random.rand(LORA_DIM, N_HEADS*HEAD_DIM) | |
w_rk = np.random.rand(EMBED_DIM, ROPE_DIM) | |
w_uv = np.random.rand(LORA_DIM, N_HEADS * HEAD_DIM) | |
# for out | |
w_out = np.random.rand(N_HEADS * HEAD_DIM, EMBED_DIM) | |
# === | |
# forward pass | |
# === | |
x = np.random.rand(BATCH_SIZE, SEQ_LEN, EMBED_DIM) | |
# lets first do query | |
dqc = x @ w_dq | |
qc = dqc @ w_uq | |
qr = apply_rope(dqc @ w_rq) | |
qc = qc.reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, HEAD_DIM) | |
qr = qr.reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, ROPE_DIM) | |
q = np.concatenate([qc, qr], axis=-1) | |
# lets do key-value | |
dkvc = x @ w_dkv # cached | |
v = (dkvc @ w_uv).reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, HEAD_DIM) | |
kc = dkvc @ w_uk | |
kr = apply_rope(x @ w_rk) # cached | |
kc = kc.reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, HEAD_DIM) | |
kr = kr.reshape(BATCH_SIZE, SEQ_LEN, 1, ROPE_DIM) | |
kr = np.broadcast_to(kr, (BATCH_SIZE, SEQ_LEN, N_HEADS, ROPE_DIM)) | |
k = np.concatenate([kc, kr], axis=-1) | |
# regular multi head attention | |
attn = softmax((q @ k.transpose(0, 1, 3, 2)) / ((HEAD_DIM + ROPE_DIM) ** (1/2))) | |
x_ = (attn @ v).reshape(BATCH_SIZE, SEQ_LEN, N_HEADS * HEAD_DIM) | |
out = x_ @ w_out | |
# Print shapes for all the matrices with clear headings | |
print("Shapes of matrices:") | |
print() | |
print("Input shape (x):", x.shape) | |
print("Query weight shapes:") | |
print(" w_dq :", w_dq.shape) | |
print(" w_uq :", w_uq.shape) | |
print(" w_rq :", w_rq.shape) | |
print() | |
print("Key-Value weight shapes:") | |
print(" w_dkv:", w_dkv.shape) | |
print(" w_uk :", w_uk.shape) | |
print(" w_rk :", w_rk.shape) | |
print(" w_uv :", w_uv.shape) | |
print() | |
print("Output weight shape (w_out):", w_out.shape) | |
print() | |
print("Intermediate shapes:") | |
print(" dqc :", dqc.shape) | |
print(" qc :", qc.shape) | |
print(" qr :", qr.shape) | |
print(" q :", q.shape) | |
print() | |
print(" dkvc :", dkvc.shape) | |
print(" v :", v.shape) | |
print(" kc :", kc.shape) | |
print(" kr :", kr.shape) | |
print(" k :", k.shape) | |
print() | |
print("Attention shape (attn):", attn.shape) | |
print("Output shape (out):", out.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Updates based on https://x.com/stochasticchasm/status/1887247394628378990