Skip to content

Instantly share code, notes, and snippets.

@Greg-Tarr
Created January 17, 2025 21:23
Show Gist options
  • Save Greg-Tarr/83a6e77d9bd9ec6397b6f0cf4f2083ab to your computer and use it in GitHub Desktop.
Save Greg-Tarr/83a6e77d9bd9ec6397b6f0cf4f2083ab to your computer and use it in GitHub Desktop.
Simple transformer KV cache in PyTorch
# /// script
# dependencies = [
# "torch>=2.5.1",
# ]
# ///
# type: ignore
"""
Basic KV cache implementation for a transformer in pytorch.
A KV cache speeds up decoding by reusing keys & values from previous
tokens instead of recomputing them. This implementation:
1. Caches K,V tensors after each forward pass
2. Concatenates cached K,V with new token's K,V
3. Uses FlexAttention for efficient sparse attention patterns
Includes benchmarking code showing typical ~2.0x speedup vs standard generation,
at the cost of higher memory usage.
Example usage:
model = Transformer(dim=2048, num_heads=16)
logits, kv_cache = model(tokens, kv_cache=None) # First pass
logits, kv_cache = model(new_token, kv_cache=kv_cache) # Use cache
Benchmark results:
With cache: 48527.77ms (2560 tokens, 42.2 tokens/sec)
No cache: 101128.65ms (2560 tokens, 20.3 tokens/sec)
Speedup from KV cache: 2.08x
"""
from typing import List
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
flex_attention,
create_block_mask,
)
def norm(x):
return F.rms_norm(x, (x.size(-1),))
def create_causal_mod(window: torch.Tensor):
def windowed_causal_mod(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= window
return causal_mask & window_mask
return windowed_causal_mod
class Rotary(nn.Module):
def __init__(self, dim: int, base: int = 10000):
super().__init__()
self.register_buffer("inv_freq", (1 / base) ** (torch.arange(0, dim, 2) / dim))
def forward(self, x):
seq_len = x.shape[1]
t = torch.arange(seq_len, device=x.device)
freqs = torch.outer(t, self.inv_freq)
cos, sin = freqs.cos()[None, :, None, :], freqs.sin()[None, :, None, :]
x1, x2 = x.chunk(2, dim=3)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x)
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.c_q = nn.Linear(dim, dim)
self.c_k = nn.Linear(dim, dim)
self.c_v = nn.Linear(dim, dim)
self.rotary = Rotary(dim // num_heads)
self.c_proj = nn.Linear(dim, dim)
def forward(self, x, block_mask, kv_cache: torch.Tensor | None = None):
B, T = x.size(0), x.size(1)
assert B == 1, "Must use batch size = 1 for FlexAttention"
q = self.c_q(x).view(B, T, self.num_heads, -1)
k = self.c_k(x).view(B, T, self.num_heads, -1)
v = self.c_v(x).view(B, T, self.num_heads, -1)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=1)
v = torch.cat([kv_cache[1], v], dim=1)
kv_cache = torch.stack([k, v])
q, k = self.rotary(q), self.rotary(k)
y = flex_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
block_mask=block_mask,
kernel_options={
# kernel options for a 40GB card (gpu poor)
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_M1": 32,
"BLOCK_N1": 64,
"BLOCK_M2": 64,
"BLOCK_N2": 32,
}
)
y = y.transpose(1, 2).contiguous().view_as(x)
y = self.c_proj(y)
return y, kv_cache
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.c_fc = nn.Linear(dim, 2 * dim)
self.c_proj = nn.Linear(2 * dim, dim)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
self.attn = CausalSelfAttention(dim=dim, num_heads=num_heads)
self.mlp = MLP(dim=dim)
def forward(self, x: torch.Tensor, block_mask: BlockMask, kv_cache: torch.Tensor | None):
y, kv_cache = self.attn(norm(x), block_mask, kv_cache)
x = x + y
x = x + self.mlp(norm(x))
return x, kv_cache
class Transformer(nn.Module):
def __init__(
self,
vocab_size: int = 50304,
dim: int = 768,
num_heads: int = 6,
num_layers: int = 12,
):
super().__init__()
self.embed = nn.Embedding(vocab_size, dim)
self.blocks = nn.ModuleList(
[Block(dim=dim, num_heads=num_heads) for _ in range(num_layers)]
)
self.head = nn.Linear(dim, vocab_size)
def forward(self, inputs: torch.Tensor, window: int | None = None, kv_cache: List[torch.Tensor] | None = None):
assert inputs.ndim == 1, "Inputs must be unbatched: shape == (seq_len,)"
T = inputs.size(-1)
if kv_cache is not None:
inputs = inputs[kv_cache[0].size(2):]
else:
kv_cache = [None] * len(self.blocks)
# Create block mask
window = (
torch.tensor(T, dtype=torch.int32, device=inputs.device)
if window is None
else window
)
block_mask_mod = create_causal_mod(window=window)
block_mask = create_block_mask(
block_mask_mod,
B=None,
H=None,
Q_LEN=inputs.size(-1),
KV_LEN=T,
)
# Forward model
x = norm(self.embed(inputs[None]))
for block_indx, block in enumerate(self.blocks):
x, block_kv_cache = block(x, block_mask, kv_cache[block_indx])
kv_cache[block_indx] = block_kv_cache
x = norm(x)
x = self.head(x)
return x, kv_cache
if __name__ == "__main__":
# Instantiate and compile model
device = "cuda"
model = Transformer(
dim=2048,
num_heads=16,
num_layers=32
).eval()
model.to(device).bfloat16()
torch.set_float32_matmul_precision("high")
model = torch.compile(model)
# Prefill prompt
prompt_len = 512
gen_len = 2048
prompt = torch.randint(0, 50304, (prompt_len,)).to(device)
# Forward pass (warmup)
print("\nTesting forward pass...")
with torch.no_grad():
for _ in range(10):
out, _ = model(prompt)
print(f"Output shape: {out.shape}")
print("\nTesting generation...")
def test_generation(use_cache: bool):
tokens = prompt.clone()
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
# Forward passes
with torch.no_grad():
kv_cache = None
for _ in range(gen_len):
logits, kv_cache = model(tokens, kv_cache=kv_cache if use_cache else None)
next_token = logits[0, -1].argmax()
tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=0)
end_time.record()
torch.cuda.synchronize()
elapsed_ms = start_time.elapsed_time(end_time)
tokens_per_sec = (gen_len * 1000) / elapsed_ms
return elapsed_ms, tokens.size(0), tokens_per_sec
# Test with cache
time_with_cache, len_with_cache, tps_with_cache = test_generation(use_cache=True)
print(f"With cache: {time_with_cache:.2f}ms ({len_with_cache} tokens, {tps_with_cache:.1f} tokens/sec)")
# Test without cache
time_no_cache, len_no_cache, tps_no_cache = test_generation(use_cache=False)
print(f"No cache: {time_no_cache:.2f}ms ({len_no_cache} tokens, {tps_no_cache:.1f} tokens/sec)")
speedup = time_no_cache / time_with_cache
print(f"\nSpeedup from KV cache: {speedup:.2f}x")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment