|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class SimpleTokenizer: |
|
|
|
def __init__(self): |
|
# Very simplified tokenizer for demonstration |
|
self.vocab = { |
|
"Hello": 0, |
|
"world": 1, |
|
"how": 2, |
|
"are": 3, |
|
"you": 4, |
|
"today": 5, |
|
"<pad>": 6 |
|
} |
|
|
|
def encode(self, text): |
|
# Simple space-based tokenization |
|
return torch.tensor( |
|
[[self.vocab.get(word, 6) for word in text.split()]]) |
|
|
|
|
|
class SimplifiedAttentionWithKVCache(nn.Module): |
|
""" |
|
Purpose of KV Cache: |
|
|
|
|
|
When generating text, LLMs need to compute attention over all previous tokens |
|
Without caching, we would recompute key and value projections for all previous tokens at each step |
|
KV cache stores previously computed key and value projections, avoiding redundant computations |
|
|
|
|
|
How it Works: |
|
|
|
|
|
During the first forward pass, we compute and store K (keys) and V (values) for all tokens |
|
For subsequent tokens, we only compute K and V for the new token |
|
We concatenate the cached K,V with the new token's K,V |
|
This saves significant computation during autoregressive generation |
|
|
|
|
|
Benefits: |
|
|
|
|
|
Dramatically reduces computation time during inference |
|
Memory usage increases (need to store K,V), but the speedup is worth it |
|
Essential for practical deployment of large language models |
|
""" |
|
|
|
def __init__(self, hidden_size=512, num_heads=8): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.num_heads = num_heads |
|
self.head_size = hidden_size // num_heads |
|
|
|
# Linear projections for Q, K, V |
|
self.q_proj = nn.Linear(hidden_size, hidden_size) |
|
self.k_proj = nn.Linear(hidden_size, hidden_size) |
|
self.v_proj = nn.Linear(hidden_size, hidden_size) |
|
self.out_proj = nn.Linear(hidden_size, hidden_size) |
|
|
|
def forward(self, x, kv_cache=None, use_cache=False): |
|
batch_size, seq_len, _ = x.shape |
|
|
|
# Project queries, keys, and values |
|
q = self.q_proj(x) # (batch_size, seq_len, hidden_size) |
|
k = self.k_proj(x) # (batch_size, seq_len, hidden_size) |
|
v = self.v_proj(x) # (batch_size, seq_len, hidden_size) |
|
|
|
# Reshape for multi-head attention |
|
q = q.view(batch_size, seq_len, self.num_heads, |
|
self.head_size).transpose(1, 2) |
|
k = k.view(batch_size, seq_len, self.num_heads, |
|
self.head_size).transpose(1, 2) |
|
v = v.view(batch_size, seq_len, self.num_heads, |
|
self.head_size).transpose(1, 2) |
|
|
|
# If using KV cache |
|
if use_cache: |
|
if kv_cache is not None: |
|
cached_k, cached_v = kv_cache |
|
# Concatenate current k,v with cached k,v |
|
k = torch.cat([cached_k, k], dim=2) |
|
v = torch.cat([cached_v, v], dim=2) |
|
|
|
# Update cache with current k,v |
|
new_cache = (k, v) |
|
|
|
# Compute attention scores |
|
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt( |
|
torch.tensor(self.head_size)) |
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
|
# Apply attention to values |
|
context = torch.matmul(attn_weights, v) |
|
|
|
# Reshape and project output |
|
context = context.transpose(1, 2).contiguous() |
|
context = context.view(batch_size, seq_len, self.hidden_size) |
|
output = self.out_proj(context) |
|
|
|
if use_cache: |
|
return output, new_cache |
|
return output |
|
|
|
|
|
class SimplifiedLLMWithCache(nn.Module): |
|
|
|
def __init__(self, vocab_size=7, hidden_size=512, num_heads=8): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
|
|
# Token embedding layer |
|
self.embedding = nn.Embedding(vocab_size, hidden_size) |
|
|
|
# Attention layer with KV cache |
|
self.attention = SimplifiedAttentionWithKVCache(hidden_size, num_heads) |
|
|
|
# Output projection |
|
self.output_proj = nn.Linear(hidden_size, vocab_size) |
|
|
|
def forward(self, input_ids, kv_cache=None, use_cache=False): |
|
# Convert token IDs to embeddings |
|
embeddings = self.embedding( |
|
input_ids) # [batch_size, seq_len, hidden_size] |
|
|
|
# Apply attention with optional KV cache |
|
if use_cache: |
|
hidden_states, new_cache = self.attention(embeddings, |
|
kv_cache, |
|
use_cache=True) |
|
logits = self.output_proj(hidden_states) |
|
return logits, new_cache |
|
else: |
|
hidden_states = self.attention(embeddings) |
|
logits = self.output_proj(hidden_states) |
|
return logits |
|
|
|
|
|
def demonstrate_text_generation(): |
|
model = SimplifiedLLMWithCache() |
|
tokenizer = SimpleTokenizer() |
|
|
|
# Initial prompt |
|
prompt = "Hello world how" |
|
input_ids = tokenizer.encode(prompt) |
|
print(f"Initial prompt: '{prompt}'") |
|
print(f"Tokenized prompt shape: {input_ids.shape}") # [1, 3] |
|
|
|
# First forward pass with the full prompt |
|
# This will compute and cache KV for all tokens |
|
outputs, kv_cache = model(input_ids, use_cache=True) |
|
print(f"Full prompt output shape: {outputs.shape}") # [1, 3, vocab_size] |
|
|
|
# Now generate the next token |
|
# We only need to process "are" and can use the cached KV values |
|
next_text = "are" |
|
next_input = tokenizer.encode(next_text) |
|
print(f"\nGenerating next token: '{next_text}'") |
|
print(f"Next token input shape: {next_input.shape}") # [1, 1] |
|
|
|
# Forward pass with cached KV values |
|
next_outputs, kv_cache = model(next_input, |
|
kv_cache=kv_cache, |
|
use_cache=True) |
|
print( |
|
f"Next token output shape: {next_outputs.shape}") # [1, 1, vocab_size] |
|
|
|
return { |
|
"initial_outputs": outputs, |
|
"next_outputs": next_outputs, |
|
"kv_cache": kv_cache |
|
} |
|
|
|
|
|
# Example usage |
|
def demonstrate_kv_cache(): |
|
model = SimplifiedAttentionWithKVCache() |
|
|
|
# Initial sequence |
|
x1 = torch.randn(1, 5, 512) # (batch_size=1, seq_len=5, hidden_size=512) |
|
|
|
# Generate with cache |
|
output1, kv_cache = model(x1, use_cache=True) |
|
|
|
# Next token |
|
x2 = torch.randn(1, 1, 512) # Single new token |
|
output2, kv_cache = model(x2, kv_cache=kv_cache, use_cache=True) |
|
|
|
return output1, output2, kv_cache |
|
|
|
|
|
if __name__ == "__main__": |
|
output1, output2, kv_cache = demonstrate_kv_cache() |
|
|
|
print("Demonstrating KV cache") |
|
print(f"Output 1 shape: {output1.shape}") # [1, 5, 512] |
|
print(f"Output 2 shape: {output2.shape}") # [1, 1, 512] |
|
print(f"KV cache shape: {kv_cache[0].shape}") # [1, 8, 6, 64] |
|
|
|
print("Demonstrating text generation") |
|
results = demonstrate_text_generation() |
|
|
|
print(f"Output 1 shape: {results['initial_outputs'].shape}") # [1, 3, 7] |
|
print(f"Output 2 shape: {results['next_outputs'].shape}") # [1, 1, 7] |
|
print(f"KV cache shape: {results['kv_cache'][0].shape}") # [1, 8, 6, 64] |