Created
September 1, 2025 12:03
-
-
Save keskival/0cbddc1ce1c99c23e1741a8daf97d871 to your computer and use it in GitHub Desktop.
GPT-2 implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # mini_gpt2_torch.py | |
| # A minimal GPT-2–style decoder-only Transformer in PyTorch with explicit ops, | |
| # but autograd handles gradients and optimization. | |
| import math, time, random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # -------------- Tokenizer -------------- | |
| class CharTokenizer: | |
| def __init__(self, text): | |
| chars = sorted(set(text)) | |
| self.stoi = {ch: i for i, ch in enumerate(chars)} | |
| self.itos = {i: ch for ch, i in self.stoi.items()} | |
| self.vocab_size = len(self.stoi) | |
| def encode(self, s): return torch.tensor([self.stoi[ch] for ch in s], dtype=torch.long) | |
| def decode(self, ids): return "".join(self.itos[int(i)] for i in ids) | |
| # -------------- Model components -------------- | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, d_model, n_heads, max_seq_len): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| self.q = nn.Linear(d_model, d_model) | |
| self.k = nn.Linear(d_model, d_model) | |
| self.v = nn.Linear(d_model, d_model) | |
| self.proj = nn.Linear(d_model, d_model) | |
| # register a causal mask buffer (not a parameter) | |
| mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool() | |
| self.register_buffer("causal_mask", mask) | |
| def forward(self, x): | |
| B, T, D = x.shape | |
| h, Hd = self.n_heads, self.head_dim | |
| q = self.q(x).view(B, T, h, Hd).transpose(1, 2) # (B,h,T,Hd) | |
| k = self.k(x).view(B, T, h, Hd).transpose(1, 2) # (B,h,T,Hd) | |
| v = self.v(x).view(B, T, h, Hd).transpose(1, 2) # (B,h,T,Hd) | |
| # scaled dot-product attention | |
| att = (q @ k.transpose(-2, -1)) / math.sqrt(Hd) # (B,h,T,T) | |
| att = att.masked_fill(self.causal_mask[:T, :T], float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| y = att @ v # (B,h,T,Hd) | |
| y = y.transpose(1, 2).contiguous().view(B, T, D) # (B,T,D) | |
| return self.proj(y) | |
| class MLP(nn.Module): | |
| def __init__(self, d_model, d_hidden): | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, d_hidden) | |
| self.fc2 = nn.Linear(d_hidden, d_model) | |
| def forward(self, x): | |
| return self.fc2(F.gelu(self.fc1(x))) # GELU approx is builtin | |
| class Block(nn.Module): | |
| def __init__(self, d_model, n_heads, d_mlp, max_seq_len): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.attn = CausalSelfAttention(d_model, n_heads, max_seq_len) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.mlp = MLP(d_model, d_mlp) | |
| def forward(self, x): | |
| x = x + self.attn(self.ln1(x)) # pre-LN residual | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class GPT2(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, n_layers=4, n_heads=4, d_mlp=None, max_seq_len=128, tie_weights=True): | |
| super().__init__() | |
| if d_mlp is None: d_mlp = 4 * d_model | |
| self.token_emb = nn.Embedding(vocab_size, d_model) | |
| self.pos_emb = nn.Embedding(max_seq_len, d_model) | |
| self.blocks = nn.ModuleList([Block(d_model, n_heads, d_mlp, max_seq_len) for _ in range(n_layers)]) | |
| self.ln_f = nn.LayerNorm(d_model) | |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=True) | |
| if tie_weights: | |
| self.lm_head.weight = self.token_emb.weight # weight tying | |
| self.max_seq_len = max_seq_len | |
| self.vocab_size = vocab_size | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| assert T <= self.max_seq_len | |
| pos = torch.arange(0, T, device=idx.device).unsqueeze(0) # (1,T) | |
| x = self.token_emb(idx) + self.pos_emb(pos) # (B,T,D) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) # (B,T,V) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1)) | |
| return logits, loss | |
| @torch.no_grad() | |
| def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None): | |
| self.eval() | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -self.max_seq_len:] | |
| logits, _ = self.forward(idx_cond) | |
| logits = logits[:, -1, :] / max(temperature, 1e-6) # next-token | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| thresh = v[:, -1].unsqueeze(-1) | |
| logits = torch.where(logits < thresh, torch.full_like(logits, -1e9), logits) | |
| probs = F.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) # (B,1) | |
| idx = torch.cat([idx, next_id], dim=1) | |
| return idx | |
| # -------------- Tiny training driver -------------- | |
| def get_batch(data, batch_size, block_size, device): | |
| # data: 1D tensor of token ids | |
| ix = torch.randint(0, len(data) - block_size - 1, (batch_size,)) | |
| x = torch.stack([data[i:i+block_size] for i in ix]).to(device) | |
| y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device) | |
| return x, y | |
| def demo(): | |
| torch.manual_seed(123) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| corpus = """ | |
| Alice was beginning to get very tired of sitting by her sister on the bank. | |
| She had nothing to do: once or twice she had peeped into the book her sister was reading, | |
| but it had no pictures or conversations in it, and what is the use of a book, thought Alice, without pictures or conversations? | |
| """ | |
| tok = CharTokenizer(corpus) | |
| data = tok.encode(corpus) | |
| model = GPT2(vocab_size=tok.vocab_size, d_model=192, n_layers=3, n_heads=3, | |
| d_mlp=768, max_seq_len=128, tie_weights=True).to(device) | |
| optim = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9,0.999), weight_decay=0.01) | |
| batch_size, block_size, steps = 32, 64, 300 | |
| print(f"device={device}, vocab={tok.vocab_size}, params={sum(p.numel() for p in model.parameters()):,}") | |
| model.train() | |
| t0 = time.time() | |
| for it in range(1, steps+1): | |
| x, y = get_batch(data, batch_size, block_size, device) | |
| optim.zero_grad(set_to_none=True) | |
| _, loss = model(x, y) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optim.step() | |
| if it % 50 == 0 or it == 1: | |
| dt = (time.time() - t0)*1000; t0 = time.time() | |
| print(f"iter {it:4d} | loss {loss.item():.3f} | {dt:6.1f} ms") | |
| # sample | |
| model.eval() | |
| with torch.no_grad(): | |
| start = tok.encode("Alice ").unsqueeze(0).to(device) | |
| out = model.generate(start, max_new_tokens=200, temperature=0.9, top_k=40)[0].tolist() | |
| print("\n=== SAMPLE ===") | |
| print(tok.decode(out)) | |
| if __name__ == "__main__": | |
| demo() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment