Created
March 2, 2018 15:42
-
-
Save izmailovpavel/65afa6212f21fe752e48056d8d723f9d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import torch\n", | |
"from torch import nn as nn\n", | |
"from torch.autograd import Variable" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch.utils.hooks.RemovableHandle at 0x7f23101cdc18>" | |
] | |
}, | |
"execution_count": 63, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grads = {}\n", | |
"def save_grad(name):\n", | |
" def hook(grad):\n", | |
" grads[name] = grad\n", | |
" return hook\n", | |
"\n", | |
"def extract_grad(var):\n", | |
" print(var)\n", | |
" print(var.shape)\n", | |
" return var\n", | |
"\n", | |
"n_feat = 10\n", | |
"n_obj = 25\n", | |
"X = np.random.normal(size=(n_obj, n_feat))\n", | |
"y = np.random.randint(low=0, high=10, size=(n_obj))\n", | |
"X_ = Variable(torch.from_numpy(X), requires_grad=True)\n", | |
"y_ = Variable(torch.from_numpy(y))\n", | |
"lsm = nn.LogSoftmax(dim=1)(X_)\n", | |
"l = nn.NLLLoss()(lsm, y_)\n", | |
"\n", | |
"lsm.register_hook(save_grad('lsm'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"l.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Variable containing:\n", | |
"1.00000e-02 *\n", | |
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000\n", | |
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
" -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
"[torch.DoubleTensor of size 25x10]" | |
] | |
}, | |
"execution_count": 65, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grads['lsm']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Variable containing:\n", | |
"1.00000e-02 *\n", | |
" 0.8155 0.2794 -3.8968 0.5208 0.1211 0.1943 0.5007 0.3851 0.1434 0.9365\n", | |
" 0.5790 0.1383 0.5933 0.6330 0.4394 0.6952 0.1665 0.0873 -3.5001 0.1682\n", | |
" 0.5062 0.0578 0.0737 1.0623 1.3683 0.1590 0.1494 0.5018 0.0833 -3.9618\n", | |
" 0.1304 1.4510 -3.7681 0.0102 0.1075 0.6720 0.5830 0.4104 0.2180 0.1856\n", | |
" 0.1847 0.4890 -3.7672 0.2188 0.9282 0.1364 0.0864 0.4260 0.4284 0.8693\n", | |
" 0.0223 -3.5900 1.6741 0.1133 0.0790 0.3652 0.1131 0.0791 0.7763 0.3677\n", | |
" 0.1598 0.1576 0.2000 0.1278 -3.8787 0.7640 0.4842 0.5352 1.2654 0.1847\n", | |
" 0.6389 0.1485 0.1226 -3.6593 0.1309 0.1685 0.3405 1.6761 0.1646 0.2687\n", | |
" 0.6277 -3.3878 0.2061 0.1202 1.0843 0.2062 0.4640 0.0806 0.4340 0.1647\n", | |
" 0.6289 0.2413 0.0680 0.2713 0.2365 0.8447 0.0867 -2.9684 0.2418 0.3492\n", | |
" 0.6839 0.0435 0.2797 0.1037 -3.7896 0.5422 0.6385 0.8123 0.0335 0.6521\n", | |
" 0.2857 0.2350 0.7899 0.2513 0.8590 0.0355 -3.4502 0.2588 0.2522 0.4828\n", | |
" 0.2759 0.1700 0.1678 0.2723 0.0284 0.0710 1.5839 0.7363 -3.4987 0.1932\n", | |
" 0.3499 0.9402 0.4786 0.2367 0.7633 -3.6291 0.3840 0.0831 0.2335 0.1597\n", | |
" 0.2092 0.1013 0.7504 0.2398 0.0702 -3.7920 0.2251 0.3815 1.5150 0.2994\n", | |
" 0.1143 0.0623 0.2665 0.3133 0.5581 -3.8609 0.9103 0.4654 0.4638 0.7068\n", | |
" 0.1304 0.2630 0.3659 1.7599 0.1529 -3.8142 0.3330 0.4531 0.2820 0.0741\n", | |
" 0.0638 0.5095 -3.7011 0.2375 0.0971 0.0885 0.1205 0.0939 2.4365 0.0538\n", | |
" 0.6268 -3.7864 0.1416 1.0899 0.7207 0.3281 0.0347 0.2979 0.2298 0.3171\n", | |
" 0.4224 0.7358 -3.3912 0.2569 0.2338 0.3163 0.3436 0.4604 0.3109 0.3111\n", | |
" 0.5334 0.0474 0.3382 0.8208 0.4691 -3.7323 0.4889 0.1613 0.4568 0.4165\n", | |
" 0.4489 -3.5636 0.4416 0.0672 0.0471 0.8899 0.7920 0.6428 0.1531 0.0810\n", | |
" 0.1111 0.1431 0.1120 -2.9565 0.6451 0.9736 0.0359 0.5117 0.2584 0.1655\n", | |
" 0.3391 0.9316 0.2202 -3.4725 0.0128 0.1691 0.2743 0.4158 0.8600 0.2497\n", | |
" -3.8255 0.6958 0.3361 0.3637 0.0943 0.3866 0.4016 0.6418 0.2407 0.6649\n", | |
"[torch.DoubleTensor of size 25x10]" | |
] | |
}, | |
"execution_count": 66, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_.grad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 56, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lsm.grad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dlsm_dXs = []\n", | |
"for i in range(n_obj):\n", | |
" denom = np.sum(np.exp(X[i]))\n", | |
" dlsm_dXs.append(np.eye(n_feat) - np.exp(X[i][:, None]) / denom)\n", | |
"dlsm_dX = np.array(dlsm_dXs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dlsm_dX_ = Variable(torch.from_numpy(dlsm_dX))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ans = np.einsum('ijk, ik -> ij', dlsm_dX, grads['lsm'].data.numpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2.5407775362056137e-17" | |
] | |
}, | |
"execution_count": 84, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.linalg.norm(ans - X_.grad.data.numpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment