Skip to content

Instantly share code, notes, and snippets.

@p-i-
Created November 23, 2023 16:01
Show Gist options
  • Save p-i-/73a090f1d00fa367bf39cfa1b333e2cd to your computer and use it in GitHub Desktop.
Save p-i-/73a090f1d00fa367bf39cfa1b333e2cd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'Functional' from 'torch' (/Users/pi/code/m2/fff/.venv/lib/python3.11/site-packages/torch/__init__.py)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/pi/code/m2/fff/mnist/x.ipynb Cell 1\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W2sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtorch\u001b[39;00m \u001b[39mimport\u001b[39;00m nn, Functional \u001b[39mas\u001b[39;00m F\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W2sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mmath\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W2sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mth\u001b[39;00m\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'Functional' from 'torch' (/Users/pi/code/m2/fff/.venv/lib/python3.11/site-packages/torch/__init__.py)"
]
}
],
"source": [
"from torch import nn, Functional as F\n",
"import math\n",
"import torch as th"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'th' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/pi/code/m2/fff/mnist/x.ipynb Cell 2\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mclass\u001b[39;49;00m \u001b[39mFFF\u001b[39;49;00m(nn\u001b[39m.\u001b[39;49mModule):\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mdef\u001b[39;49;00m \u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m) \u001b[39m-\u001b[39;49m\u001b[39m>\u001b[39;49m \u001b[39mNone\u001b[39;49;00m:\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m()\n",
"\u001b[1;32m/Users/pi/code/m2/fff/mnist/x.ipynb Cell 2\u001b[0m line \u001b[0;36m3\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=30'>31</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mw1s \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mParameter(th\u001b[39m.\u001b[39mempty((\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minput_width \u001b[39m*\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_nodes, \u001b[39m1\u001b[39m), dtype\u001b[39m=\u001b[39mth\u001b[39m.\u001b[39mfloat)\u001b[39m.\u001b[39muniform_(\u001b[39m-\u001b[39ml1_init_factor, \u001b[39m+\u001b[39ml1_init_factor), requires_grad\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=31'>32</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mw2s \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mParameter(th\u001b[39m.\u001b[39mempty((\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_nodes, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_width), dtype\u001b[39m=\u001b[39mth\u001b[39m.\u001b[39mfloat)\u001b[39m.\u001b[39muniform_(\u001b[39m-\u001b[39ml2_init_factor, \u001b[39m+\u001b[39ml2_init_factor), requires_grad\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=34'>35</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x: th\u001b[39m.\u001b[39mTensor):\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=35'>36</a>\u001b[0m batch_size \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/pi/code/m2/fff/mnist/x.ipynb#W0sZmlsZQ%3D%3D?line=37'>38</a>\u001b[0m \u001b[39m# concurrent for batch size (bs, )\u001b[39;00m\n",
"\u001b[0;31mNameError\u001b[0m: name 'th' is not defined"
]
}
],
"source": [
"class FFF(nn.Module):\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
"\n",
" INPUT_WIDTH = 64\n",
" OUTPUT_WIDTH = 32\n",
" DEPTH = 4\n",
"\n",
" self.input_width = INPUT_WIDTH\n",
" self.output_width = OUTPUT_WIDTH\n",
" self.depth = DEPTH\n",
" self.n_nodes = 2**(DEPTH+1) - 1\n",
"\n",
" self._initiate_weights()\n",
"\n",
" def _initiate_weights(self):\n",
" init_factor_I1 = 1 / math.sqrt(self.input_width)\n",
" init_factor_I2 = 1 / math.sqrt(self.depth + 1)\n",
"\n",
" # shape: (n_nodes, input_width)\n",
" # weights for linear layer\n",
" self.w1s = nn.Parameter(th.empty(self.n_nodes, self.input_width).uniform_(-init_factor_I1, init_factor_I1), requires_grad=True)\n",
"\n",
" # weights for regular layer\n",
" self.w2s = nn.Parameter(th.empty(self.n_nodes, self.output_width).uniform_(-init_factor_I2, init_factor_I2), requires_grad=True)\n",
"\n",
" ## Inference weights\n",
" l1_init_factor = 1.0 / math.sqrt(self.input_width)\n",
" l2_init_factor = 1.0 / math.sqrt(self.n_nodes)\n",
"\n",
" self.w1s = nn.Parameter(th.empty((self.input_width * self.n_nodes, 1), dtype=th.float).uniform_(-l1_init_factor, +l1_init_factor), requires_grad=True)\n",
" self.w2s = nn.Parameter(th.empty((self.n_nodes, self.output_width), dtype=th.float).uniform_(-l2_init_factor, +l2_init_factor), requires_grad=True)\n",
"\n",
"\n",
" def forward(self, x: th.Tensor):\n",
" batch_size = x.shape[0]\n",
"\n",
" # concurrent for batch size (bs, )\n",
" current_node = th.zeros((batch_size,), dtype=th.long)\n",
"\n",
" all_nodes = th.zeros(batch_size, self.depth+1, dtype=th.long)\n",
" all_logits = th.empty((batch_size, self.depth+1), dtype=th.float)\n",
"\n",
" for i in range(self.depth + 1):\n",
" # compute plane scores\n",
" # dot product between input (x) and weights of the current node (w1s)\n",
" # result is scalar of shape (bs)\n",
" plane_score = th.einsum('b i, b i -> b', x, self.w1s[current_node])\n",
" all_nodes[:, i] = current_node\n",
"\n",
" # scores are used for gradient propagation and learning decision boundaries\n",
" all_logits[:, i] = plane_score\n",
"\n",
" # compute next node (left or right)\n",
" plane_choice = (plane_score > 0).long()\n",
" current_node = (current_node * 2) + plane_choice + 1\n",
"\n",
" # from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L29\n",
" # FF_41( GeLU ( FF_14(x) )))\n",
"\n",
" # GeLU(FF_14 @ x) @ FF_41\n",
" # GeLU(W1(x) @ x) @ W2(x)\n",
" selected_w2s = self.w2s[all_nodes.flatten()].view(batch_size, self.depth+1, self.output_width)\n",
" return th.einsum('b i j , b i -> b j', selected_w2s, F.gelu(all_logits))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment