Created
November 23, 2023 16:01
-
-
Save p-i-/73a090f1d00fa367bf39cfa1b333e2cd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "ImportError", | |
"evalue": "cannot import name 'Functional' from 'torch' (/Users/pi/code/m2/fff/.venv/lib/python3.11/site-packages/torch/__init__.py)", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m/Users/pi/code/m2/fff/mnist/x.ipynb Cell 1\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W2sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtorch\u001b[39;00m \u001b[39mimport\u001b[39;00m nn, Functional \u001b[39mas\u001b[39;00m F\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W2sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mmath\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W2sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mth\u001b[39;00m\n", | |
"\u001b[0;31mImportError\u001b[0m: cannot import name 'Functional' from 'torch' (/Users/pi/code/m2/fff/.venv/lib/python3.11/site-packages/torch/__init__.py)" | |
] | |
} | |
], | |
"source": [ | |
"from torch import nn, Functional as F\n", | |
"import math\n", | |
"import torch as th" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "NameError", | |
"evalue": "name 'th' is not defined", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m/Users/pi/code/m2/fff/mnist/x.ipynb Cell 2\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mclass\u001b[39;49;00m \u001b[39mFFF\u001b[39;49;00m(nn\u001b[39m.\u001b[39;49mModule):\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mdef\u001b[39;49;00m \u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m) \u001b[39m-\u001b[39;49m\u001b[39m>\u001b[39;49m \u001b[39mNone\u001b[39;49;00m:\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m()\n", | |
"\u001b[1;32m/Users/pi/code/m2/fff/mnist/x.ipynb Cell 2\u001b[0m line \u001b[0;36m3\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=30'>31</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mw1s \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mParameter(th\u001b[39m.\u001b[39mempty((\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minput_width \u001b[39m*\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_nodes, \u001b[39m1\u001b[39m), dtype\u001b[39m=\u001b[39mth\u001b[39m.\u001b[39mfloat)\u001b[39m.\u001b[39muniform_(\u001b[39m-\u001b[39ml1_init_factor, \u001b[39m+\u001b[39ml1_init_factor), requires_grad\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=31'>32</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mw2s \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mParameter(th\u001b[39m.\u001b[39mempty((\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_nodes, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_width), dtype\u001b[39m=\u001b[39mth\u001b[39m.\u001b[39mfloat)\u001b[39m.\u001b[39muniform_(\u001b[39m-\u001b[39ml2_init_factor, \u001b[39m+\u001b[39ml2_init_factor), requires_grad\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=34'>35</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x: th\u001b[39m.\u001b[39mTensor):\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=35'>36</a>\u001b[0m batch_size \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=37'>38</a>\u001b[0m \u001b[39m# concurrent for batch size (bs, )\u001b[39;00m\n", | |
"\u001b[0;31mNameError\u001b[0m: name 'th' is not defined" | |
] | |
} | |
], | |
"source": [ | |
"class FFF(nn.Module):\n", | |
" def __init__(self) -> None:\n", | |
" super().__init__()\n", | |
"\n", | |
" INPUT_WIDTH = 64\n", | |
" OUTPUT_WIDTH = 32\n", | |
" DEPTH = 4\n", | |
"\n", | |
" self.input_width = INPUT_WIDTH\n", | |
" self.output_width = OUTPUT_WIDTH\n", | |
" self.depth = DEPTH\n", | |
" self.n_nodes = 2**(DEPTH+1) - 1\n", | |
"\n", | |
" self._initiate_weights()\n", | |
"\n", | |
" def _initiate_weights(self):\n", | |
" init_factor_I1 = 1 / math.sqrt(self.input_width)\n", | |
" init_factor_I2 = 1 / math.sqrt(self.depth + 1)\n", | |
"\n", | |
" # shape: (n_nodes, input_width)\n", | |
" # weights for linear layer\n", | |
" self.w1s = nn.Parameter(th.empty(self.n_nodes, self.input_width).uniform_(-init_factor_I1, init_factor_I1), requires_grad=True)\n", | |
"\n", | |
" # weights for regular layer\n", | |
" self.w2s = nn.Parameter(th.empty(self.n_nodes, self.output_width).uniform_(-init_factor_I2, init_factor_I2), requires_grad=True)\n", | |
"\n", | |
" ## Inference weights\n", | |
" l1_init_factor = 1.0 / math.sqrt(self.input_width)\n", | |
" l2_init_factor = 1.0 / math.sqrt(self.n_nodes)\n", | |
"\n", | |
" self.w1s = nn.Parameter(th.empty((self.input_width * self.n_nodes, 1), dtype=th.float).uniform_(-l1_init_factor, +l1_init_factor), requires_grad=True)\n", | |
" self.w2s = nn.Parameter(th.empty((self.n_nodes, self.output_width), dtype=th.float).uniform_(-l2_init_factor, +l2_init_factor), requires_grad=True)\n", | |
"\n", | |
"\n", | |
" def forward(self, x: th.Tensor):\n", | |
" batch_size = x.shape[0]\n", | |
"\n", | |
" # concurrent for batch size (bs, )\n", | |
" current_node = th.zeros((batch_size,), dtype=th.long)\n", | |
"\n", | |
" all_nodes = th.zeros(batch_size, self.depth+1, dtype=th.long)\n", | |
" all_logits = th.empty((batch_size, self.depth+1), dtype=th.float)\n", | |
"\n", | |
" for i in range(self.depth + 1):\n", | |
" # compute plane scores\n", | |
" # dot product between input (x) and weights of the current node (w1s)\n", | |
" # result is scalar of shape (bs)\n", | |
" plane_score = th.einsum('b i, b i -> b', x, self.w1s[current_node])\n", | |
" all_nodes[:, i] = current_node\n", | |
"\n", | |
" # scores are used for gradient propagation and learning decision boundaries\n", | |
" all_logits[:, i] = plane_score\n", | |
"\n", | |
" # compute next node (left or right)\n", | |
" plane_choice = (plane_score > 0).long()\n", | |
" current_node = (current_node * 2) + plane_choice + 1\n", | |
"\n", | |
" # from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L29\n", | |
" # FF_41( GeLU ( FF_14(x) )))\n", | |
"\n", | |
" # GeLU(FF_14 @ x) @ FF_41\n", | |
" # GeLU(W1(x) @ x) @ W2(x)\n", | |
" selected_w2s = self.w2s[all_nodes.flatten()].view(batch_size, self.depth+1, self.output_width)\n", | |
" return th.einsum('b i j , b i -> b j', selected_w2s, F.gelu(all_logits))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment