Skip to content

Instantly share code, notes, and snippets.

@RohanAwhad
Last active February 5, 2025 22:51
Show Gist options
  • Save RohanAwhad/ecbf82433ed52494b3e028b6589a6eac to your computer and use it in GitHub Desktop.
Save RohanAwhad/ecbf82433ed52494b3e028b6589a6eac to your computer and use it in GitHub Desktop.
Multi-Head Latent Attention
# /// script
# dependencies = [
# "numpy",
# ]
# ///
# multi head latent attention
import numpy as np
# SHAPES AND SIZES
BATCH_SIZE = 2
SEQ_LEN = 8
EMBED_DIM = 8
N_HEADS = 4
HEAD_DIM = 8
ROPE_DIM = 4
LORA_DIM = 2
# ===
# utils
# ===
# these are dummy functions for now just to make the program run
def apply_rope(x): return x
def softmax(x): return x
# ===
# weight matrices
# ===
# for query
w_dq = np.random.rand(EMBED_DIM, LORA_DIM)
w_uq = np.random.rand(LORA_DIM, N_HEADS*HEAD_DIM)
w_rq = np.random.rand(LORA_DIM, N_HEADS*ROPE_DIM)
# for kv
w_dkv = np.random.rand(EMBED_DIM, LORA_DIM)
w_uk = np.random.rand(LORA_DIM, N_HEADS*HEAD_DIM)
w_rk = np.random.rand(EMBED_DIM, ROPE_DIM)
w_uv = np.random.rand(LORA_DIM, N_HEADS * HEAD_DIM)
# for out
w_out = np.random.rand(N_HEADS * HEAD_DIM, EMBED_DIM)
# ===
# forward pass
# ===
x = np.random.rand(BATCH_SIZE, SEQ_LEN, EMBED_DIM)
# lets first do query
dqc = x @ w_dq
qc = dqc @ w_uq
qr = apply_rope(dqc @ w_rq)
qc = qc.reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, HEAD_DIM)
qr = qr.reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, ROPE_DIM)
q = np.concatenate([qc, qr], axis=-1)
# lets do key-value
dkvc = x @ w_dkv # cached
v = (dkvc @ w_uv).reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, HEAD_DIM)
kc = dkvc @ w_uk
kr = apply_rope(x @ w_rk) # cached
kc = kc.reshape(BATCH_SIZE, SEQ_LEN, N_HEADS, HEAD_DIM)
kr = kr.reshape(BATCH_SIZE, SEQ_LEN, 1, ROPE_DIM)
kr = np.broadcast_to(kr, (BATCH_SIZE, SEQ_LEN, N_HEADS, ROPE_DIM))
k = np.concatenate([kc, kr], axis=-1)
# regular multi head attention
attn = softmax((q @ k.transpose(0, 1, 3, 2)) / ((HEAD_DIM + ROPE_DIM) ** (1/2)))
x_ = (attn @ v).reshape(BATCH_SIZE, SEQ_LEN, N_HEADS * HEAD_DIM)
out = x_ @ w_out
# Print shapes for all the matrices with clear headings
print("Shapes of matrices:")
print()
print("Input shape (x):", x.shape)
print("Query weight shapes:")
print(" w_dq :", w_dq.shape)
print(" w_uq :", w_uq.shape)
print(" w_rq :", w_rq.shape)
print()
print("Key-Value weight shapes:")
print(" w_dkv:", w_dkv.shape)
print(" w_uk :", w_uk.shape)
print(" w_rk :", w_rk.shape)
print(" w_uv :", w_uv.shape)
print()
print("Output weight shape (w_out):", w_out.shape)
print()
print("Intermediate shapes:")
print(" dqc :", dqc.shape)
print(" qc :", qc.shape)
print(" qr :", qr.shape)
print(" q :", q.shape)
print()
print(" dkvc :", dkvc.shape)
print(" v :", v.shape)
print(" kc :", kc.shape)
print(" kr :", kr.shape)
print(" k :", k.shape)
print()
print("Attention shape (attn):", attn.shape)
print("Output shape (out):", out.shape)
@RohanAwhad
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment