Skip to content

Instantly share code, notes, and snippets.

@0xBigBoss
Last active October 27, 2024 02:06
Show Gist options
  • Save 0xBigBoss/44cfddba3d14e925a28ffbbcdc4c5679 to your computer and use it in GitHub Desktop.
Save 0xBigBoss/44cfddba3d14e925a28ffbbcdc4c5679 to your computer and use it in GitHub Desktop.
A simple LLM and tokenizor for demonstrating KV Cache and inference using the transformer architecture(generating text).
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleTokenizer:
def __init__(self):
# Very simplified tokenizer for demonstration
self.vocab = {
"Hello": 0,
"world": 1,
"how": 2,
"are": 3,
"you": 4,
"today": 5,
"<pad>": 6
}
def encode(self, text):
# Simple space-based tokenization
return torch.tensor(
[[self.vocab.get(word, 6) for word in text.split()]])
class SimplifiedAttentionWithKVCache(nn.Module):
"""
Purpose of KV Cache:
When generating text, LLMs need to compute attention over all previous tokens
Without caching, we would recompute key and value projections for all previous tokens at each step
KV cache stores previously computed key and value projections, avoiding redundant computations
How it Works:
During the first forward pass, we compute and store K (keys) and V (values) for all tokens
For subsequent tokens, we only compute K and V for the new token
We concatenate the cached K,V with the new token's K,V
This saves significant computation during autoregressive generation
Benefits:
Dramatically reduces computation time during inference
Memory usage increases (need to store K,V), but the speedup is worth it
Essential for practical deployment of large language models
"""
def __init__(self, hidden_size=512, num_heads=8):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
# Linear projections for Q, K, V
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, x, kv_cache=None, use_cache=False):
batch_size, seq_len, _ = x.shape
# Project queries, keys, and values
q = self.q_proj(x) # (batch_size, seq_len, hidden_size)
k = self.k_proj(x) # (batch_size, seq_len, hidden_size)
v = self.v_proj(x) # (batch_size, seq_len, hidden_size)
# Reshape for multi-head attention
q = q.view(batch_size, seq_len, self.num_heads,
self.head_size).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads,
self.head_size).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads,
self.head_size).transpose(1, 2)
# If using KV cache
if use_cache:
if kv_cache is not None:
cached_k, cached_v = kv_cache
# Concatenate current k,v with cached k,v
k = torch.cat([cached_k, k], dim=2)
v = torch.cat([cached_v, v], dim=2)
# Update cache with current k,v
new_cache = (k, v)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(
torch.tensor(self.head_size))
attn_weights = F.softmax(scores, dim=-1)
# Apply attention to values
context = torch.matmul(attn_weights, v)
# Reshape and project output
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.hidden_size)
output = self.out_proj(context)
if use_cache:
return output, new_cache
return output
class SimplifiedLLMWithCache(nn.Module):
def __init__(self, vocab_size=7, hidden_size=512, num_heads=8):
super().__init__()
self.hidden_size = hidden_size
# Token embedding layer
self.embedding = nn.Embedding(vocab_size, hidden_size)
# Attention layer with KV cache
self.attention = SimplifiedAttentionWithKVCache(hidden_size, num_heads)
# Output projection
self.output_proj = nn.Linear(hidden_size, vocab_size)
def forward(self, input_ids, kv_cache=None, use_cache=False):
# Convert token IDs to embeddings
embeddings = self.embedding(
input_ids) # [batch_size, seq_len, hidden_size]
# Apply attention with optional KV cache
if use_cache:
hidden_states, new_cache = self.attention(embeddings,
kv_cache,
use_cache=True)
logits = self.output_proj(hidden_states)
return logits, new_cache
else:
hidden_states = self.attention(embeddings)
logits = self.output_proj(hidden_states)
return logits
def demonstrate_text_generation():
model = SimplifiedLLMWithCache()
tokenizer = SimpleTokenizer()
# Initial prompt
prompt = "Hello world how"
input_ids = tokenizer.encode(prompt)
print(f"Initial prompt: '{prompt}'")
print(f"Tokenized prompt shape: {input_ids.shape}") # [1, 3]
# First forward pass with the full prompt
# This will compute and cache KV for all tokens
outputs, kv_cache = model(input_ids, use_cache=True)
print(f"Full prompt output shape: {outputs.shape}") # [1, 3, vocab_size]
# Now generate the next token
# We only need to process "are" and can use the cached KV values
next_text = "are"
next_input = tokenizer.encode(next_text)
print(f"\nGenerating next token: '{next_text}'")
print(f"Next token input shape: {next_input.shape}") # [1, 1]
# Forward pass with cached KV values
next_outputs, kv_cache = model(next_input,
kv_cache=kv_cache,
use_cache=True)
print(
f"Next token output shape: {next_outputs.shape}") # [1, 1, vocab_size]
return {
"initial_outputs": outputs,
"next_outputs": next_outputs,
"kv_cache": kv_cache
}
# Example usage
def demonstrate_kv_cache():
model = SimplifiedAttentionWithKVCache()
# Initial sequence
x1 = torch.randn(1, 5, 512) # (batch_size=1, seq_len=5, hidden_size=512)
# Generate with cache
output1, kv_cache = model(x1, use_cache=True)
# Next token
x2 = torch.randn(1, 1, 512) # Single new token
output2, kv_cache = model(x2, kv_cache=kv_cache, use_cache=True)
return output1, output2, kv_cache
if __name__ == "__main__":
output1, output2, kv_cache = demonstrate_kv_cache()
print("Demonstrating KV cache")
print(f"Output 1 shape: {output1.shape}") # [1, 5, 512]
print(f"Output 2 shape: {output2.shape}") # [1, 1, 512]
print(f"KV cache shape: {kv_cache[0].shape}") # [1, 8, 6, 64]
print("Demonstrating text generation")
results = demonstrate_text_generation()
print(f"Output 1 shape: {results['initial_outputs'].shape}") # [1, 3, 7]
print(f"Output 2 shape: {results['next_outputs'].shape}") # [1, 1, 7]
print(f"KV cache shape: {results['kv_cache'][0].shape}") # [1, 8, 6, 64]

Demo a Simple LLM

A simple LLM and tokenizor for demonstrating KV Cache and inference using the transformer architecture(generating text).

Need PyTorch installed.

$ python ./demo_llm.py
Demonstrating KV cache
Output 1 shape: torch.Size([1, 5, 512])
Output 2 shape: torch.Size([1, 1, 512])
KV cache shape: torch.Size([1, 8, 6, 64])

Demonstrating text generation
Initial prompt: 'Hello world how'
Tokenized prompt shape: torch.Size([1, 3])
Full prompt output shape: torch.Size([1, 3, 7])

Generating next token: 'are'
Next token input shape: torch.Size([1, 1])
Next token output shape: torch.Size([1, 1, 7])
Output 1 shape: torch.Size([1, 3, 7])
Output 2 shape: torch.Size([1, 1, 7])
KV cache shape: torch.Size([1, 8, 4, 64])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment