Skip to content

Instantly share code, notes, and snippets.

@GVRV
Created August 2, 2024 14:42
Show Gist options
  • Save GVRV/24b2af70d3b14409d9a3192a35122cd6 to your computer and use it in GitHub Desktop.
Save GVRV/24b2af70d3b14409d9a3192a35122cd6 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d25adea8-7d12-4ca1-9beb-a31820cc01d2",
"metadata": {},
"outputs": [],
"source": [
"with open('names.txt', 'r') as f:\n",
" names = f.read().splitlines()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4a7c0ab3-1efb-490f-9df4-32f60b49ae97",
"metadata": {},
"outputs": [],
"source": [
"trigram = {}\n",
"for name in names:\n",
" chars = '.' + name + '.'\n",
" for ch1, ch2, ch3 in zip(chars, chars[1:], chars[2:]):\n",
" b = (ch1, ch2, ch3)\n",
" trigram[b] = trigram.get(b, 0) + 1"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "49810a94-76a5-4d6e-a1b9-8c0aac217612",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(('a', 'h', '.'), 1714),\n",
" (('n', 'a', '.'), 1673),\n",
" (('a', 'n', '.'), 1509),\n",
" (('o', 'n', '.'), 1503),\n",
" (('.', 'm', 'a'), 1453),\n",
" (('.', 'j', 'a'), 1255),\n",
" (('.', 'k', 'a'), 1254),\n",
" (('e', 'n', '.'), 1217),\n",
" (('l', 'y', 'n'), 976),\n",
" (('y', 'n', '.'), 953)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted(trigram.items(), key=lambda kv: -kv[1])[:10]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "02e3fd83-670e-47b0-9d67-68831881136a",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "09a079ed-aa4b-4077-85d9-15c4ee582884",
"metadata": {},
"outputs": [],
"source": [
"tokens = ['.'] + sorted(list(set(''.join(names))))\n",
"stoi = {c: i for i, c in enumerate(tokens)}\n",
"itos = {i: c for c, i in stoi.items()}"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "af09e2f8-f85d-4266-af28-ffccf192094d",
"metadata": {},
"outputs": [],
"source": [
"trigram_tensor = torch.zeros((27 * 27, 27), dtype=torch.int32)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "f89fff87-a89b-4e91-b8d1-8c5ef9083a43",
"metadata": {},
"outputs": [],
"source": [
"for name in names:\n",
" chars = '.' + name + '.'\n",
" for ch1, ch2, ch3 in zip(chars, chars[1:], chars[2:]):\n",
" trigram_tensor[(27 * stoi[ch1]) + stoi[ch2], stoi[ch3]] += 1"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "09a4b656-bca4-4974-a517-2fd01e0d9a85",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[203, 337, 0, 0, 0, 331, 0, 0, 0, 271, 0, 0, 0, 1,\n",
" 0, 58, 0, 0, 0, 8, 0, 10, 0, 1, 0, 124, 1],\n",
" [ 6, 32, 0, 0, 0, 7, 0, 0, 0, 10, 0, 0, 0, 0,\n",
" 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0]],\n",
" dtype=torch.int32)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trigram_tensor[336:338,:]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "12f3a432-6470-411d-8ffc-3691e29749f8",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 94,
"id": "d882349b-272b-470e-b686-4457e4157257",
"metadata": {},
"outputs": [],
"source": [
"prob_trigram_tensor = (trigram_tensor + 1).float()"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "389c2cba-2596-4768-9b26-4378a6d020c4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[204., 338., 1., 1., 1., 332., 1., 1., 1., 272., 1., 1.,\n",
" 1., 2., 1., 59., 1., 1., 1., 9., 1., 11., 1., 2.,\n",
" 1., 125., 2.],\n",
" [ 7., 33., 1., 1., 1., 8., 1., 1., 1., 11., 1., 1.,\n",
" 1., 1., 1., 3., 1., 1., 1., 2., 1., 1., 1., 1.,\n",
" 1., 3., 1.]])"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob_trigram_tensor[336:338,:]"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "0dc87947-69f7-4b62-8bf5-413fe3a9aca1",
"metadata": {},
"outputs": [],
"source": [
"prob_trigram_tensor /= prob_trigram_tensor.sum(1, keepdims=True)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "d127fd7f-1211-43ae-a7f4-715b32b9f99a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.1487, 0.2464, 0.0007, 0.0007, 0.0007, 0.2420, 0.0007, 0.0007, 0.0007,\n",
" 0.1983, 0.0007, 0.0007, 0.0007, 0.0015, 0.0007, 0.0430, 0.0007, 0.0007,\n",
" 0.0007, 0.0066, 0.0007, 0.0080, 0.0007, 0.0015, 0.0007, 0.0911, 0.0015],\n",
" [0.0805, 0.3793, 0.0115, 0.0115, 0.0115, 0.0920, 0.0115, 0.0115, 0.0115,\n",
" 0.1264, 0.0115, 0.0115, 0.0115, 0.0115, 0.0115, 0.0345, 0.0115, 0.0115,\n",
" 0.0115, 0.0230, 0.0115, 0.0115, 0.0115, 0.0115, 0.0115, 0.0345, 0.0115]])"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob_trigram_tensor[336:338, :]"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "e076a9b5-407c-4e41-a1dd-70266e9208fc",
"metadata": {},
"outputs": [],
"source": [
"random_trigram_tensor = torch.ones(27 * 27, 27).float()\n",
"random_trigram_tensor /= random_trigram_tensor.sum(1, keepdims=True)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "7615ee06-5203-4bc9-8cd8-a541fcc70193",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ce.\n",
"za.\n",
"zogh.\n",
"uriana.\n",
"kaydnevonimittain.\n",
"luwak.\n",
"ka.\n",
"da.\n",
"samiyah.\n",
"javer.\n"
]
}
],
"source": [
"g = torch.Generator().manual_seed(2147483647)\n",
"for i in range(10):\n",
" out = []\n",
" ix = 0\n",
" while True:\n",
" ich1 = int(ix / 27)\n",
" ich2 = int(ix % 27)\n",
" ch1 = itos[ich1]\n",
" ch2 = itos[ich2]\n",
" # print(f\"Probability distribution for '{ch1}{ch2}': {prob_trigram_tensor[ix, :]}\")\n",
" p = prob_trigram_tensor[ix]\n",
" ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n",
" out.append(itos[ix])\n",
" if ix % 27 == 0:\n",
" break\n",
" ix = (ich2 * 27) + ix\n",
" print(''.join(out))"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "96af7103-7742-4160-9a28-320e9cc9e794",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"log_likelihood=tensor(-410414.9688)\n",
"nll=tensor(410414.9688)\n",
"2.092747449874878\n"
]
}
],
"source": [
"log_likelihood = 0.0\n",
"n = 0\n",
"\n",
"for w in names:\n",
"# for w in [\"andrejq\"]:\n",
" chs = ['.'] + list(w) + ['.']\n",
" for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):\n",
" ix1 = stoi[ch1]\n",
" ix2 = stoi[ch2]\n",
" ix3 = stoi[ch3]\n",
" prob = prob_trigram_tensor[(ix1 * 27) + ix2, ix3]\n",
" logprob = torch.log(prob)\n",
" log_likelihood += logprob\n",
" n += 1\n",
" # print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')\n",
"\n",
"print(f'{log_likelihood=}')\n",
"nll = -log_likelihood\n",
"print(f'{nll=}')\n",
"print(f'{nll/n}')"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "c2cda117-8d12-41a1-b179-e876db1ccebb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of examples: 196113\n"
]
}
],
"source": [
"# create the dataset\n",
"xs, ys = [], []\n",
"for w in names:\n",
" chs = ['.'] + list(w) + ['.']\n",
" for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):\n",
" ix1 = stoi[ch1]\n",
" ix2 = stoi[ch2]\n",
" ix3 = stoi[ch3]\n",
" xs.append((ix1 * 27) + ix2)\n",
" ys.append(ix3)\n",
"xs = torch.tensor(xs)\n",
"ys = torch.tensor(ys)\n",
"num = xs.nelement()\n",
"print('number of examples: ', num)\n",
"\n",
"# initialize the 'network'\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"W = torch.randn((27 * 27, 27), generator=g, requires_grad=True)"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "3aa90b13-011f-42b9-9373-769601997158",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"# gradient descent\n",
"for k in range(1000):\n",
" \n",
" # forward pass\n",
" xenc = F.one_hot(xs, num_classes=27*27).float() # input to the network: one-hot encoding\n",
" logits = xenc @ W # predict log-counts\n",
" counts = logits.exp() # counts, equivalent to N\n",
" probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n",
" loss = -probs[torch.arange(num), ys].log().mean() # + 0.01*(W**2).mean()\n",
" # print(loss.item())\n",
" \n",
" # backward pass\n",
" W.grad = None # set to zero the gradient\n",
" loss.backward()\n",
" \n",
" # update\n",
" W.data += -50 * W.grad"
]
},
{
"cell_type": "code",
"execution_count": 112,
"id": "8f470d40-9d50-4d81-af67-b2532c12fd82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0905237197875977\n"
]
}
],
"source": [
"print(loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "39f675a0-bcd7-4de8-8afa-58e15b787892",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"zexza.\n",
"zoganuriana.\n",
"otah.\n",
"oll.\n",
"imittain.\n"
]
}
],
"source": [
"# finally, sample from the 'neural net' model\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"\n",
"for i in range(5):\n",
" \n",
" out = []\n",
" ix = 0\n",
" while True:\n",
" \n",
" # ----------\n",
" # BEFORE:\n",
" #p = P[ix]\n",
" # ----------\n",
" # NOW:\n",
" xenc = F.one_hot(torch.tensor([ix]), num_classes=27 * 27).float()\n",
" logits = xenc @ W # predict log-counts\n",
" counts = logits.exp() # counts, equivalent to N\n",
" p = counts / counts.sum(1, keepdims=True) # probabilities for next character\n",
" # ----------\n",
" ich1 = int(ix / 27)\n",
" ich2 = ix % 27\n",
" ch1 = itos[ich1]\n",
" ch2 = itos[ich2]\n",
" # print(f\"Probability distribution for '{ch1}{ch2}': {p}\")\n",
" ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n",
" out.append(itos[ix])\n",
" if ix % 27 == 0:\n",
" break\n",
" ix = (ich2 * 27) + ix\n",
" print(''.join(out))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "790a81f7-2c9b-4638-bde8-e44729d3bec0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment