Skip to content

Instantly share code, notes, and snippets.

@keskival
Created September 1, 2025 12:03
Show Gist options
  • Save keskival/0cbddc1ce1c99c23e1741a8daf97d871 to your computer and use it in GitHub Desktop.
Save keskival/0cbddc1ce1c99c23e1741a8daf97d871 to your computer and use it in GitHub Desktop.
GPT-2 implementation
# mini_gpt2_torch.py
# A minimal GPT-2–style decoder-only Transformer in PyTorch with explicit ops,
# but autograd handles gradients and optimization.
import math, time, random
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------- Tokenizer --------------
class CharTokenizer:
def __init__(self, text):
chars = sorted(set(text))
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for ch, i in self.stoi.items()}
self.vocab_size = len(self.stoi)
def encode(self, s): return torch.tensor([self.stoi[ch] for ch in s], dtype=torch.long)
def decode(self, ids): return "".join(self.itos[int(i)] for i in ids)
# -------------- Model components --------------
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, max_seq_len):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.proj = nn.Linear(d_model, d_model)
# register a causal mask buffer (not a parameter)
mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
self.register_buffer("causal_mask", mask)
def forward(self, x):
B, T, D = x.shape
h, Hd = self.n_heads, self.head_dim
q = self.q(x).view(B, T, h, Hd).transpose(1, 2) # (B,h,T,Hd)
k = self.k(x).view(B, T, h, Hd).transpose(1, 2) # (B,h,T,Hd)
v = self.v(x).view(B, T, h, Hd).transpose(1, 2) # (B,h,T,Hd)
# scaled dot-product attention
att = (q @ k.transpose(-2, -1)) / math.sqrt(Hd) # (B,h,T,T)
att = att.masked_fill(self.causal_mask[:T, :T], float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v # (B,h,T,Hd)
y = y.transpose(1, 2).contiguous().view(B, T, D) # (B,T,D)
return self.proj(y)
class MLP(nn.Module):
def __init__(self, d_model, d_hidden):
super().__init__()
self.fc1 = nn.Linear(d_model, d_hidden)
self.fc2 = nn.Linear(d_hidden, d_model)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x))) # GELU approx is builtin
class Block(nn.Module):
def __init__(self, d_model, n_heads, d_mlp, max_seq_len):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, n_heads, max_seq_len)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = MLP(d_model, d_mlp)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # pre-LN residual
x = x + self.mlp(self.ln2(x))
return x
class GPT2(nn.Module):
def __init__(self, vocab_size, d_model=256, n_layers=4, n_heads=4, d_mlp=None, max_seq_len=128, tie_weights=True):
super().__init__()
if d_mlp is None: d_mlp = 4 * d_model
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList([Block(d_model, n_heads, d_mlp, max_seq_len) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=True)
if tie_weights:
self.lm_head.weight = self.token_emb.weight # weight tying
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
def forward(self, idx, targets=None):
B, T = idx.shape
assert T <= self.max_seq_len
pos = torch.arange(0, T, device=idx.device).unsqueeze(0) # (1,T)
x = self.token_emb(idx) + self.pos_emb(pos) # (B,T,D)
for blk in self.blocks:
x = blk(x)
x = self.ln_f(x)
logits = self.lm_head(x) # (B,T,V)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.max_seq_len:]
logits, _ = self.forward(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-6) # next-token
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
thresh = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < thresh, torch.full_like(logits, -1e9), logits)
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # (B,1)
idx = torch.cat([idx, next_id], dim=1)
return idx
# -------------- Tiny training driver --------------
def get_batch(data, batch_size, block_size, device):
# data: 1D tensor of token ids
ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix]).to(device)
y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device)
return x, y
def demo():
torch.manual_seed(123)
device = "cuda" if torch.cuda.is_available() else "cpu"
corpus = """
Alice was beginning to get very tired of sitting by her sister on the bank.
She had nothing to do: once or twice she had peeped into the book her sister was reading,
but it had no pictures or conversations in it, and what is the use of a book, thought Alice, without pictures or conversations?
"""
tok = CharTokenizer(corpus)
data = tok.encode(corpus)
model = GPT2(vocab_size=tok.vocab_size, d_model=192, n_layers=3, n_heads=3,
d_mlp=768, max_seq_len=128, tie_weights=True).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9,0.999), weight_decay=0.01)
batch_size, block_size, steps = 32, 64, 300
print(f"device={device}, vocab={tok.vocab_size}, params={sum(p.numel() for p in model.parameters()):,}")
model.train()
t0 = time.time()
for it in range(1, steps+1):
x, y = get_batch(data, batch_size, block_size, device)
optim.zero_grad(set_to_none=True)
_, loss = model(x, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
if it % 50 == 0 or it == 1:
dt = (time.time() - t0)*1000; t0 = time.time()
print(f"iter {it:4d} | loss {loss.item():.3f} | {dt:6.1f} ms")
# sample
model.eval()
with torch.no_grad():
start = tok.encode("Alice ").unsqueeze(0).to(device)
out = model.generate(start, max_new_tokens=200, temperature=0.9, top_k=40)[0].tolist()
print("\n=== SAMPLE ===")
print(tok.decode(out))
if __name__ == "__main__":
demo()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment