Last active
March 17, 2025 21:55
-
-
Save gg2001/d2a1c0e6640dfe746015ed60cf0680d1 to your computer and use it in GitHub Desktop.
Adds a KV cache to my GPT-2 inference implementation: https://gist.github.com/gg2001/dfda4ce523223afe67dd5d9f9038fc47 (>10x faster on CPU)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
from torch.nn.functional import gelu, layer_norm, softmax | |
from transformers import GPT2Model, AutoTokenizer | |
from typing import TypedDict | |
MODEL_NAME = "gpt2" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = GPT2Model.from_pretrained(MODEL_NAME).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, clean_up_tokenization_spaces=True) | |
# Hyperparameters | |
cfg = model.config.to_dict() | |
vocab_size: int = cfg["vocab_size"] # 50257, token dictionary size | |
n_positions: int = cfg["n_positions"] # 1024, "context length" | |
n_embd: int = cfg["n_embd"] # 768, token embedding dimensions | |
n_head: int = cfg["n_head"] # 12, attention heads per block | |
n_layer: int = cfg["n_layer"] # 12, number of transformer blocks | |
d_mlp: int = n_embd * 4 # 3072, MLP hidden layer dimensions | |
d_head: int = n_embd // n_head # 64, dimensions of each attention head | |
# Parameters | |
class Attention(TypedDict): | |
c_attn_weight: torch.Tensor # (768 = 64 * 12, 2304 = 768 * 3) | |
c_attn_bias: torch.Tensor # (768 * 3,) | |
c_proj_weight: torch.Tensor # (768, 768) | |
c_proj_bias: torch.Tensor # (768,) | |
class MLP(TypedDict): | |
c_fc_weight: torch.Tensor # (768, 3072) | |
c_fc_bias: torch.Tensor # (3072,) | |
c_proj_weight: torch.Tensor # (3072, 768) | |
c_proj_bias: torch.Tensor # (768,) | |
class LayerNorm(TypedDict): | |
weight: torch.Tensor # (768,) | |
bias: torch.Tensor # (768,) | |
class Block(TypedDict): | |
ln_1: LayerNorm # layer norm 1 | |
attn: Attention # multi-head self-attention | |
ln_2: LayerNorm # layer norm 2 | |
mlp: MLP # Feed-forward | |
state_dict = model.state_dict() | |
wte: torch.Tensor = state_dict["wte.weight"].to(device) # (50257, 768) | |
wpe: torch.Tensor = state_dict["wpe.weight"].to(device) # (1024, 768) | |
blocks: list[Block] = [] | |
for i in range(n_layer): | |
blocks.append( | |
{ | |
"ln_1": { | |
"weight": state_dict[f"h.{i}.ln_1.weight"].to(device), | |
"bias": state_dict[f"h.{i}.ln_1.bias"].to(device), | |
}, | |
"attn": { | |
"c_attn_weight": state_dict[f"h.{i}.attn.c_attn.weight"].to(device), | |
"c_attn_bias": state_dict[f"h.{i}.attn.c_attn.bias"].to(device), | |
"c_proj_weight": state_dict[f"h.{i}.attn.c_proj.weight"].to(device), | |
"c_proj_bias": state_dict[f"h.{i}.attn.c_proj.bias"].to(device), | |
}, | |
"ln_2": { | |
"weight": state_dict[f"h.{i}.ln_2.weight"].to(device), | |
"bias": state_dict[f"h.{i}.ln_2.bias"].to(device), | |
}, | |
"mlp": { | |
"c_fc_weight": state_dict[f"h.{i}.mlp.c_fc.weight"].to(device), | |
"c_fc_bias": state_dict[f"h.{i}.mlp.c_fc.bias"].to(device), | |
"c_proj_weight": state_dict[f"h.{i}.mlp.c_proj.weight"].to(device), | |
"c_proj_bias": state_dict[f"h.{i}.mlp.c_proj.bias"].to(device), | |
}, | |
} | |
) | |
ln_f: LayerNorm = { | |
"weight": state_dict["ln_f.weight"].to(device), | |
"bias": state_dict["ln_f.bias"].to(device), | |
} | |
def forward( | |
input: torch.Tensor, | |
kv_cache: torch.Tensor, | |
past_length: int = 0, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
batch_size, token_len = input.shape | |
# token embeddings + position embeddings | |
positions = torch.arange( | |
past_length, past_length + token_len, device=device | |
) # (token_len, n_embd) | |
x = wte[input] + wpe[positions] # (batch_size, token_len, n_embd) | |
# mask for causal attention | |
# mask not needed when sampling a single token from a populated KV cache | |
if past_length == 0 and token_len != 1: | |
mask = torch.tril(torch.ones((token_len, token_len), device=device)) | |
# increase size of KV cache | |
if past_length != 0: | |
new_cache = torch.zeros( | |
(batch_size, 2, n_layer, n_head, past_length + token_len, d_head), | |
dtype=torch.float32, | |
device=device, | |
) | |
new_cache[:, :, :, :, :past_length, :] = kv_cache | |
kv_cache = new_cache | |
# residual stream | |
for i, block in enumerate(blocks): | |
######################################## | |
# layer norm 1 | |
######################################## | |
ln_1 = layer_norm( | |
x, (n_embd,), weight=block["ln_1"]["weight"], bias=block["ln_1"]["bias"] | |
) # (batch_size, token_len, n_embd) | |
######################################## | |
# multi-head self-attention | |
######################################## | |
heads = block["attn"] | |
# query, key, value | |
qkv = ( | |
ln_1 @ heads["c_attn_weight"] + heads["c_attn_bias"] | |
) # (batch_size, token_len, n_embd * 3) | |
q, k, v = ( | |
qkv[..., :n_embd], | |
qkv[..., n_embd : 2 * n_embd], | |
qkv[..., 2 * n_embd :], | |
) # (batch_size, token_len, n_embd) | |
# separate the heads | |
q = q.view(batch_size, token_len, n_head, d_head).transpose( | |
1, 2 | |
) # (batch_size, n_head, token_len, d_head) | |
k = k.view(batch_size, token_len, n_head, d_head).transpose(1, 2) | |
v = v.view(batch_size, token_len, n_head, d_head).transpose(1, 2) | |
# update KV cache for current layer | |
kv_cache[:, 0, i, :, past_length : past_length + token_len, :] = k | |
kv_cache[:, 1, i, :, past_length : past_length + token_len, :] = v | |
# load full KV from cache | |
k_full = kv_cache[:, 0, i, :, : past_length + token_len, :] | |
v_full = kv_cache[:, 1, i, :, : past_length + token_len, :] | |
# attention scores + mask | |
attn: torch.Tensor = ( | |
q @ k_full.transpose(-2, -1) | |
) * d_head**-0.5 # (batch_size, n_head, token_len, token_len) | |
# apply mask if not sampling a single token from a populated KV cache | |
if past_length == 0 and token_len != 1: | |
attn = attn.masked_fill( | |
mask[None, None, -token_len:, :] == 0, float("-inf") | |
) | |
scores = softmax(attn, dim=-1) | |
heads_output = scores @ v_full # (batch_size, n_head, token_len, d_head) | |
# merge heads + linear layer | |
concat_heads = heads_output.transpose(1, 2).reshape( | |
batch_size, token_len, n_embd | |
) # (batch_size, token_len, n_embd) | |
attn_output = ( | |
concat_heads @ heads["c_proj_weight"] + heads["c_proj_bias"] | |
) # (batch_size, token_len, n_embd) | |
######################################## | |
# residual connection 1 + layer norm 2 | |
######################################## | |
x = x + attn_output # (batch_size, token_len, n_embd) | |
ln_2 = layer_norm( | |
x, (n_embd,), weight=block["ln_2"]["weight"], bias=block["ln_2"]["bias"] | |
) | |
######################################## | |
# mlp | |
######################################## | |
mlp = block["mlp"] | |
# hidden layer | |
mlp_output = ( | |
ln_2 @ mlp["c_fc_weight"] + mlp["c_fc_bias"] | |
) # (batch_size, token_len, d_mlp) | |
mlp_output = gelu(mlp_output) | |
# output layer | |
mlp_output = ( | |
mlp_output @ mlp["c_proj_weight"] + mlp["c_proj_bias"] | |
) # (batch_size, token_len, n_embd) | |
######################################## | |
# residual connection 2 | |
######################################## | |
x = x + mlp_output | |
# final layer norm | |
x = layer_norm( | |
x, (n_embd,), weight=ln_f["weight"], bias=ln_f["bias"] | |
) # (batch_size, token_len, n_embd) | |
# unembed layer = transpose of the embedding layer | |
logits = x @ wte.T # (batch_size, token_len, vocab_size) | |
return logits, kv_cache | |
def generate(input: str, num_tokens: int, stream: bool = False) -> str: | |
# (batch_size, token_len) | |
tokens = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.int64, device=device) | |
if input != "": | |
tokenized: torch.Tensor = tokenizer(input, return_tensors="pt").input_ids.to( | |
device | |
) # (1, token_len) | |
tokens = torch.cat([tokens, tokenized], dim=1) | |
new_tokens = torch.empty(0, dtype=torch.int64, device=device) | |
past_length = tokens.size(1) | |
# Initialize KV cache | |
kv_cache = torch.zeros( | |
(1, 2, n_layer, n_head, past_length, d_head), | |
dtype=torch.float32, | |
device=device, | |
) # (batch_size, qk, n_layer, n_head, token_len, d_head) | |
# Populate the KV cache with the context | |
with torch.no_grad(): | |
logits, kv_cache = forward(tokens, kv_cache, 0) # (1, token_len, vocab_size) | |
for _ in range(num_tokens): | |
# Convert logits to probabilities | |
probs = softmax(logits[0, -1, :], dim=-1) # (vocab_size,) | |
# Sample from the distribution | |
next_token = torch.multinomial(probs, num_samples=1) # (1,) | |
new_tokens = torch.cat([new_tokens, next_token], dim=0) | |
if stream: | |
print(tokenizer.decode(next_token, skip_special_tokens=True), end="") | |
# Stop if we generate an EOS token | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
# Process only the sampled token | |
with torch.no_grad(): | |
input_token = next_token.unsqueeze(0) # (1, 1) | |
logits, kv_cache = forward(input_token, kv_cache, past_length) | |
past_length += 1 | |
if stream: | |
print() | |
return tokenizer.decode(new_tokens, skip_special_tokens=True) | |
if __name__ == "__main__": | |
default_num_tokens = 50 | |
parser = argparse.ArgumentParser(description="GPT-2") | |
parser.add_argument( | |
"--prompt", type=str, required=True, help="Input prompt for text generation" | |
) | |
parser.add_argument( | |
"--tokens", | |
type=int, | |
default=default_num_tokens, | |
help=f"Number of tokens to generate (default: {default_num_tokens})", | |
) | |
args = parser.parse_args() | |
generate(args.prompt, args.tokens, stream=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment