Created
November 21, 2022 07:46
-
-
Save myazdani/33bb403e0eaadffdff6fddce22a59649 to your computer and use it in GitHub Desktop.
computing-jacobians-different-ways.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPQ0nIndzX7JLKjmo9XmJYY", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/myazdani/33bb403e0eaadffdff6fddce22a59649/computing-jacobians-different-ways.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Comparing different ways of computing Jacobians in PyTorch\n", | |
"\n", | |
"My main aim is to understand the different ways Jacobians are computed in PyTorch and how it compares with functorch.\n", | |
"\n", | |
"The computation graph is this:\n", | |
"$$\n", | |
"\\begin{align*}\n", | |
"a = [2, 3] \\\\\n", | |
"b = [6, 4] \\\\\n", | |
"Q = 3a^3 - b^2\n", | |
"\\end{align*}\n", | |
"$$" | |
], | |
"metadata": { | |
"id": "7_lpd9m9Ihvk" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ipdeWoKmGuMe", | |
"outputId": "3ad25698-7d6f-4332-c3f0-e7b79dbb4fef" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Found existing installation: torch 1.14.0.dev20221120+cpu\n", | |
"Uninstalling torch-1.14.0.dev20221120+cpu:\n", | |
" Successfully uninstalled torch-1.14.0.dev20221120+cpu\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip uninstall -y torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "J9LdEFC1G0aY", | |
"outputId": "88a4a9f2-f51d-4471-9b16-687f0a570bb1" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Looking in links: https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n", | |
"Collecting torch\n", | |
" Using cached https://download.pytorch.org/whl/nightly/cpu/torch-1.14.0.dev20221120%2Bcpu-cp37-cp37m-linux_x86_64.whl (197.9 MB)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (4.1.1)\n", | |
"Requirement already satisfied: networkx in /usr/local/lib/python3.7/dist-packages (from torch) (2.6.3)\n", | |
"Requirement already satisfied: sympy in /usr/local/lib/python3.7/dist-packages (from torch) (1.7.1)\n", | |
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.7/dist-packages (from sympy->torch) (1.2.1)\n", | |
"Installing collected packages: torch\n", | |
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
"torchvision 0.13.1+cu113 requires torch==1.12.1, but you have torch 1.14.0.dev20221120+cpu which is incompatible.\n", | |
"torchtext 0.13.1 requires torch==1.12.1, but you have torch 1.14.0.dev20221120+cpu which is incompatible.\n", | |
"torchaudio 0.12.1+cu113 requires torch==1.12.1, but you have torch 1.14.0.dev20221120+cpu which is incompatible.\n", | |
"fastai 2.7.10 requires torch<1.14,>=1.7, but you have torch 1.14.0.dev20221120+cpu which is incompatible.\u001b[0m\n", | |
"Successfully installed torch-1.14.0.dev20221120+cpu\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Compute Jacobian \"manually\"" | |
], | |
"metadata": { | |
"id": "vldHGSoRIDtg" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch" | |
], | |
"metadata": { | |
"id": "L79OfXZTG1XK" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"a = torch.tensor([2., 3.], requires_grad=True)\n", | |
"b = torch.tensor([6., 4.], requires_grad=True)\n", | |
"Q = 3*a**3 - b**2" | |
], | |
"metadata": { | |
"id": "CYEZOGkiHHJF" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"x = (a,b)\n", | |
"y = Q\n", | |
"V = torch.eye(2)\n", | |
"\n", | |
"N = len(x)\n", | |
"jacobian = []\n", | |
"for i in range(N):\n", | |
" dy_i_dx = torch.autograd.grad(y, x, grad_outputs=V[i], retain_graph=True,)\n", | |
" jacobian.append(dy_i_dx)\n", | |
"jacobian" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ZXZeYr-pHg0n", | |
"outputId": "87567473-9bb1-44cb-835a-a67525050e7b" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[(tensor([36., 0.]), tensor([-12., -0.])),\n", | |
" (tensor([ 0., 81.]), tensor([-0., -8.]))]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Using `torch.autograd.functional.jacobian`" | |
], | |
"metadata": { | |
"id": "d2RsD6KiIN2b" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"a = torch.tensor([2., 3.], requires_grad=True)\n", | |
"b = torch.tensor([6., 4.], requires_grad=True)" | |
], | |
"metadata": { | |
"id": "u5vmAdgMHgtf" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"torch.autograd.functional.jacobian(lambda a, b: 3*a**3 - b**2, (a, b))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Oc4sx49-HJ86", | |
"outputId": "fa38d0b7-9f67-4366-9f24-5737d5ef0cc0" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(tensor([[36., 0.],\n", | |
" [ 0., 81.]]), tensor([[-12., -0.],\n", | |
" [ -0., -8.]]))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Using `vmap` & `vjp` from `functorch`" | |
], | |
"metadata": { | |
"id": "D5AVPIjYITgr" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from functorch import vmap, vjp" | |
], | |
"metadata": { | |
"id": "dnrsGaULHLAF" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"a = torch.tensor([2., 3.], requires_grad=True)\n", | |
"b = torch.tensor([6., 4.], requires_grad=True)" | |
], | |
"metadata": { | |
"id": "4BR6XU-dHQ8R" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"_, vjp_fn = vjp(lambda a, b: 3*a**3 - b**2, a, b)" | |
], | |
"metadata": { | |
"id": "1Ci2SR95HU2D" | |
}, | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ft_jacobian = vmap(vjp_fn)(torch.eye(2))\n", | |
"ft_jacobian" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "SN-5_S7lHWNY", | |
"outputId": "7a9e2098-f99f-489f-a1bb-dfcfe79592dc" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(tensor([[36., 0.],\n", | |
" [ 0., 81.]], grad_fn=<MulBackward0>), tensor([[-12., -0.],\n", | |
" [ -0., -8.]], grad_fn=<MulBackward0>))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "lNa6d_rHHbCR" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "v501i8q8HYwA" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment