Created
January 17, 2025 21:23
-
-
Save Greg-Tarr/83a6e77d9bd9ec6397b6f0cf4f2083ab to your computer and use it in GitHub Desktop.
Simple transformer KV cache in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# dependencies = [ | |
# "torch>=2.5.1", | |
# ] | |
# /// | |
# type: ignore | |
""" | |
Basic KV cache implementation for a transformer in pytorch. | |
A KV cache speeds up decoding by reusing keys & values from previous | |
tokens instead of recomputing them. This implementation: | |
1. Caches K,V tensors after each forward pass | |
2. Concatenates cached K,V with new token's K,V | |
3. Uses FlexAttention for efficient sparse attention patterns | |
Includes benchmarking code showing typical ~2.0x speedup vs standard generation, | |
at the cost of higher memory usage. | |
Example usage: | |
model = Transformer(dim=2048, num_heads=16) | |
logits, kv_cache = model(tokens, kv_cache=None) # First pass | |
logits, kv_cache = model(new_token, kv_cache=kv_cache) # Use cache | |
Benchmark results: | |
With cache: 48527.77ms (2560 tokens, 42.2 tokens/sec) | |
No cache: 101128.65ms (2560 tokens, 20.3 tokens/sec) | |
Speedup from KV cache: 2.08x | |
""" | |
from typing import List | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn.attention.flex_attention import ( | |
BlockMask, | |
flex_attention, | |
create_block_mask, | |
) | |
def norm(x): | |
return F.rms_norm(x, (x.size(-1),)) | |
def create_causal_mod(window: torch.Tensor): | |
def windowed_causal_mod(b, h, q_idx, kv_idx): | |
causal_mask = q_idx >= kv_idx | |
window_mask = q_idx - kv_idx <= window | |
return causal_mask & window_mask | |
return windowed_causal_mod | |
class Rotary(nn.Module): | |
def __init__(self, dim: int, base: int = 10000): | |
super().__init__() | |
self.register_buffer("inv_freq", (1 / base) ** (torch.arange(0, dim, 2) / dim)) | |
def forward(self, x): | |
seq_len = x.shape[1] | |
t = torch.arange(seq_len, device=x.device) | |
freqs = torch.outer(t, self.inv_freq) | |
cos, sin = freqs.cos()[None, :, None, :], freqs.sin()[None, :, None, :] | |
x1, x2 = x.chunk(2, dim=3) | |
y1 = x1 * cos + x2 * sin | |
y2 = x1 * (-sin) + x2 * cos | |
return torch.cat((y1, y2), 3).type_as(x) | |
class CausalSelfAttention(nn.Module): | |
def __init__(self, dim: int, num_heads: int): | |
super().__init__() | |
assert dim % num_heads == 0 | |
self.num_heads = num_heads | |
self.c_q = nn.Linear(dim, dim) | |
self.c_k = nn.Linear(dim, dim) | |
self.c_v = nn.Linear(dim, dim) | |
self.rotary = Rotary(dim // num_heads) | |
self.c_proj = nn.Linear(dim, dim) | |
def forward(self, x, block_mask, kv_cache: torch.Tensor | None = None): | |
B, T = x.size(0), x.size(1) | |
assert B == 1, "Must use batch size = 1 for FlexAttention" | |
q = self.c_q(x).view(B, T, self.num_heads, -1) | |
k = self.c_k(x).view(B, T, self.num_heads, -1) | |
v = self.c_v(x).view(B, T, self.num_heads, -1) | |
if kv_cache is not None: | |
k = torch.cat([kv_cache[0], k], dim=1) | |
v = torch.cat([kv_cache[1], v], dim=1) | |
kv_cache = torch.stack([k, v]) | |
q, k = self.rotary(q), self.rotary(k) | |
y = flex_attention( | |
q.transpose(1, 2), | |
k.transpose(1, 2), | |
v.transpose(1, 2), | |
block_mask=block_mask, | |
kernel_options={ | |
# kernel options for a 40GB card (gpu poor) | |
"BLOCK_M": 64, | |
"BLOCK_N": 64, | |
"BLOCK_M1": 32, | |
"BLOCK_N1": 64, | |
"BLOCK_M2": 64, | |
"BLOCK_N2": 32, | |
} | |
) | |
y = y.transpose(1, 2).contiguous().view_as(x) | |
y = self.c_proj(y) | |
return y, kv_cache | |
class MLP(nn.Module): | |
def __init__(self, dim: int): | |
super().__init__() | |
self.c_fc = nn.Linear(dim, 2 * dim) | |
self.c_proj = nn.Linear(2 * dim, dim) | |
def forward(self, x): | |
x = self.c_fc(x) | |
x = F.relu(x) | |
x = self.c_proj(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, dim: int, num_heads: int): | |
super().__init__() | |
self.attn = CausalSelfAttention(dim=dim, num_heads=num_heads) | |
self.mlp = MLP(dim=dim) | |
def forward(self, x: torch.Tensor, block_mask: BlockMask, kv_cache: torch.Tensor | None): | |
y, kv_cache = self.attn(norm(x), block_mask, kv_cache) | |
x = x + y | |
x = x + self.mlp(norm(x)) | |
return x, kv_cache | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
vocab_size: int = 50304, | |
dim: int = 768, | |
num_heads: int = 6, | |
num_layers: int = 12, | |
): | |
super().__init__() | |
self.embed = nn.Embedding(vocab_size, dim) | |
self.blocks = nn.ModuleList( | |
[Block(dim=dim, num_heads=num_heads) for _ in range(num_layers)] | |
) | |
self.head = nn.Linear(dim, vocab_size) | |
def forward(self, inputs: torch.Tensor, window: int | None = None, kv_cache: List[torch.Tensor] | None = None): | |
assert inputs.ndim == 1, "Inputs must be unbatched: shape == (seq_len,)" | |
T = inputs.size(-1) | |
if kv_cache is not None: | |
inputs = inputs[kv_cache[0].size(2):] | |
else: | |
kv_cache = [None] * len(self.blocks) | |
# Create block mask | |
window = ( | |
torch.tensor(T, dtype=torch.int32, device=inputs.device) | |
if window is None | |
else window | |
) | |
block_mask_mod = create_causal_mod(window=window) | |
block_mask = create_block_mask( | |
block_mask_mod, | |
B=None, | |
H=None, | |
Q_LEN=inputs.size(-1), | |
KV_LEN=T, | |
) | |
# Forward model | |
x = norm(self.embed(inputs[None])) | |
for block_indx, block in enumerate(self.blocks): | |
x, block_kv_cache = block(x, block_mask, kv_cache[block_indx]) | |
kv_cache[block_indx] = block_kv_cache | |
x = norm(x) | |
x = self.head(x) | |
return x, kv_cache | |
if __name__ == "__main__": | |
# Instantiate and compile model | |
device = "cuda" | |
model = Transformer( | |
dim=2048, | |
num_heads=16, | |
num_layers=32 | |
).eval() | |
model.to(device).bfloat16() | |
torch.set_float32_matmul_precision("high") | |
model = torch.compile(model) | |
# Prefill prompt | |
prompt_len = 512 | |
gen_len = 2048 | |
prompt = torch.randint(0, 50304, (prompt_len,)).to(device) | |
# Forward pass (warmup) | |
print("\nTesting forward pass...") | |
with torch.no_grad(): | |
for _ in range(10): | |
out, _ = model(prompt) | |
print(f"Output shape: {out.shape}") | |
print("\nTesting generation...") | |
def test_generation(use_cache: bool): | |
tokens = prompt.clone() | |
start_time = torch.cuda.Event(enable_timing=True) | |
end_time = torch.cuda.Event(enable_timing=True) | |
start_time.record() | |
# Forward passes | |
with torch.no_grad(): | |
kv_cache = None | |
for _ in range(gen_len): | |
logits, kv_cache = model(tokens, kv_cache=kv_cache if use_cache else None) | |
next_token = logits[0, -1].argmax() | |
tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=0) | |
end_time.record() | |
torch.cuda.synchronize() | |
elapsed_ms = start_time.elapsed_time(end_time) | |
tokens_per_sec = (gen_len * 1000) / elapsed_ms | |
return elapsed_ms, tokens.size(0), tokens_per_sec | |
# Test with cache | |
time_with_cache, len_with_cache, tps_with_cache = test_generation(use_cache=True) | |
print(f"With cache: {time_with_cache:.2f}ms ({len_with_cache} tokens, {tps_with_cache:.1f} tokens/sec)") | |
# Test without cache | |
time_no_cache, len_no_cache, tps_no_cache = test_generation(use_cache=False) | |
print(f"No cache: {time_no_cache:.2f}ms ({len_no_cache} tokens, {tps_no_cache:.1f} tokens/sec)") | |
speedup = time_no_cache / time_with_cache | |
print(f"\nSpeedup from KV cache: {speedup:.2f}x") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment