Created
June 21, 2024 09:59
-
-
Save joshlk/20f1f51900e7f299feb618aa2fc55921 to your computer and use it in GitHub Desktop.
Print PyTorch backwards ops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9ba437fc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "b5cb8743", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Tracing back tensors:\n", | |
"<MulBackward0 object at 0x11c9ffaf0>\n", | |
"<SumBackward0 object at 0x12e9c3a60>\n", | |
"<MulBackward0 object at 0x12e9c3400>\n", | |
"<AccumulateGrad object at 0x12e9c3070>\n", | |
"Tensor with grad found: tensor([0.3643, 0.6264, 0.1329, 0.5581, 0.3163], requires_grad=True)\n", | |
" - gradient: tensor([3., 3., 3., 3., 3.])\n", | |
"\n", | |
"<AccumulateGrad object at 0x12e9c3ee0>\n", | |
"Tensor with grad found: tensor([1., 1., 1., 1., 1.], requires_grad=True)\n", | |
" - gradient: tensor([1.0928, 1.8791, 0.3986, 1.6743, 0.9489])\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"input1 = torch.randn(100, 128, requires_grad=True)\n", | |
"input2 = torch.randn(100, 128, requires_grad=True)\n", | |
"cos = nn.CosineSimilarity(dim=1, eps=1e-6)\n", | |
"output = cos(input1, input2)\n", | |
"\n", | |
"print()\n", | |
"print('Tracing back tensors:')\n", | |
"def getBack(var_grad_fn):\n", | |
" print(var_grad_fn)\n", | |
" for n in var_grad_fn.next_functions:\n", | |
" if n[0]:\n", | |
" try:\n", | |
" tensor = getattr(n[0], 'variable')\n", | |
" print(n[0])\n", | |
" print('Tensor with grad found:', tensor)\n", | |
" print(' - gradient:', tensor.grad)\n", | |
" print()\n", | |
" except AttributeError as e:\n", | |
" getBack(n[0])\n", | |
"\n", | |
"output.sum().backward()\n", | |
"getBack(loss.grad_fn)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "3e320c04", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.8534, 0.8103, -0.8892, ..., -0.2915, -0.8183, -0.6481],\n", | |
" [-0.9195, -0.0834, 1.6122, ..., 0.0945, -0.0291, -0.7190],\n", | |
" [ 0.3594, 1.0440, -0.5852, ..., -0.2921, -0.4885, 0.1041],\n", | |
" ...,\n", | |
" [-0.3927, 0.2467, 0.3223, ..., 0.1250, -0.3101, 0.2410],\n", | |
" [ 1.3586, -1.4949, 0.3142, ..., -0.1608, 0.8276, 0.8251],\n", | |
" [-0.2738, 0.9730, -0.6034, ..., -0.8690, 0.0268, -0.8985]],\n", | |
" grad_fn=<AddmmBackward0>)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ad4fd55f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment