Created
August 2, 2024 14:42
-
-
Save GVRV/24b2af70d3b14409d9a3192a35122cd6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "d25adea8-7d12-4ca1-9beb-a31820cc01d2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with open('names.txt', 'r') as f:\n", | |
" names = f.read().splitlines()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "4a7c0ab3-1efb-490f-9df4-32f60b49ae97", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trigram = {}\n", | |
"for name in names:\n", | |
" chars = '.' + name + '.'\n", | |
" for ch1, ch2, ch3 in zip(chars, chars[1:], chars[2:]):\n", | |
" b = (ch1, ch2, ch3)\n", | |
" trigram[b] = trigram.get(b, 0) + 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "49810a94-76a5-4d6e-a1b9-8c0aac217612", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[(('a', 'h', '.'), 1714),\n", | |
" (('n', 'a', '.'), 1673),\n", | |
" (('a', 'n', '.'), 1509),\n", | |
" (('o', 'n', '.'), 1503),\n", | |
" (('.', 'm', 'a'), 1453),\n", | |
" (('.', 'j', 'a'), 1255),\n", | |
" (('.', 'k', 'a'), 1254),\n", | |
" (('e', 'n', '.'), 1217),\n", | |
" (('l', 'y', 'n'), 976),\n", | |
" (('y', 'n', '.'), 953)]" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sorted(trigram.items(), key=lambda kv: -kv[1])[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "02e3fd83-670e-47b0-9d67-68831881136a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "09a079ed-aa4b-4077-85d9-15c4ee582884", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tokens = ['.'] + sorted(list(set(''.join(names))))\n", | |
"stoi = {c: i for i, c in enumerate(tokens)}\n", | |
"itos = {i: c for c, i in stoi.items()}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "af09e2f8-f85d-4266-af28-ffccf192094d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trigram_tensor = torch.zeros((27 * 27, 27), dtype=torch.int32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "f89fff87-a89b-4e91-b8d1-8c5ef9083a43", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for name in names:\n", | |
" chars = '.' + name + '.'\n", | |
" for ch1, ch2, ch3 in zip(chars, chars[1:], chars[2:]):\n", | |
" trigram_tensor[(27 * stoi[ch1]) + stoi[ch2], stoi[ch3]] += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "09a4b656-bca4-4974-a517-2fd01e0d9a85", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[203, 337, 0, 0, 0, 331, 0, 0, 0, 271, 0, 0, 0, 1,\n", | |
" 0, 58, 0, 0, 0, 8, 0, 10, 0, 1, 0, 124, 1],\n", | |
" [ 6, 32, 0, 0, 0, 7, 0, 0, 0, 10, 0, 0, 0, 0,\n", | |
" 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0]],\n", | |
" dtype=torch.int32)" | |
] | |
}, | |
"execution_count": 36, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"trigram_tensor[336:338,:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "12f3a432-6470-411d-8ffc-3691e29749f8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 94, | |
"id": "d882349b-272b-470e-b686-4457e4157257", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"prob_trigram_tensor = (trigram_tensor + 1).float()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 95, | |
"id": "389c2cba-2596-4768-9b26-4378a6d020c4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[204., 338., 1., 1., 1., 332., 1., 1., 1., 272., 1., 1.,\n", | |
" 1., 2., 1., 59., 1., 1., 1., 9., 1., 11., 1., 2.,\n", | |
" 1., 125., 2.],\n", | |
" [ 7., 33., 1., 1., 1., 8., 1., 1., 1., 11., 1., 1.,\n", | |
" 1., 1., 1., 3., 1., 1., 1., 2., 1., 1., 1., 1.,\n", | |
" 1., 3., 1.]])" | |
] | |
}, | |
"execution_count": 95, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob_trigram_tensor[336:338,:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 96, | |
"id": "0dc87947-69f7-4b62-8bf5-413fe3a9aca1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"prob_trigram_tensor /= prob_trigram_tensor.sum(1, keepdims=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 97, | |
"id": "d127fd7f-1211-43ae-a7f4-715b32b9f99a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[0.1487, 0.2464, 0.0007, 0.0007, 0.0007, 0.2420, 0.0007, 0.0007, 0.0007,\n", | |
" 0.1983, 0.0007, 0.0007, 0.0007, 0.0015, 0.0007, 0.0430, 0.0007, 0.0007,\n", | |
" 0.0007, 0.0066, 0.0007, 0.0080, 0.0007, 0.0015, 0.0007, 0.0911, 0.0015],\n", | |
" [0.0805, 0.3793, 0.0115, 0.0115, 0.0115, 0.0920, 0.0115, 0.0115, 0.0115,\n", | |
" 0.1264, 0.0115, 0.0115, 0.0115, 0.0115, 0.0115, 0.0345, 0.0115, 0.0115,\n", | |
" 0.0115, 0.0230, 0.0115, 0.0115, 0.0115, 0.0115, 0.0115, 0.0345, 0.0115]])" | |
] | |
}, | |
"execution_count": 97, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob_trigram_tensor[336:338, :]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 98, | |
"id": "e076a9b5-407c-4e41-a1dd-70266e9208fc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"random_trigram_tensor = torch.ones(27 * 27, 27).float()\n", | |
"random_trigram_tensor /= random_trigram_tensor.sum(1, keepdims=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 103, | |
"id": "7615ee06-5203-4bc9-8cd8-a541fcc70193", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ce.\n", | |
"za.\n", | |
"zogh.\n", | |
"uriana.\n", | |
"kaydnevonimittain.\n", | |
"luwak.\n", | |
"ka.\n", | |
"da.\n", | |
"samiyah.\n", | |
"javer.\n" | |
] | |
} | |
], | |
"source": [ | |
"g = torch.Generator().manual_seed(2147483647)\n", | |
"for i in range(10):\n", | |
" out = []\n", | |
" ix = 0\n", | |
" while True:\n", | |
" ich1 = int(ix / 27)\n", | |
" ich2 = int(ix % 27)\n", | |
" ch1 = itos[ich1]\n", | |
" ch2 = itos[ich2]\n", | |
" # print(f\"Probability distribution for '{ch1}{ch2}': {prob_trigram_tensor[ix, :]}\")\n", | |
" p = prob_trigram_tensor[ix]\n", | |
" ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n", | |
" out.append(itos[ix])\n", | |
" if ix % 27 == 0:\n", | |
" break\n", | |
" ix = (ich2 * 27) + ix\n", | |
" print(''.join(out))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 104, | |
"id": "96af7103-7742-4160-9a28-320e9cc9e794", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"log_likelihood=tensor(-410414.9688)\n", | |
"nll=tensor(410414.9688)\n", | |
"2.092747449874878\n" | |
] | |
} | |
], | |
"source": [ | |
"log_likelihood = 0.0\n", | |
"n = 0\n", | |
"\n", | |
"for w in names:\n", | |
"# for w in [\"andrejq\"]:\n", | |
" chs = ['.'] + list(w) + ['.']\n", | |
" for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):\n", | |
" ix1 = stoi[ch1]\n", | |
" ix2 = stoi[ch2]\n", | |
" ix3 = stoi[ch3]\n", | |
" prob = prob_trigram_tensor[(ix1 * 27) + ix2, ix3]\n", | |
" logprob = torch.log(prob)\n", | |
" log_likelihood += logprob\n", | |
" n += 1\n", | |
" # print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')\n", | |
"\n", | |
"print(f'{log_likelihood=}')\n", | |
"nll = -log_likelihood\n", | |
"print(f'{nll=}')\n", | |
"print(f'{nll/n}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"id": "c2cda117-8d12-41a1-b179-e876db1ccebb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"number of examples: 196113\n" | |
] | |
} | |
], | |
"source": [ | |
"# create the dataset\n", | |
"xs, ys = [], []\n", | |
"for w in names:\n", | |
" chs = ['.'] + list(w) + ['.']\n", | |
" for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):\n", | |
" ix1 = stoi[ch1]\n", | |
" ix2 = stoi[ch2]\n", | |
" ix3 = stoi[ch3]\n", | |
" xs.append((ix1 * 27) + ix2)\n", | |
" ys.append(ix3)\n", | |
"xs = torch.tensor(xs)\n", | |
"ys = torch.tensor(ys)\n", | |
"num = xs.nelement()\n", | |
"print('number of examples: ', num)\n", | |
"\n", | |
"# initialize the 'network'\n", | |
"g = torch.Generator().manual_seed(2147483647)\n", | |
"W = torch.randn((27 * 27, 27), generator=g, requires_grad=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 111, | |
"id": "3aa90b13-011f-42b9-9373-769601997158", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn.functional as F\n", | |
"# gradient descent\n", | |
"for k in range(1000):\n", | |
" \n", | |
" # forward pass\n", | |
" xenc = F.one_hot(xs, num_classes=27*27).float() # input to the network: one-hot encoding\n", | |
" logits = xenc @ W # predict log-counts\n", | |
" counts = logits.exp() # counts, equivalent to N\n", | |
" probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", | |
" loss = -probs[torch.arange(num), ys].log().mean() # + 0.01*(W**2).mean()\n", | |
" # print(loss.item())\n", | |
" \n", | |
" # backward pass\n", | |
" W.grad = None # set to zero the gradient\n", | |
" loss.backward()\n", | |
" \n", | |
" # update\n", | |
" W.data += -50 * W.grad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 112, | |
"id": "8f470d40-9d50-4d81-af67-b2532c12fd82", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.0905237197875977\n" | |
] | |
} | |
], | |
"source": [ | |
"print(loss.item())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 113, | |
"id": "39f675a0-bcd7-4de8-8afa-58e15b787892", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"zexza.\n", | |
"zoganuriana.\n", | |
"otah.\n", | |
"oll.\n", | |
"imittain.\n" | |
] | |
} | |
], | |
"source": [ | |
"# finally, sample from the 'neural net' model\n", | |
"g = torch.Generator().manual_seed(2147483647)\n", | |
"\n", | |
"for i in range(5):\n", | |
" \n", | |
" out = []\n", | |
" ix = 0\n", | |
" while True:\n", | |
" \n", | |
" # ----------\n", | |
" # BEFORE:\n", | |
" #p = P[ix]\n", | |
" # ----------\n", | |
" # NOW:\n", | |
" xenc = F.one_hot(torch.tensor([ix]), num_classes=27 * 27).float()\n", | |
" logits = xenc @ W # predict log-counts\n", | |
" counts = logits.exp() # counts, equivalent to N\n", | |
" p = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", | |
" # ----------\n", | |
" ich1 = int(ix / 27)\n", | |
" ich2 = ix % 27\n", | |
" ch1 = itos[ich1]\n", | |
" ch2 = itos[ich2]\n", | |
" # print(f\"Probability distribution for '{ch1}{ch2}': {p}\")\n", | |
" ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n", | |
" out.append(itos[ix])\n", | |
" if ix % 27 == 0:\n", | |
" break\n", | |
" ix = (ich2 * 27) + ix\n", | |
" print(''.join(out))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "790a81f7-2c9b-4638-bde8-e44729d3bec0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment