Created
March 28, 2026 23:33
-
-
Save PirosB3/eb88d0c10d69cf45d6872dc71f6af5a4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.nn import functional as F | |
| from torch import nn | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, num_heads: int, context_size: int, embedding_dim: int) -> None: | |
| super().__init__() | |
| self.head_size = embedding_dim // num_heads | |
| self.heads = nn.ModuleList([ | |
| Head(context_size, self.head_size, embedding_dim) | |
| for _ in range(num_heads) | |
| ]) | |
| self.projection_layer = nn.Linear(embedding_dim, embedding_dim) | |
| def forward(self, x): | |
| output = torch.cat([head(x) for head in self.heads], dim=-1) | |
| return self.projection_layer(output) | |
| class Block(nn.Module): | |
| def __init__(self, num_heads: int, context_size: int, embedding_dim: int) -> None: | |
| super().__init__() | |
| self.mat = MultiHeadAttention( | |
| num_heads, | |
| context_size, | |
| embedding_dim, | |
| ) | |
| self.ln1 = nn.LayerNorm(embedding_dim) | |
| self.ln2 = nn.LayerNorm(embedding_dim) | |
| self.ff = FeedForward(embedding_dim, embedding_dim * 4) | |
| def forward(self, x): | |
| # Normalize before attention | |
| x = x + self.mat(self.ln1(x)) | |
| # Normalize before feed-forward | |
| x = x + self.ff(self.ln2(x)) | |
| return x | |
| class FeedForward(nn.Module): | |
| def __init__(self, embedding_dim: int, hidden_layer: int) -> None: | |
| super().__init__() | |
| self.network = nn.Sequential( | |
| nn.Linear(embedding_dim, hidden_layer), | |
| nn.ReLU(), | |
| nn.Linear(hidden_layer, embedding_dim), | |
| ) | |
| def forward(self, x): | |
| return self.network(x) | |
| class TokenAndPositionEmbedding(nn.Module): | |
| def __init__(self, vocab_size: int, num_embeddings: int, context_window: int) -> None: | |
| super().__init__() | |
| self.embeddings = nn.Embedding(vocab_size, num_embeddings) | |
| self.positional = nn.Embedding(context_window, num_embeddings) | |
| def forward(self, x): | |
| width = x.shape[1] | |
| token_embeds = self.embeddings(x) # B, T, C | |
| positions = self.positional(torch.arange(width)) | |
| return token_embeds + positions | |
| class Head(nn.Module): | |
| def __init__(self, context_size: int, head_size: int, embedding_dim: int) -> None: | |
| super().__init__() | |
| self.context_size = context_size | |
| self.head_size = head_size | |
| self.embedding_dim = embedding_dim | |
| self.Q = nn.Linear(embedding_dim, head_size, bias=False) | |
| self.K = nn.Linear(embedding_dim, head_size, bias=False) | |
| self.V = nn.Linear(embedding_dim, head_size, bias=False) | |
| self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size))) | |
| def forward(self, x): | |
| _, T, _ = x.shape | |
| # Compute Q and K matrices | |
| q = self.Q(x) | |
| k = self.K(x) | |
| # Compute dot product here - this is the attention, the higher the number the more relevant the attention is for a specific word | |
| wei = q @ k.transpose(-2, -1) * (self.head_size ** -0.5) | |
| # Cut out area outside the triangle | |
| wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) | |
| wei = F.softmax(wei, dim=-1) | |
| # Compute V | |
| x = wei @ self.V(x) | |
| return x | |
| class GPTModel(nn.Module): | |
| def __init__(self, context_window: int, num_embeddings: int, num_heads: int, num_blocks: int, vocab_size: int) -> None: | |
| super().__init__() | |
| # Start with embedding layer | |
| self.embeddings = TokenAndPositionEmbedding(vocab_size, num_embeddings, context_window) | |
| self.blocks = nn.Sequential(*[ | |
| Block(num_heads, context_window, num_embeddings) | |
| for _ in range(num_blocks) | |
| ]) | |
| self.ln = nn.LayerNorm(num_embeddings) | |
| self.lm_head = nn.Linear(num_embeddings, vocab_size) | |
| def forward(self, X, targets=None): | |
| B, T = X.shape | |
| embeddings = self.embeddings(X) | |
| process_blocks = self.blocks(embeddings) | |
| logits = self.lm_head(self.ln(process_blocks)) | |
| loss = None | |
| if targets is not None: | |
| logits_flat = logits.view(B * T, -1) | |
| targets_flat = targets.view(B * T) | |
| loss = F.cross_entropy(logits_flat, targets_flat) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens, block_size): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -block_size:] | |
| logits, _ = self(idx_cond) | |
| last = logits[:, -1, :] | |
| probs = F.softmax(last, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx | |
| def generate_stream(self, idx, max_new_tokens, block_size, decode_fn): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -block_size:] | |
| logits, _ = self(idx_cond) | |
| probs = F.softmax(logits[:, -1, :], dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| yield decode_fn([idx_next.item()]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment