Skip to content

Instantly share code, notes, and snippets.

@PirosB3
Created March 28, 2026 23:33
Show Gist options
  • Select an option

  • Save PirosB3/eb88d0c10d69cf45d6872dc71f6af5a4 to your computer and use it in GitHub Desktop.

Select an option

Save PirosB3/eb88d0c10d69cf45d6872dc71f6af5a4 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import functional as F
from torch import nn
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads: int, context_size: int, embedding_dim: int) -> None:
super().__init__()
self.head_size = embedding_dim // num_heads
self.heads = nn.ModuleList([
Head(context_size, self.head_size, embedding_dim)
for _ in range(num_heads)
])
self.projection_layer = nn.Linear(embedding_dim, embedding_dim)
def forward(self, x):
output = torch.cat([head(x) for head in self.heads], dim=-1)
return self.projection_layer(output)
class Block(nn.Module):
def __init__(self, num_heads: int, context_size: int, embedding_dim: int) -> None:
super().__init__()
self.mat = MultiHeadAttention(
num_heads,
context_size,
embedding_dim,
)
self.ln1 = nn.LayerNorm(embedding_dim)
self.ln2 = nn.LayerNorm(embedding_dim)
self.ff = FeedForward(embedding_dim, embedding_dim * 4)
def forward(self, x):
# Normalize before attention
x = x + self.mat(self.ln1(x))
# Normalize before feed-forward
x = x + self.ff(self.ln2(x))
return x
class FeedForward(nn.Module):
def __init__(self, embedding_dim: int, hidden_layer: int) -> None:
super().__init__()
self.network = nn.Sequential(
nn.Linear(embedding_dim, hidden_layer),
nn.ReLU(),
nn.Linear(hidden_layer, embedding_dim),
)
def forward(self, x):
return self.network(x)
class TokenAndPositionEmbedding(nn.Module):
def __init__(self, vocab_size: int, num_embeddings: int, context_window: int) -> None:
super().__init__()
self.embeddings = nn.Embedding(vocab_size, num_embeddings)
self.positional = nn.Embedding(context_window, num_embeddings)
def forward(self, x):
width = x.shape[1]
token_embeds = self.embeddings(x) # B, T, C
positions = self.positional(torch.arange(width))
return token_embeds + positions
class Head(nn.Module):
def __init__(self, context_size: int, head_size: int, embedding_dim: int) -> None:
super().__init__()
self.context_size = context_size
self.head_size = head_size
self.embedding_dim = embedding_dim
self.Q = nn.Linear(embedding_dim, head_size, bias=False)
self.K = nn.Linear(embedding_dim, head_size, bias=False)
self.V = nn.Linear(embedding_dim, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))
def forward(self, x):
_, T, _ = x.shape
# Compute Q and K matrices
q = self.Q(x)
k = self.K(x)
# Compute dot product here - this is the attention, the higher the number the more relevant the attention is for a specific word
wei = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)
# Cut out area outside the triangle
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# Compute V
x = wei @ self.V(x)
return x
class GPTModel(nn.Module):
def __init__(self, context_window: int, num_embeddings: int, num_heads: int, num_blocks: int, vocab_size: int) -> None:
super().__init__()
# Start with embedding layer
self.embeddings = TokenAndPositionEmbedding(vocab_size, num_embeddings, context_window)
self.blocks = nn.Sequential(*[
Block(num_heads, context_window, num_embeddings)
for _ in range(num_blocks)
])
self.ln = nn.LayerNorm(num_embeddings)
self.lm_head = nn.Linear(num_embeddings, vocab_size)
def forward(self, X, targets=None):
B, T = X.shape
embeddings = self.embeddings(X)
process_blocks = self.blocks(embeddings)
logits = self.lm_head(self.ln(process_blocks))
loss = None
if targets is not None:
logits_flat = logits.view(B * T, -1)
targets_flat = targets.view(B * T)
loss = F.cross_entropy(logits_flat, targets_flat)
return logits, loss
def generate(self, idx, max_new_tokens, block_size):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)
last = logits[:, -1, :]
probs = F.softmax(last, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
def generate_stream(self, idx, max_new_tokens, block_size, decode_fn):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)
probs = F.softmax(logits[:, -1, :], dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
yield decode_fn([idx_next.item()])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment