Last active
March 9, 2024 22:22
-
-
Save kirisakow/7af90f26a8bc3f674058ddda71c6f518 to your computer and use it in GitHub Desktop.
Building a homemade GPT from scratch. Engineering a modern NLP AI model, based on transformers and attention.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [ | |
{ | |
"file_id": "https://gist.github.com/kirisakow/7af90f26a8bc3f674058ddda71c6f518", | |
"timestamp": "1710022634365" | |
} | |
], | |
"collapsed_sections": [ | |
"LgjYFf9l4rTB", | |
"vVSK4Ow4h3w2", | |
"GJf0Jd5esEA-", | |
"vnPYzC3StW48", | |
"Rp2mQrOwuF6J", | |
"Fz-LPul00G_u", | |
"lqGcHme1-C78", | |
"oy5KNILoYGVY", | |
"JoAeSYDjocBZ", | |
"pA6FtZzS7EIO", | |
"YSbZMCte7q32", | |
"MFMwmN7K8IRq", | |
"gYozYnOQ8Wru", | |
"382xJUI88xic" | |
], | |
"authorship_tag": "ABX9TyMIUDDJkOfQJRkL+L4n7Vom", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/kirisakow/7af90f26a8bc3f674058ddda71c6f518/build-a-homemade-gpt-from-scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Building a homemade GPT from scratch\n", | |
"\n", | |
"Engineering a modern NLP AI model, based on transformers and attention.\n", | |
"<br><br>\n", | |
"\n", | |
"This is a refactored version of [NeetCode ML tutorial][neetcode] and [original notebook][original_notebook], enhanced with coding best practices (constants instead of hardcoded “magic” literals; reusable code wrapped in separate functions; extensive use of generators; etc).\n", | |
"\n", | |
"Other resources:\n", | |
"* [PyTorch][pytorch_docs] library documentation;\n", | |
"* Andrej Karpathy's [Machine Learning Practice Problems][karpathy_playlist] YouTube playlist;\n", | |
"* [Modern Approaches in Natural Language Processing][modern_nlp], a 2020 online ebook;\n", | |
"* [Python][python_docs] official documentation;\n", | |
"* [Effective Python][effective_python], 2nd ed., a 2019 book by Brett Slatkin;\n", | |
"* A comprehensive curated list of [software engineering best practices and concepts][best_practices] such as clean code, simple design, software craftsmanship, YAGNI, KISS, SOLID, TDD, design patterns, and others.\n", | |
"\n", | |
"Author: [Kiril Isakov][kisakov_linkedin] ([kirisakow][kirisakow_github])\n", | |
"\n", | |
"[kisakov_linkedin]: https://www.linkedin.com/in/kisakov/\n", | |
"[kirisakow_github]: https://github.com/kirisakow\n", | |
"[neetcode]: https://neetcode.io/practice\n", | |
"[original_notebook]: https://colab.research.google.com/drive/1L92UwfFlVlog-p8PhKe-ioROAaBa4aFc\n", | |
"[karpathy_playlist]: https://www.youtube.com/playlist?list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ\n", | |
"[pytorch_docs]: https://pytorch.org/docs/stable/\n", | |
"[modern_nlp]: https://slds-lmu.github.io/seminar_nlp_ss20\n", | |
"[python_docs]: https://docs.python.org/\n", | |
"[effective_python]: https://effectivepython.com\n", | |
"[best_practices]: https://gitlab.com/kirisakow/clean-code-software-craftsmanship-best-practices" | |
], | |
"metadata": { | |
"id": "fvuSVgapabdw" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Install and initialize libraries and constants" | |
], | |
"metadata": { | |
"id": "LgjYFf9l4rTB" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --quiet torchtyping" | |
], | |
"metadata": { | |
"id": "NPNprFOc0rZs" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from numpy.typing import NDArray\n", | |
"from torchtyping import TensorType\n", | |
"from typing import List, Tuple\n", | |
"import itertools\n", | |
"import numpy as np\n", | |
"import re\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"\n", | |
"COLS_DIM_INDEX, ROWS_DIM_INDEX = 0, 1\n", | |
"DECIMAL_PRECISION = 4\n", | |
"DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", | |
"DROPOUT_PROBABILITY = 0.2\n", | |
"EMB_LAYER_LEN = 16\n", | |
"FINAL_LAYER_LEN = 10\n", | |
"FIRST_LAYER_LEN = 512\n", | |
"IMG_SHAPE = [28, 28]\n", | |
"KNOWN_OUTPUT_DIMENSION_SIZE = 2\n", | |
"LEARNING_RATE = 0.01\n", | |
"LINEAR_NN_SCALE_FACTOR = 4\n", | |
"MASK_FILLER = 0\n", | |
"OUTPUT_LAYER_LEN = 1\n", | |
"T_DIM_INDEX = 1\n", | |
"TOKEN_DELIMITER = r'\\s'\n", | |
"WORDS_INDEX = {0: '\\n', 1: ' ', 2: '!', 3: '\"', 4: '$', 5: '%', 6: '&', 7: \"'\", 8: '(', 9: ')', 10: '*',\n", | |
" 11: '+', 12: ',', 13: '-', 14: '.', 15: '/', 16: '0', 17: '1', 18: '2', 19: '3', 20: '4',\n", | |
" 21: '5', 22: '6', 23: '7', 24: '8', 25: '9', 26: ':', 27: ';', 28: '?', 29: 'A', 30: 'B',\n", | |
" 31: 'C', 32: 'D', 33: 'E', 34: 'F', 35: 'G', 36: 'H', 37: 'I', 38: 'J', 39: 'K', 40: 'L',\n", | |
" 41: 'M', 42: 'N', 43: 'O', 44: 'P', 45: 'Q', 46: 'R', 47: 'S', 48: 'T', 49: 'U', 50: 'V',\n", | |
" 51: 'W', 52: 'X', 53: 'Y', 54: 'Z', 55: '[', 56: ']', 57: '_', 58: 'a', 59: 'b', 60: 'c',\n", | |
" 61: 'd', 62: 'e', 63: 'f', 64: 'g', 65: 'h', 66: 'i', 67: 'j', 68: 'k', 69: 'l', 70: 'm',\n", | |
" 71: 'n', 72: 'o', 73: 'p', 74: 'q', 75: 'r', 76: 's', 77: 't', 78: 'u', 79: 'v', 80: 'w',\n", | |
" 81: 'x', 82: 'y', 83: 'z', 84: '{', 85: '|', 86: '}', 87: 'à', 88: 'á', 89: 'è', 90: 'é',\n", | |
" 91: 'ë', 92: 'ñ', 93: 'ó', 94: 'ú', 95: '\\u2005', 96: '–', 97: '—', 98: '‘', 99: '’', 100: '“',\n", | |
" 101: '”', 102: '…', 103: '\\u205f'}\n", | |
"WORDS_INDEX = tuple(WORDS_INDEX.values())" | |
], | |
"metadata": { | |
"id": "sB6upU1mr7g_" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 0. ML, NLP, and PyTorch fundamentals" | |
], | |
"metadata": { | |
"id": "vVSK4Ow4h3w2" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.1. Gradient Descent" | |
], | |
"metadata": { | |
"id": "wdJ4NLBGiuIy" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "GJf0Jd5esEA-" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "A8Sk92kFYhQI" | |
}, | |
"outputs": [], | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/gradient-descent\n", | |
"class Solution:\n", | |
" def get_minimizer(self, iterations: int, learning_rate: float, init: int) -> float:\n", | |
" minimizer = init\n", | |
" for _ in range(iterations):\n", | |
" derivative = 2 * minimizer\n", | |
" minimizer = minimizer - learning_rate * derivative\n", | |
" return round(minimizer, 5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n", | |
"* recursivity;" | |
], | |
"metadata": { | |
"id": "rCbwYQ6Cs7UE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution:\n", | |
"\n", | |
" def get_minimizer(self, iterations: int, learning_rate: float, init: int) -> float:\n", | |
" \"\"\"Recursively perform a gradient descent toward the minimum of the x² function, with\n", | |
" * `iterations`: number of steps;\n", | |
" * `learning_rate`: step width;\n", | |
" * `init`: current minimum\n", | |
" \"\"\"\n", | |
" if iterations == 0:\n", | |
" return round(init, DECIMAL_PRECISION)\n", | |
" derivative = 2 * init\n", | |
" current_guess = init - learning_rate * derivative\n", | |
" return self.get_minimizer(init=current_guess,\n", | |
" iterations=iterations - 1,\n", | |
" learning_rate=learning_rate)" | |
], | |
"metadata": { | |
"id": "UU87zm7ej_Uf" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.2. Linear regression: the `forward()` function" | |
], | |
"metadata": { | |
"id": "FTnOHMUlmxC5" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "vnPYzC3StW48" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/linear-regression-forward\n", | |
"class Solution:\n", | |
"\n", | |
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n", | |
" prediction = np.matmul(X, weights)\n", | |
" return np.round(prediction, 5)\n", | |
"\n", | |
" def get_error(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64]) -> float:\n", | |
" error = np.mean(np.square(model_prediction - ground_truth))\n", | |
" return round(error, 5)" | |
], | |
"metadata": { | |
"id": "eGUjiRAzoz7C" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;" | |
], | |
"metadata": { | |
"id": "qL4xW9kAt15M" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution:\n", | |
"\n", | |
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n", | |
" prediction = np.matmul(X, weights)\n", | |
" return np.round(prediction, DECIMAL_PRECISION)\n", | |
"\n", | |
" def get_error(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64]) -> float:\n", | |
" error = np.mean(np.square(model_prediction - ground_truth))\n", | |
" return round(error, DECIMAL_PRECISION)" | |
], | |
"metadata": { | |
"id": "dKrLDvj6pmZ5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.3. Linear regression: training" | |
], | |
"metadata": { | |
"id": "YO5QRRrVqnzg" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "Rp2mQrOwuF6J" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/linear-regression-training\n", | |
"class Solution:\n", | |
" def get_derivative(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64], N: int, X: NDArray[np.float64], desired_weight: int) -> float:\n", | |
" return -2 * np.dot(ground_truth - model_prediction, X[:, desired_weight]) / N\n", | |
"\n", | |
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n", | |
" return np.squeeze(np.matmul(X, weights))\n", | |
"\n", | |
" learning_rate = 0.01\n", | |
"\n", | |
" def train_model(self, X: NDArray[np.float64], Y: NDArray[np.float64], num_iterations: int, initial_weights: NDArray[np.float64]) -> NDArray[np.float64]:\n", | |
" for _ in range(num_iterations):\n", | |
" model_prediction = self.get_model_prediction(X, initial_weights)\n", | |
"\n", | |
" d1 = self.get_derivative(model_prediction, Y, len(X), X, 0)\n", | |
" d2 = self.get_derivative(model_prediction, Y, len(X), X, 1)\n", | |
" d3 = self.get_derivative(model_prediction, Y, len(X), X, 2)\n", | |
"\n", | |
" initial_weights[0] = initial_weights[0] - d1 * self.learning_rate\n", | |
" initial_weights[1] = initial_weights[1] - d2 * self.learning_rate\n", | |
" initial_weights[2] = initial_weights[2] - d3 * self.learning_rate\n", | |
"\n", | |
" return np.round(initial_weights, 5)" | |
], | |
"metadata": { | |
"id": "HSnokAW_rFjR" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability;" | |
], | |
"metadata": { | |
"id": "o1IjpcU5u2j1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution:\n", | |
"\n", | |
" def get_derivative(self, model_prediction: NDArray[np.float64], ground_truth: NDArray[np.float64], N: int, X: NDArray[np.float64], desired_weight: int) -> float:\n", | |
" return -2 * np.dot(ground_truth - model_prediction, X[:, desired_weight]) / N\n", | |
"\n", | |
" def get_model_prediction(self, X: NDArray[np.float64], weights: NDArray[np.float64]) -> NDArray[np.float64]:\n", | |
" return np.squeeze(np.matmul(X, weights))\n", | |
"\n", | |
" def update_weights(self, actual_weights: NDArray[np.float64], model_prediction: NDArray[np.float64],\n", | |
" Y: NDArray[np.float64], X: NDArray[np.float64]) -> NDArray[np.float64]:\n", | |
" weights_indices = range(len(X[0]))\n", | |
" for i in weights_indices:\n", | |
" derivative = self.get_derivative(model_prediction, Y, len(X), X, i)\n", | |
" actual_weights[i] -= derivative * LEARNING_RATE\n", | |
" return actual_weights\n", | |
"\n", | |
" def train_model(self, X: NDArray[np.float64], Y: NDArray[np.float64], num_iterations: int, initial_weights: NDArray[np.float64]) -> NDArray[np.float64]:\n", | |
" actual_weights = initial_weights\n", | |
" for _ in range(num_iterations):\n", | |
" model_prediction = self.get_model_prediction(X, actual_weights)\n", | |
" actual_weights = self.update_weights(actual_weights, model_prediction, Y, X)\n", | |
" return np.round(actual_weights, DECIMAL_PRECISION)\n" | |
], | |
"metadata": { | |
"id": "56aHIdUyriJC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.4. PyTorch basics" | |
], | |
"metadata": { | |
"id": "cPN38c420BNN" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "Fz-LPul00G_u" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/basics-of-pytorch\n", | |
"class Solution:\n", | |
" def reshape(self, to_reshape: TensorType[float]) -> TensorType[float]:\n", | |
" M, N = to_reshape.shape\n", | |
" reshaped = torch.reshape(to_reshape, (M * N // 2, 2))\n", | |
" return torch.round(reshaped, decimals=4)\n", | |
"\n", | |
" def average(self, to_avg: TensorType[float]) -> TensorType[float]:\n", | |
" averaged = torch.mean(to_avg, dim = 0)\n", | |
" return torch.round(averaged, decimals=4)\n", | |
"\n", | |
" def concatenate(self, cat_one: TensorType[float], cat_two: TensorType[float]) -> TensorType[float]:\n", | |
" concatenated = torch.cat((cat_one, cat_two), dim = 1)\n", | |
" return torch.round(concatenated, decimals=4)\n", | |
"\n", | |
" def get_loss(self, prediction: TensorType[float], target: TensorType[float]) -> TensorType[float]:\n", | |
" loss = torch.nn.functional.mse_loss(prediction, target)\n", | |
" return torch.round(loss, decimals=4)\n" | |
], | |
"metadata": { | |
"id": "N6O0I5g80KAJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n", | |
"* better use of torch API (see `reshape()` function): `-1` is a built-in shortcut for `M⋅N // the_other_output_dim`;" | |
], | |
"metadata": { | |
"id": "FM7sOJG20KqG" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution:\n", | |
"\n", | |
" def reshape(self, to_reshape: TensorType[float]) -> TensorType[float]:\n", | |
" \"\"\"Reshape an M×N tensor into a (M⋅N // 2)×2 tensor\"\"\"\n", | |
" reshaped = to_reshape.view(-1, KNOWN_OUTPUT_DIMENSION_SIZE)\n", | |
" return torch.round(reshaped, decimals=DECIMAL_PRECISION)\n", | |
"\n", | |
"\n", | |
" def average(self, to_avg: TensorType[float]) -> TensorType[float]:\n", | |
" \"\"\"Find the average of every column in a tensor.\"\"\"\n", | |
" averaged = torch.mean(to_avg, dim=COLS_DIM_INDEX)\n", | |
" return torch.round(averaged, decimals=DECIMAL_PRECISION)\n", | |
"\n", | |
"\n", | |
" def concatenate(self, cat_one: TensorType[float], cat_two: TensorType[float]) -> TensorType[float]:\n", | |
" \"\"\"Combine an M×N tensor and a M×M tensor into a M×(M+N) tensor\"\"\"\n", | |
" concatenated = torch.cat((cat_one, cat_two), dim=ROWS_DIM_INDEX)\n", | |
" return torch.round(concatenated, decimals=DECIMAL_PRECISION)\n", | |
"\n", | |
"\n", | |
" def get_loss(self, prediction: TensorType[float], target: TensorType[float]) -> TensorType[float]:\n", | |
" \"\"\"Calculate the mean squared error loss between a prediction and target tensor\"\"\"\n", | |
" loss = torch.nn.functional.mse_loss(prediction, target)\n", | |
" return torch.round(loss, decimals=DECIMAL_PRECISION)" | |
], | |
"metadata": { | |
"id": "aYFbq5oS0NxC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.5. Handwritten digits classifier (based on MNIST dataset)" | |
], | |
"metadata": { | |
"id": "e_tW-g-_46Mw" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "G5SqnDMh47sN" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/handwritten-digit-classifier\n", | |
"class Solution(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.first_linear = nn.Linear(784, 512)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.dropout = nn.Dropout(p=0.2)\n", | |
" self.projection = nn.Linear(512, 10)\n", | |
" self.sigmoid = nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, images: TensorType[float]) -> TensorType[float]:\n", | |
" torch.manual_seed(0)\n", | |
" out = self.sigmoid(self.projection(self.dropout(self.relu(self.first_linear(images)))))\n", | |
" return torch.round(out, decimals=4)" | |
], | |
"metadata": { | |
"id": "Cl2CAsrb5PzW" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n", | |
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability;" | |
], | |
"metadata": { | |
"id": "m_91hkSu5lKp" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution(nn.Module):\n", | |
"\n", | |
" def get_img_area(self, img_shape: list) -> int:\n", | |
" img_shape_as_tensor = torch.tensor(img_shape)\n", | |
" return torch.prod(img_shape_as_tensor)\n", | |
"\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.first_linear = nn.Linear(self.get_img_area(IMG_SHAPE), FIRST_LAYER_LEN)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.dropout = nn.Dropout(p=DROPOUT_PROBABILITY)\n", | |
" self.projection = nn.Linear(FIRST_LAYER_LEN, FINAL_LAYER_LEN)\n", | |
" self.sigmoid = nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, images: TensorType[float]) -> TensorType[float]:\n", | |
" torch.manual_seed(0)\n", | |
" ret = self.first_linear(images)\n", | |
" ret = self.relu(ret)\n", | |
" ret = self.dropout(ret)\n", | |
" ret = self.projection(ret)\n", | |
" ret = self.sigmoid(ret)\n", | |
" return torch.round(ret, decimals=DECIMAL_PRECISION)" | |
], | |
"metadata": { | |
"id": "XP14vrnc5pbs" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.6. An introduction to natural language processing (NLP)" | |
], | |
"metadata": { | |
"id": "FZUF8EtQ9yis" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "lqGcHme1-C78" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode Solution: https://neetcode.io/problems/nlp-intro\n", | |
"class Solution:\n", | |
" def get_dataset(self, positive: List[str], negative: List[str]) -> TensorType[float]:\n", | |
" # First let's get the total set of words\n", | |
" words = set()\n", | |
" combined = positive + negative\n", | |
" for sentence in combined:\n", | |
" for word in sentence.split():\n", | |
" words.add(word)\n", | |
"\n", | |
" # Now let's build a mapping\n", | |
" sorted_list = sorted(list(words))\n", | |
" word_to_int = {}\n", | |
" for i, c in enumerate(sorted_list):\n", | |
" word_to_int[c] = i + 1\n", | |
"\n", | |
" # Write encode() which is used to build the dataset\n", | |
" def encode(sentence):\n", | |
" integers = []\n", | |
" for word in sentence.split():\n", | |
" integers.append(word_to_int[word])\n", | |
" return integers\n", | |
"\n", | |
" var_len_tensors = []\n", | |
" for sentence in combined:\n", | |
" var_len_tensors.append(torch.tensor(encode(sentence)))\n", | |
"\n", | |
" return nn.utils.rnn.pad_sequence(var_len_tensors, batch_first = True)" | |
], | |
"metadata": { | |
"id": "3UhjU433-GUd" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* add a missing import (`typing.List`)\n", | |
"* put an explicit value for token delimiter, and use a constant for that purpose instead of a hardcoded value (aka “magic string”);\n", | |
"* use a regex pattern for token delimiter;\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability;\n", | |
"* use [generators][term-generator] (instead of filling up a list, intended to be iterated on later);\n", | |
"* use a function from [`itertools`][itertools], a handy built-in module;\n", | |
"* use [list comprehension][term-list-comprehension] and [dictionary comprehension][term-dictionary-comprehension] expressions to fill up lists and dictionaries;\n", | |
"* use [`enumerate(..., start=1)`][enumerate] for `i` to start from 1 from the get-go, instead of repeatedly incrementing `i + 1`;\n", | |
"* to access a value by key in a dictionary use failsafe method [`dict.get(key[, default_value])`][dict.get] instead of `dict[key]`;\n", | |
"\n", | |
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n", | |
"[itertools]: https://docs.python.org/3/library/itertools.html\n", | |
"[term-list-comprehension]: https://docs.python.org/3/glossary.html#term-list-comprehension\n", | |
"[term-dictionary-comprehension]: https://docs.python.org/3/glossary.html#term-dictionary-comprehension\n", | |
"[enumerate]: https://docs.python.org/3/library/functions.html#enumerate\n", | |
"[dict.get]: https://docs.python.org/3/library/stdtypes.html#dict.get" | |
], | |
"metadata": { | |
"id": "19Ug2qAE-M4I" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution:\n", | |
"\n", | |
" def tokenize_sentence(self, sentence: str) -> str:\n", | |
" yield from [word for word in re.split(TOKEN_DELIMITER, sentence) if word != '']\n", | |
"\n", | |
" def get_all_words(self, *args) -> str:\n", | |
" for sentence in itertools.chain(*args):\n", | |
" yield from self.tokenize_sentence(sentence)\n", | |
"\n", | |
" def get_words_index(self, *args) -> dict:\n", | |
" unique_words = list(set(self.get_all_words(*args)))\n", | |
" unique_words.sort()\n", | |
" sorted_words_index = {w: i for i, w in enumerate(unique_words, start=1)}\n", | |
" return sorted_words_index\n", | |
"\n", | |
" def encode_sentence(self, sentence: str, words_index: dict) -> List[int]:\n", | |
" for word in self.tokenize_sentence(sentence):\n", | |
" yield words_index.get(word, 0)\n", | |
"\n", | |
" def encode_sentences(self, *args) -> TensorType[int]:\n", | |
" words_index = self.get_words_index(*args)\n", | |
" for sentence in itertools.chain(*args):\n", | |
" encoded_sentence = self.encode_sentence(sentence, words_index)\n", | |
" yield torch.tensor(list(encoded_sentence))\n", | |
"\n", | |
" def get_dataset(self, positive: List[str], negative: List[str]) -> TensorType[float]:\n", | |
" var_len_tensors = self.encode_sentences(positive, negative)\n", | |
" return nn.utils.rnn.pad_sequence(list(var_len_tensors), batch_first=True)" | |
], | |
"metadata": { | |
"id": "mUCGxSKQ-REH" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.7. Sentiment analysis" | |
], | |
"metadata": { | |
"id": "gYwupvPyX8kp" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "oy5KNILoYGVY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode Solution: https://neetcode.io/problems/sentiment-analysis\n", | |
"class Solution(nn.Module):\n", | |
" def __init__(self, vocabulary_size: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.embedding_layer = nn.Embedding(vocabulary_size, 16)\n", | |
" self.linear_layer = nn.Linear(16, 1)\n", | |
" self.sigmoid_layer = nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, x: TensorType[int]) -> TensorType[float]:\n", | |
" embeddings = self.embedding_layer(x)\n", | |
" averaged = torch.mean(embeddings, axis = 1)\n", | |
" projected = self.linear_layer(averaged)\n", | |
" return torch.round(self.sigmoid_layer(projected), decimals=4)" | |
], | |
"metadata": { | |
"id": "36JSmn_qYJPS" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n", | |
"* use [`torch.mean(..., dim=...)`][torch.mean] instead of `torch.mean(..., axis=...)` which doesn't exist;\n", | |
"* shorter and less complex instructions per line of code (which is especially handy for testing).\n", | |
"\n", | |
"[torch.mean]: https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean" | |
], | |
"metadata": { | |
"id": "-bk4IcqRYrPv" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution(nn.Module):\n", | |
"\n", | |
" def __init__(self, vocabulary_size: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.embedding_layer = nn.Embedding(num_embeddings=vocabulary_size, embedding_dim=EMB_LAYER_LEN)\n", | |
" self.linear_layer = nn.Linear(EMB_LAYER_LEN, OUTPUT_LAYER_LEN)\n", | |
" self.sigmoid_layer = nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, x: TensorType[int]) -> TensorType[float]:\n", | |
" ret = self.embedding_layer(x)\n", | |
" ret = torch.mean(ret, dim=T_DIM_INDEX)\n", | |
" ret = self.linear_layer(ret)\n", | |
" ret = self.sigmoid_layer(ret)\n", | |
" ret = torch.round(ret, decimals=DECIMAL_PRECISION)\n", | |
" return ret" | |
], | |
"metadata": { | |
"id": "MVrJ3JNgY8qC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 0.8. GPT dataset" | |
], | |
"metadata": { | |
"id": "GmGRqozKlHL9" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode Solution" | |
], | |
"metadata": { | |
"id": "JoAeSYDjocBZ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/gpt-dataset\n", | |
"class Solution:\n", | |
" def batch_loader(self, raw_dataset: str, context_length: int, batch_size: int) -> Tuple[List[List[str]]]:\n", | |
" torch.manual_seed(0)\n", | |
" tokenized = raw_dataset.split()\n", | |
" indices = torch.randint(low=0, high=len(tokenized) - context_length, size=(batch_size,)).tolist()\n", | |
" X = []\n", | |
" Y = []\n", | |
" for idx in indices:\n", | |
" X.append(tokenized[idx:idx+context_length])\n", | |
" Y.append(tokenized[idx+1:idx+1+context_length])\n", | |
" return X, Y" | |
], | |
"metadata": { | |
"id": "lx5Cl8DoohFV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* put an explicit value for token delimiter, and use a constant for that purpose instead of a hardcoded value (aka “magic string”);\n", | |
"* use a regex pattern for token delimiter;\n", | |
"* improve readability by breaking down a complex one-liner instruction (`indices = ...`) into multiple lines of code;\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability: here, X and Y lists are built using same function, called twice, each time with a different value for the `offset=` parameter;\n", | |
"* use [generators][term-generator] (instead of filling up a list, intended to be iterated on later);\n", | |
"* use [list comprehension][term-list-comprehension] and [dictionary comprehension][term-dictionary-comprehension] expressions to fill up lists and dictionaries;\n", | |
"\n", | |
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n", | |
"[term-list-comprehension]: https://docs.python.org/3/glossary.html#term-list-comprehension\n", | |
"[term-dictionary-comprehension]: https://docs.python.org/3/glossary.html#term-dictionary-comprehension" | |
], | |
"metadata": { | |
"id": "EGEjE6X_ovWA" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Solution:\n", | |
"\n", | |
" def tokenize(self, text: str) -> str:\n", | |
" yield from [word for word in re.split(TOKEN_DELIMITER, text) if word != '']\n", | |
"\n", | |
" def build_batch(self, words: List[str], context_length: int, indices: TensorType[int], offset: int=0) -> List[List[str]]:\n", | |
" for i in indices:\n", | |
" yield words[i + offset:i + offset + context_length]\n", | |
"\n", | |
" def batch_loader(self, raw_dataset: str,\n", | |
" context_length: int,\n", | |
" batch_size: int) -> Tuple[List[List[str]]]:\n", | |
" words = list(self.tokenize(raw_dataset))\n", | |
" torch.manual_seed(0)\n", | |
" indices = torch.randint(low=0,\n", | |
" high=len(words) - context_length,\n", | |
" size=(batch_size,))\\\n", | |
" .tolist()\n", | |
" X = list(self.build_batch(words, context_length, indices))\n", | |
" Y = list(self.build_batch(words, context_length, indices, offset=1))\n", | |
" return X, Y" | |
], | |
"metadata": { | |
"id": "mRg1R4ggozIF" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 1. Build a homemade GPT from scratch" | |
], | |
"metadata": { | |
"id": "cCDTG28v6I7v" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 1.1. Self-attention class" | |
], | |
"metadata": { | |
"id": "0W8A0qXT6R6i" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "pA6FtZzS7EIO" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/self-attention\n", | |
"class SingleHeadAttention(nn.Module):\n", | |
"\n", | |
" def __init__(self, embedding_dim: int, attention_dim: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.key_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n", | |
" self.query_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n", | |
" self.value_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" k = self.key_gen(embedded)\n", | |
" q = self.query_gen(embedded)\n", | |
" v = self.value_gen(embedded)\n", | |
"\n", | |
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n", | |
" context_length, attention_dim = k.shape[1], k.shape[2]\n", | |
" scores = scores / (attention_dim ** 0.5)\n", | |
"\n", | |
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n", | |
" mask = lower_triangular == 0\n", | |
" scores = scores.masked_fill(mask, float('-inf'))\n", | |
" scores = nn.functional.softmax(scores, dim = 2)\n", | |
"\n", | |
" return torch.round(scores @ v, decimals=4)" | |
], | |
"metadata": { | |
"id": "S0AIVkl47J6D" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* better flexibility, reusability, extensibility: add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n", | |
"* no over-optimization: use `math.sqrt(x)` instead of `x ** 0.5`;\n", | |
" * Why? Computing a square root with `x ** 0.5` as an attempt at optimization one'd naturally do in C or C++ is actually an example of a technically wrong choice when it comes to Python: in fact `math.sqrt(x)`, which calls a C binary under the hood, is [significantly faster][math.sqrt] than `x ** 0.5` since the earliest release of Python 3.\n", | |
" * `numpy` module also has an equivalent function and much more. Therefore, use `numpy` if `numpy` has already been loaded earlier or is intended to be loaded later.\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability;\n", | |
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n", | |
"\n", | |
"[math.sqrt]: https://stackoverflow.com/questions/327002/which-is-faster-in-python-x-5-or-math-sqrtx/327048#327048" | |
], | |
"metadata": { | |
"id": "Sms2SRPj7MAt" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class MySingleHeadAttention(nn.Module):\n", | |
"\n", | |
" def __init__(self, model_dim: int, head_size: int, round_output: bool=False):\n", | |
" super().__init__()\n", | |
" # torch.manual_seed(0)\n", | |
" self.key_layer = nn.Linear(model_dim, head_size, bias=False)\n", | |
" self.query_layer = nn.Linear(model_dim, head_size, bias=False)\n", | |
" self.value_layer = nn.Linear(model_dim, head_size, bias=False)\n", | |
" self.round_output = round_output\n", | |
"\n", | |
" def bool_mask(self, context_length: int, filler_val: int, device) -> TensorType[bool]:\n", | |
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n", | |
" return (lower_triangular == filler_val).to(device)\n", | |
"\n", | |
" which_power_for_e = lambda self, x: np.log(x) if x != 0 else -np.inf\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" k = self.key_layer(embedded)\n", | |
" q = self.query_layer(embedded)\n", | |
" v = self.value_layer(embedded)\n", | |
" context_length, attention_dim = k.shape[1], k.shape[2]\n", | |
" scores = q @ torch.transpose(k, 1, 2)\n", | |
" scores = scores / np.sqrt(attention_dim)\n", | |
" mask = self.bool_mask(context_length, MASK_FILLER, DEVICE)\n", | |
" scores = scores.masked_fill(mask, self.which_power_for_e(MASK_FILLER))\n", | |
" scores = nn.functional.softmax(scores, dim=2)\n", | |
" scores = scores @ v\n", | |
" scores = torch.round(scores, decimals=DECIMAL_PRECISION) if self.round_output else scores\n", | |
" return scores" | |
], | |
"metadata": { | |
"id": "js9bPAv47PnB" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 1.2. Multi-headed self-attention class" | |
], | |
"metadata": { | |
"id": "1khI9cVg7q3x" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "YSbZMCte7q32" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/multi-headed-self-attention\n", | |
"class MultiHeadedSelfAttention(nn.Module):\n", | |
"\n", | |
" def __init__(self, embedding_dim: int, attention_dim: int, num_heads: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.attention_heads = nn.ModuleList()\n", | |
" for i in range(num_heads):\n", | |
" self.attention_heads.append(self.SingleHeadAttention(embedding_dim, attention_dim // num_heads))\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" head_outputs = []\n", | |
" for head in self.attention_heads:\n", | |
" head_outputs.append(head(embedded))\n", | |
" concatenated = torch.cat(head_outputs, dim = 2)\n", | |
" return torch.round(concatenated, decimals=4)\n", | |
"\n", | |
" class SingleHeadAttention(nn.Module):\n", | |
" def __init__(self, embedding_dim: int, attention_dim: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.key_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n", | |
" self.query_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n", | |
" self.value_gen = nn.Linear(embedding_dim, attention_dim, bias=False)\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" k = self.key_gen(embedded)\n", | |
" q = self.query_gen(embedded)\n", | |
" v = self.value_gen(embedded)\n", | |
"\n", | |
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n", | |
" context_length, attention_dim = k.shape[1], k.shape[2]\n", | |
" scores = scores / (attention_dim ** 0.5)\n", | |
"\n", | |
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n", | |
" mask = lower_triangular == 0\n", | |
" scores = scores.masked_fill(mask, float('-inf'))\n", | |
" scores = nn.functional.softmax(scores, dim = 2)\n", | |
"\n", | |
" return scores @ v" | |
], | |
"metadata": { | |
"id": "tp5Oxy5V7q35" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* prefer classes aggregation (aka weak coupling) to composition (aka strong coupling): instead of being `MultiHeadedSelfAttention`'s inner class, `SingleHeadAttention` should be independent, which is better for flexibility, reusability, extensibility;\n", | |
"* add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability;\n", | |
"* use [generators][term-generator] (instead of filling up a list, intended to be iterated on later);\n", | |
"\n", | |
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n" | |
], | |
"metadata": { | |
"id": "4vKZznxR7q37" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class MyMultiHeadedSelfAttention(nn.Module):\n", | |
"\n", | |
" def __init__(self, model_dim: int, num_heads: int, round_output: bool=False):\n", | |
" super().__init__()\n", | |
" # torch.manual_seed(0)\n", | |
" head_size = model_dim // num_heads\n", | |
" self.attention_heads = nn.ModuleList()\n", | |
" for _ in range(num_heads):\n", | |
" self.attention_heads.append(MySingleHeadAttention(model_dim, head_size))\n", | |
" self.compute = nn.Linear(model_dim, model_dim)\n", | |
" self.dropout = nn.Dropout(p=DROPOUT_PROBABILITY)\n", | |
" self.round_output = round_output\n", | |
"\n", | |
" def get_head_outputs(self, embedded):\n", | |
" for att_head in self.attention_heads:\n", | |
" yield att_head(embedded)\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" head_outputs = self.get_head_outputs(embedded)\n", | |
" ret = torch.cat(list(head_outputs), dim=2)\n", | |
" ret = self.compute(ret)\n", | |
" ret = self.dropout(ret)\n", | |
" ret = torch.round(ret, decimals=DECIMAL_PRECISION) if self.round_output else ret\n", | |
" return ret" | |
], | |
"metadata": { | |
"id": "SkEbE7T17q3-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 1.3. Transformer block class" | |
], | |
"metadata": { | |
"id": "QGZjrjHW8IRg" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "MFMwmN7K8IRq" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/transformer-block\n", | |
"class TransformerBlock(nn.Module):\n", | |
"\n", | |
" def __init__(self, model_dim: int, num_heads: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.mhsa = self.MultiHeadedSelfAttention(model_dim, num_heads)\n", | |
" self.vanilla_nn = self.VanillaNeuralNetwork(model_dim)\n", | |
" self.layer_norm_one = nn.LayerNorm(model_dim)\n", | |
" self.layer_norm_two = nn.LayerNorm(model_dim)\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" # Round answer to 4 decimal places\n", | |
" torch.manual_seed(0)\n", | |
" embedded = embedded + self.mhsa(self.layer_norm_one(embedded)) # skip connection\n", | |
" embedded = embedded + self.vanilla_nn(self.layer_norm_two(embedded)) # another skip connection\n", | |
" return torch.round(embedded, decimals=4)\n", | |
"\n", | |
"\n", | |
" class MultiHeadedSelfAttention(nn.Module):\n", | |
"\n", | |
" class SingleHeadAttention(nn.Module):\n", | |
" def __init__(self, model_dim: int, head_size: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.key_gen = nn.Linear(model_dim, head_size, bias=False)\n", | |
" self.query_gen = nn.Linear(model_dim, head_size, bias=False)\n", | |
" self.value_gen = nn.Linear(model_dim, head_size, bias=False)\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" k = self.key_gen(embedded)\n", | |
" q = self.query_gen(embedded)\n", | |
" v = self.value_gen(embedded)\n", | |
"\n", | |
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n", | |
" context_length, attention_dim = k.shape[1], k.shape[2]\n", | |
" scores = scores / (attention_dim ** 0.5)\n", | |
"\n", | |
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n", | |
" mask = lower_triangular == 0\n", | |
" scores = scores.masked_fill(mask, float('-inf'))\n", | |
" scores = nn.functional.softmax(scores, dim = 2)\n", | |
"\n", | |
" return scores @ v\n", | |
"\n", | |
" def __init__(self, model_dim: int, num_heads: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.attention_heads = nn.ModuleList()\n", | |
" for i in range(num_heads):\n", | |
" self.attention_heads.append(self.SingleHeadAttention(model_dim, model_dim // num_heads))\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" head_outputs = []\n", | |
" for head in self.attention_heads:\n", | |
" head_outputs.append(head(embedded))\n", | |
" concatenated = torch.cat(head_outputs, dim = 2)\n", | |
" return concatenated\n", | |
"\n", | |
" class VanillaNeuralNetwork(nn.Module):\n", | |
"\n", | |
" def __init__(self, model_dim: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.first_linear_layer = nn.Linear(model_dim, model_dim * 4)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.second_linear_layer = nn.Linear(model_dim * 4, model_dim)\n", | |
" self.dropout = nn.Dropout(0.2) # using p = 0.2\n", | |
"\n", | |
" def forward(self, x: TensorType[float]) -> TensorType[float]:\n", | |
" torch.manual_seed(0)\n", | |
" return self.dropout(self.second_linear_layer(self.relu(self.first_linear_layer(x))))" | |
], | |
"metadata": { | |
"id": "c8YatG_r8IRt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* prefer classes aggregation (aka weak coupling) to composition (aka strong coupling): each class should be independent from the others, which is better for flexibility, reusability, extensibility;\n", | |
"* add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n", | |
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n" | |
], | |
"metadata": { | |
"id": "zx3pgmiF8IRv" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class MyTransformerBlock(nn.Module):\n", | |
"\n", | |
" def __init__(self, model_dim: int, num_heads: int, round_output: bool=False):\n", | |
" super().__init__()\n", | |
" # torch.manual_seed(0)\n", | |
" self.mhsa = MyMultiHeadedSelfAttention(model_dim, num_heads)\n", | |
" self.vanilla_nn = MyVanillaNeuralNetwork(model_dim)\n", | |
" self.layer_norm_one = nn.LayerNorm(model_dim)\n", | |
" self.layer_norm_two = nn.LayerNorm(model_dim)\n", | |
" self.round_output = round_output\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" # torch.manual_seed(0)\n", | |
" embedded += self.mhsa(self.layer_norm_one(embedded))\n", | |
" embedded += self.vanilla_nn(self.layer_norm_two(embedded))\n", | |
" embedded = torch.round(embedded, decimals=DECIMAL_PRECISION) if self.round_output else embedded\n", | |
" return embedded\n", | |
"\n", | |
"\n", | |
"class MyVanillaNeuralNetwork(nn.Module):\n", | |
"\n", | |
" def __init__(self, model_dim: int):\n", | |
" super().__init__()\n", | |
" # torch.manual_seed(0)\n", | |
" self.first_linear_layer = nn.Linear(model_dim, model_dim * LINEAR_NN_SCALE_FACTOR)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.second_linear_layer = nn.Linear(model_dim * LINEAR_NN_SCALE_FACTOR, model_dim)\n", | |
" self.dropout = nn.Dropout(p=DROPOUT_PROBABILITY)\n", | |
"\n", | |
" def forward(self, x: TensorType[float]) -> TensorType[float]:\n", | |
" # torch.manual_seed(0)\n", | |
" ret = self.first_linear_layer(x)\n", | |
" ret = self.relu(ret)\n", | |
" ret = self.second_linear_layer(ret)\n", | |
" ret = self.dropout(ret)\n", | |
" return ret" | |
], | |
"metadata": { | |
"id": "KofH8zTU8IRy" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 1.4. GPT model class" | |
], | |
"metadata": { | |
"id": "gSyqjSYm8Wrq" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "gYozYnOQ8Wru" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/code-gpt\n", | |
"class GPT(nn.Module):\n", | |
"\n", | |
" def __init__(self, vocab_size: int, context_length: int, model_dim: int, num_blocks: int, num_heads: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.word_embeddings = nn.Embedding(vocab_size, model_dim)\n", | |
" self.position_embeddings = nn.Embedding(context_length, model_dim)\n", | |
" self.transformer_blocks = nn.Sequential()\n", | |
" for i in range(num_blocks):\n", | |
" self.transformer_blocks.append(self.TransformerBlock(model_dim, num_heads))\n", | |
" self.layer_norm_three = nn.LayerNorm(model_dim)\n", | |
" self.vocab_projection = nn.Linear(model_dim, vocab_size)\n", | |
"\n", | |
" def forward(self, context: TensorType[int]) -> TensorType[float]:\n", | |
" torch.manual_seed(0)\n", | |
" embedded = self.word_embeddings(context)\n", | |
" context_length = context.shape[1]\n", | |
" positions = torch.arange(context_length)\n", | |
" embedded = embedded + self.position_embeddings(positions)\n", | |
"\n", | |
" raw_output = self.vocab_projection(self.layer_norm_three(self.transformer_blocks(embedded)))\n", | |
" # raw_output is batch by context_length by vocab_size\n", | |
"\n", | |
" probabilities = nn.functional.softmax(raw_output, dim = -1)\n", | |
" return torch.round(probabilities, decimals=4)\n", | |
"\n", | |
" class TransformerBlock(nn.Module):\n", | |
"\n", | |
" class MultiHeadedSelfAttention(nn.Module):\n", | |
"\n", | |
" class SingleHeadAttention(nn.Module):\n", | |
" def __init__(self, model_dim: int, head_size: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.key_gen = nn.Linear(model_dim, head_size, bias=False)\n", | |
" self.query_gen = nn.Linear(model_dim, head_size, bias=False)\n", | |
" self.value_gen = nn.Linear(model_dim, head_size, bias=False)\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" k = self.key_gen(embedded)\n", | |
" q = self.query_gen(embedded)\n", | |
" v = self.value_gen(embedded)\n", | |
"\n", | |
" scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()\n", | |
" context_length, attention_dim = k.shape[1], k.shape[2]\n", | |
" scores = scores / (attention_dim ** 0.5)\n", | |
"\n", | |
" lower_triangular = torch.tril(torch.ones(context_length, context_length))\n", | |
" mask = lower_triangular == 0\n", | |
" scores = scores.masked_fill(mask, float('-inf'))\n", | |
" scores = nn.functional.softmax(scores, dim = 2)\n", | |
"\n", | |
" return scores @ v\n", | |
"\n", | |
" def __init__(self, model_dim: int, num_heads: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.attention_heads = nn.ModuleList()\n", | |
" for i in range(num_heads):\n", | |
" self.attention_heads.append(self.SingleHeadAttention(model_dim, model_dim // num_heads))\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" head_outputs = []\n", | |
" for head in self.attention_heads:\n", | |
" head_outputs.append(head(embedded))\n", | |
" concatenated = torch.cat(head_outputs, dim = 2)\n", | |
" return concatenated\n", | |
"\n", | |
" class VanillaNeuralNetwork(nn.Module):\n", | |
"\n", | |
" def __init__(self, model_dim: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.first_linear_layer = nn.Linear(model_dim, model_dim * 4)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.second_linear_layer = nn.Linear(model_dim * 4, model_dim)\n", | |
" self.dropout = nn.Dropout(0.2) # using p = 0.2\n", | |
"\n", | |
" def forward(self, x: TensorType[float]) -> TensorType[float]:\n", | |
" torch.manual_seed(0)\n", | |
" return self.dropout(self.second_linear_layer(self.relu(self.first_linear_layer(x))))\n", | |
"\n", | |
" def __init__(self, model_dim: int, num_heads: int):\n", | |
" super().__init__()\n", | |
" torch.manual_seed(0)\n", | |
" self.mhsa = self.MultiHeadedSelfAttention(model_dim, num_heads)\n", | |
" self.vanilla_nn = self.VanillaNeuralNetwork(model_dim)\n", | |
" self.layer_norm_one = nn.LayerNorm(model_dim)\n", | |
" self.layer_norm_two = nn.LayerNorm(model_dim)\n", | |
"\n", | |
" def forward(self, embedded: TensorType[float]) -> TensorType[float]:\n", | |
" torch.manual_seed(0)\n", | |
" embedded = embedded + self.mhsa(self.layer_norm_one(embedded)) # skip connection\n", | |
" embedded = embedded + self.vanilla_nn(self.layer_norm_two(embedded)) # another skip connection\n", | |
" return embedded" | |
], | |
"metadata": { | |
"id": "q0h-utP08Wrv" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* prefer classes aggregation (aka weak coupling) to composition (aka strong coupling): each class should be independent from the others, which is better for flexibility, reusability, extensibility;\n", | |
"* add a Boolean parameter to the constructor function to decide whether the output should be rounded or not;\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability;\n", | |
"* shorter and less complex instructions per line of code (which is especially handy for testing);\n", | |
"* no “magic” numbers or strings: use constants instead of hardcoded values;\n" | |
], | |
"metadata": { | |
"id": "09I7ov_N8Wrv" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class MyGPT(nn.Module):\n", | |
"\n", | |
" def __init__(self, vocab_size: int, context_length: int, model_dim: int, num_blocks: int, num_heads: int, round_output: bool=False):\n", | |
" super().__init__()\n", | |
" # torch.manual_seed(0)\n", | |
" self.token_embedding = nn.Embedding(vocab_size, model_dim)\n", | |
" self.pos_embedding = nn.Embedding(context_length, model_dim)\n", | |
" self.transformer_blocks = nn.Sequential()\n", | |
" for _ in range(num_blocks):\n", | |
" self.transformer_blocks.append(MyTransformerBlock(model_dim, num_heads))\n", | |
" self.layer_norm_three = nn.LayerNorm(model_dim)\n", | |
" self.vocab_projection = nn.Linear(model_dim, vocab_size)\n", | |
" self.round_output = round_output\n", | |
"\n", | |
" def get_positions(self, context, device):\n", | |
" context_len = context.shape[1]\n", | |
" return torch.arange(context_len).to(device)\n", | |
"\n", | |
" def forward(self, context: TensorType[int]) -> TensorType[float]:\n", | |
" # torch.manual_seed(0)\n", | |
" embedded = self.token_embedding(context)\n", | |
" positions = self.get_positions(context, DEVICE)\n", | |
" embedded += self.pos_embedding(positions)\n", | |
" ret = self.transformer_blocks(embedded)\n", | |
" ret = self.layer_norm_three(ret)\n", | |
" ret = self.vocab_projection(ret)\n", | |
" ret = torch.round(ret, decimals=DECIMAL_PRECISION) if self.round_output else ret\n", | |
" return ret" | |
], | |
"metadata": { | |
"id": "XgubHNYe8Wrw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 1.5. Make GPT talk back" | |
], | |
"metadata": { | |
"id": "fRGhx_by8xiX" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### NeetCode solution" | |
], | |
"metadata": { | |
"id": "382xJUI88xic" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# NeetCode solution: https://neetcode.io/problems/make-gpt-talk-back\n", | |
"class Solution:\n", | |
" def generate(self, model, new_chars: int, context: TensorType[int], context_length: int, int_to_char: dict) -> str:\n", | |
" generator = torch.manual_seed(0)\n", | |
" initial_state = generator.get_state()\n", | |
" res = []\n", | |
" for i in range(new_chars):\n", | |
" if len(context.T) > context_length:\n", | |
" context = context[:, -context_length:]\n", | |
" prediction = model(context) # B, T, Vocab_Size\n", | |
" last_time_step = prediction[:, -1, :] # B, Vocab_Size\n", | |
" probabilities = nn.functional.softmax(last_time_step, dim = -1)\n", | |
" next_char = torch.multinomial(probabilities, 1, generator=generator)\n", | |
" generator.set_state(initial_state)\n", | |
" context = torch.cat((context, next_char), dim = -1)\n", | |
" res.append(int_to_char[next_char.item()])\n", | |
" return ''.join(res)" | |
], | |
"metadata": { | |
"id": "2aYvYuLe8xie" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### Refactored solution\n", | |
"* return a random result or a follow-up to a prompt;\n", | |
"* use a [generator][term-generator] to yield the result bit by bit, instead of waiting for the result to be complete to return it;\n", | |
"* wrap repeated instructions into separate functions for better readability and reusability;\n", | |
"\n", | |
"[term-generator]: https://docs.python.org/3/glossary.html#term-generator\n" | |
], | |
"metadata": { | |
"id": "0-kTjPo18xig" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Runner:\n", | |
"\n", | |
" def encode_prompt(self, prompt: str, words_index: tuple, device: torch.device) -> TensorType[int]:\n", | |
" context = torch.zeros(1, 1, dtype=torch.int64).to(device)\n", | |
" for char_to_encode in prompt:\n", | |
" char_index = words_index.index(char_to_encode)\n", | |
" context = torch.cat((context, torch.tensor([[next_predicted_index]])), dim=-1)\n", | |
" return context\n", | |
"\n", | |
" def left_trim_context(self, context: TensorType[int], context_length: int) -> TensorType[int]:\n", | |
" if len(context.T) > context_length:\n", | |
" context = context[:, -context_length:]\n", | |
" return context\n", | |
"\n", | |
" def generate(self, model, output_len: int, context_length: int, context: TensorType[int]=None) -> int:\n", | |
" if context is None:\n", | |
" context = torch.zeros(1, 1, dtype=torch.int64)\n", | |
" for i in range(output_len):\n", | |
" context = self.left_trim_context(context, context_length)\n", | |
" prediction = model(context) # B x T x Vocab_Size\n", | |
" last_time_step = prediction[:, -1, :]\n", | |
" probabilities = nn.functional.softmax(last_time_step, dim=-1)\n", | |
" next_predicted_index = torch.multinomial(probabilities, 1)\n", | |
" context = torch.cat((context, next_predicted_index), dim=-1)\n", | |
" yield next_predicted_index.item()" | |
], | |
"metadata": { | |
"id": "82wwnK_U8xii" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 1.6. Download and plug in the pre-trained model" | |
], | |
"metadata": { | |
"id": "vX6JVta_tuCm" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%cd /content\n", | |
"!git clone https://github.com/gptandchill/drake-lyric-generator\n", | |
"%cd drake-lyric-generator\n", | |
"%ls" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "P8wDHQmot0M2", | |
"outputId": "214f6211-303c-4ed1-e903-d61dec1bd142" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n", | |
"Cloning into 'drake-lyric-generator'...\n", | |
"remote: Enumerating objects: 3, done.\u001b[K\n", | |
"remote: Counting objects: 100% (3/3), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (2/2), done.\u001b[K\n", | |
"remote: Total 3 (delta 0), reused 0 (delta 0), pack-reused 0\u001b[K\n", | |
"Receiving objects: 100% (3/3), 16.53 MiB | 23.51 MiB/s, done.\n", | |
"/content/drake-lyric-generator\n", | |
"weights.pt\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Define the hyperparameters, instantiate the model, and load in the weights from training. The prior cell downloads weights.pt into this Colab runtime." | |
], | |
"metadata": { | |
"id": "ckqEPmJUsnfF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"path_to_pre_trained_model = 'weights.pt'\n", | |
"vocab_size = len(WORDS_INDEX)\n", | |
"context_length = 128\n", | |
"model_dim = 252\n", | |
"num_blocks = 6\n", | |
"num_heads = 6\n", | |
"\n", | |
"model = MyGPT(vocab_size, context_length, model_dim, num_blocks, num_heads).to(DEVICE)\n", | |
"pre_trained_model_state_dict = torch.load(path_to_pre_trained_model, map_location=DEVICE)\n", | |
"model.load_state_dict(pre_trained_model_state_dict)\n", | |
"_ = model.eval()" | |
], | |
"metadata": { | |
"id": "aBPCi79SskEn" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 1.7. Run model to generate lyrics" | |
], | |
"metadata": { | |
"id": "Sijuqs-xMi1n" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### 1.7.1. Generate random lyrics" | |
], | |
"metadata": { | |
"id": "OmYfLP60tPrS" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"output_len = 150\n", | |
"for next_predicted_index in Runner().generate(model, output_len, context_length):\n", | |
" next_char = WORDS_INDEX[next_predicted_index]\n", | |
" print('', end=next_char)" | |
], | |
"metadata": { | |
"id": "qpwXg0YitK0w", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "17142ad0-6c91-413c-878c-812ba67b55fe" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"You don't know that\n", | |
"Put the was fuckin' attien\n", | |
"And I got still like a mil with the fireworks\n", | |
"Tatin' wifey, they don't even regrin' it so like my kill " | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"#### 1.7.2. Generate lyrics as a follow-up to a prompt" | |
], | |
"metadata": { | |
"id": "zscxLW_70gex" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"output_len = 150\n", | |
"prompt = \"I was born to \"\n", | |
"context = Runner().encode_prompt(prompt, WORDS_INDEX, DEVICE)\n", | |
"print(prompt, end='')\n", | |
"for next_predicted_index in Runner().generate(model, output_len, context_length, context):\n", | |
" next_char = WORDS_INDEX[next_predicted_index]\n", | |
" print('', end=next_char)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "ff7c5533-e5a7-43bb-f4b6-0a2e19d4297f", | |
"id": "SsrBY2Wb0gez" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"I was born to anothough, got the just like I love you s\n", | |
"It's all your soldier flow it\n", | |
"Uppy, now after tryna distage is to getting if a madrista\n", | |
"There's so smoker fe" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment