-
-
Save avivajpeyi/1ca004c1129ad9cf816e6fe57a4178d6 to your computer and use it in GitHub Desktop.
intro-to-jax-part2.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "ada14b6a-d989-4aa4-8b71-0d870933eb13", | |
"metadata": { | |
"id": "ada14b6a-d989-4aa4-8b71-0d870933eb13" | |
}, | |
"source": [ | |
"# Introduction to JAX (Part 2)\n", | |
"\n", | |
"We'll start with a re-cap of the previous \"intro to jax\" session with (hopefully!) enough info to get people who weren't there caught up.\n", | |
"\n", | |
"This tutorial includes a whirlwind introduction to JAX. It's going to be pretty incomplete so, if you want more info, check out the [JAX docs](https://jax.readthedocs.io).\n", | |
"\n", | |
"We'll pretty much always want to include this line since JAX normally operates with single point precision:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9dc1746e-58c1-4ae6-824c-f0e0865a784a", | |
"metadata": { | |
"id": "9dc1746e-58c1-4ae6-824c-f0e0865a784a" | |
}, | |
"outputs": [], | |
"source": [ | |
"import jax\n", | |
"\n", | |
"# In many cases you may want to enable support for double precision\n", | |
"# jax.config.update(\"jax_enable_x64\", True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b2f1843d-fe6d-428d-9206-08599ce547de", | |
"metadata": { | |
"id": "b2f1843d-fe6d-428d-9206-08599ce547de" | |
}, | |
"source": [ | |
"## `jax.numpy`\n", | |
"\n", | |
"`jax.numpy` works just like `numpy` (almost always):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "467fe662-819a-4c5a-b179-8e5d1cb21076", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "467fe662-819a-4c5a-b179-8e5d1cb21076", | |
"outputId": "e2cb6afd-6d9c-46a9-f5b2-ae08a669b443" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(DeviceArray([0.1 , 1.325, 2.55 , 3.775, 5. ], dtype=float32),\n", | |
" DeviceArray([ 0.09983342, 0.9699439 , 0.55768377, -0.5918946 ,\n", | |
" -0.9589243 ], dtype=float32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
], | |
"source": [ | |
"import jax.numpy as jnp\n", | |
"\n", | |
"x = jnp.linspace(0.1, 5.0, 5)\n", | |
"y = jnp.sin(x)\n", | |
"x, y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e0b320f6-d4fd-4f59-bf0a-8cbd748215be", | |
"metadata": { | |
"id": "e0b320f6-d4fd-4f59-bf0a-8cbd748215be" | |
}, | |
"source": [ | |
"We can combine regular `numpy` and `jax.numpy`:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8e108d6f-5bf7-483f-8485-7bab4533f92f", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "8e108d6f-5bf7-483f-8485-7bab4533f92f", | |
"outputId": "cdfd2bd5-227a-4dcc-be5d-19139a0aa104" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([0.1 , 1.325, 2.55 , 3.775, 5. ]),\n", | |
" DeviceArray([ 0.09983342, 0.9699439 , 0.55768377, -0.5918946 ,\n", | |
" -0.9589243 ], dtype=float32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"x = np.linspace(0.1, 5.0, 5)\n", | |
"y = jnp.sin(x)\n", | |
"x, y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9282cf37-e89e-4f3a-abfc-b247538f6e4b", | |
"metadata": { | |
"id": "9282cf37-e89e-4f3a-abfc-b247538f6e4b" | |
}, | |
"source": [ | |
"## `jax.jit`\n", | |
"\n", | |
"We use `jax.jit` to fuse operations, and run them on the GPU, for example.\n", | |
"One of the key points to remember when using JAX is that it works best in a \"functional\" style.\n", | |
"A lot of the key JAX functions take a function as input and return a new function.\n", | |
"For example:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "cc4dc0ac-b240-47d1-91d3-f1edbb0a0526", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "cc4dc0ac-b240-47d1-91d3-f1edbb0a0526", | |
"outputId": "f078464d-67f4-43d8-8983-8d2c2acf5e91" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"hi from this function\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 , 1.883305 ], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
], | |
"source": [ | |
"def jnp_function(x):\n", | |
" print(\"hi from this function\")\n", | |
" arg = jnp.sin(x)\n", | |
" return 1.5 + jnp.exp(arg)\n", | |
"\n", | |
"jitted_function = jax.jit(jnp_function)\n", | |
"\n", | |
"jitted_function(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8c7d5f64-7797-46d8-ac54-4c46cb6d9525", | |
"metadata": { | |
"id": "8c7d5f64-7797-46d8-ac54-4c46cb6d9525" | |
}, | |
"source": [ | |
"What happens if we call that function again?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "79bac88b-202b-42b2-889c-16afe755f667", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "79bac88b-202b-42b2-889c-16afe755f667", | |
"outputId": "4c21b7ef-c766-4415-ce9f-78757d165879" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 , 1.883305 ], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
], | |
"source": [ | |
"jitted_function(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5e7c75a4-23da-4c28-9ee0-1e5baa960327", | |
"metadata": { | |
"id": "5e7c75a4-23da-4c28-9ee0-1e5baa960327" | |
}, | |
"source": [ | |
"What about if we call it with a different input?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "35573796-3a57-44ab-94f6-dcd34c481f97", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "35573796-3a57-44ab-94f6-dcd34c481f97", | |
"outputId": "6b5c0dea-ea02-493e-ac8c-1d7592af2fb9" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([2.6048036, 3.7815475, 3.1976116, 2.0723903, 1.9410601], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
], | |
"source": [ | |
"jitted_function(np.sin(x))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "020331f8-2033-4937-8273-7565ab26fd6d", | |
"metadata": { | |
"id": "020331f8-2033-4937-8273-7565ab26fd6d" | |
}, | |
"source": [ | |
"What about an input with a different shape?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1353bc5c-6a1d-4d12-a9be-28e9204e2c95", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1353bc5c-6a1d-4d12-a9be-28e9204e2c95", | |
"outputId": "7be7f76c-4f73-405c-8a0b-945d23ba105f" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"hi from this function\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 ], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
], | |
"source": [ | |
"jitted_function(x[:-1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b2990af6-a1b8-426a-aa5e-7560c97b64d6", | |
"metadata": { | |
"id": "b2990af6-a1b8-426a-aa5e-7560c97b64d6" | |
}, | |
"source": [ | |
"*Note:* It is common to use `jax.jit` as a \"decorator\":" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d997f1d0-8117-425b-bd8a-889be7cb821f", | |
"metadata": { | |
"id": "d997f1d0-8117-425b-bd8a-889be7cb821f" | |
}, | |
"outputs": [], | |
"source": [ | |
"@jax.jit\n", | |
"def jitted_function(x):\n", | |
" arg = jnp.sin(x)\n", | |
" return 1.5 + jnp.exp(arg)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"What about control flow?" | |
], | |
"metadata": { | |
"id": "Se90GwOHOhP3" | |
}, | |
"id": "Se90GwOHOhP3" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@jax.jit\n", | |
"def incorrect_conditional_func(x):\n", | |
" if jnp.all(x > 0):\n", | |
" return x\n", | |
" arg = jnp.sin(x)\n", | |
" return 1.5 + jnp.exp(arg)" | |
], | |
"metadata": { | |
"id": "XG8W5MNUOjkk" | |
}, | |
"id": "XG8W5MNUOjkk", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# What happens if we run this?\n", | |
"# incorrect_conditional_func(x)" | |
], | |
"metadata": { | |
"id": "p5ZfoQGvOopU" | |
}, | |
"id": "p5ZfoQGvOopU", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@jax.jit\n", | |
"def correct_conditional_func(x):\n", | |
" arg = jnp.sin(x)\n", | |
" return jnp.where(jnp.all(x > 0), x, 1.5 + jnp.exp(arg))" | |
], | |
"metadata": { | |
"id": "iIIaFa_1O88z" | |
}, | |
"id": "iIIaFa_1O88z", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"correct_conditional_func(x)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5HNdGm6gPMZ_", | |
"outputId": "12092eb1-8b97-47a4-c827-0972934d2422" | |
}, | |
"id": "5HNdGm6gPMZ_", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([0.1 , 1.325, 2.55 , 3.775, 5. ], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cdc4a86a-cd09-4d4d-a5d0-55438cc9c02b", | |
"metadata": { | |
"id": "cdc4a86a-cd09-4d4d-a5d0-55438cc9c02b" | |
}, | |
"source": [ | |
"## `jax.vmap`\n", | |
"\n", | |
"`jax.vmap` gives a mechanism for applying a \"scalar\" function on a vector of inputs.\n", | |
"The same effects can often be achieved by manually broadcasting, but `vmap` comes in handy more often than you might think." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "136772d8-a384-4dc3-87ba-e967e5bfbf7c", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "136772d8-a384-4dc3-87ba-e967e5bfbf7c", | |
"outputId": "e5eb8a3d-e696-4880-94e9-5074c315e837" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[[ 1.19428426e-01, 2.83938229e-01, 1.14193819e-01],\n", | |
" [ 2.83938229e-01, 6.75056338e-01, 2.71493077e-01],\n", | |
" [ 1.14193819e-01, 2.71493077e-01, 1.09188654e-01]],\n", | |
"\n", | |
" [[ 1.69821870e+00, -1.17982101e+00, -5.81696212e-01],\n", | |
" [-1.17982101e+00, 8.19669247e-01, 4.04127836e-01],\n", | |
" [-5.81696212e-01, 4.04127836e-01, 1.99250251e-01]],\n", | |
"\n", | |
" [[ 2.88318753e-01, -3.12033236e-01, -1.95758328e-01],\n", | |
" [-3.12033236e-01, 3.37698251e-01, 2.11859629e-01],\n", | |
" [-1.95758328e-01, 2.11859629e-01, 1.32913038e-01]],\n", | |
"\n", | |
" [[ 8.65139291e-02, 8.35990533e-03, 1.60806060e-01],\n", | |
" [ 8.35990533e-03, 8.07823846e-04, 1.55388089e-02],\n", | |
" [ 1.60806060e-01, 1.55388089e-02, 2.98895091e-01]],\n", | |
"\n", | |
" [[ 5.42364597e-01, 1.19975701e-01, 3.55058730e-01],\n", | |
" [ 1.19975701e-01, 2.65396535e-02, 7.85420388e-02],\n", | |
" [ 3.55058730e-01, 7.85420388e-02, 2.32439041e-01]]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 13 | |
} | |
], | |
"source": [ | |
"A = np.random.default_rng(1).normal(size=(5, 3))\n", | |
"\n", | |
"def scalar_function(x):\n", | |
" return jnp.outer(x, x)\n", | |
"\n", | |
"vector_function = jax.vmap(scalar_function)\n", | |
"vector_function(A)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4a861a6d-f6b2-415b-a3d7-a9eba66b6df8", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4a861a6d-f6b2-415b-a3d7-a9eba66b6df8", | |
"outputId": "cb2d525b-6174-4e9d-818c-61077554e40b" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[[ 1.19428434e-01, 2.83938242e-01, 1.14193830e-01],\n", | |
" [ 2.83938242e-01, 6.75056374e-01, 2.71493097e-01],\n", | |
" [ 1.14193830e-01, 2.71493097e-01, 1.09188661e-01]],\n", | |
"\n", | |
" [[ 1.69821877e+00, -1.17982104e+00, -5.81696252e-01],\n", | |
" [-1.17982104e+00, 8.19669245e-01, 4.04127838e-01],\n", | |
" [-5.81696252e-01, 4.04127838e-01, 1.99250259e-01]],\n", | |
"\n", | |
" [[ 2.88318777e-01, -3.12033246e-01, -1.95758328e-01],\n", | |
" [-3.12033246e-01, 3.37698251e-01, 2.11859620e-01],\n", | |
" [-1.95758328e-01, 2.11859620e-01, 1.32913032e-01]],\n", | |
"\n", | |
" [[ 8.65139256e-02, 8.35990480e-03, 1.60806056e-01],\n", | |
" [ 8.35990480e-03, 8.07823801e-04, 1.55388084e-02],\n", | |
" [ 1.60806056e-01, 1.55388084e-02, 2.98895090e-01]],\n", | |
"\n", | |
" [[ 5.42364622e-01, 1.19975697e-01, 3.55058738e-01],\n", | |
" [ 1.19975697e-01, 2.65396512e-02, 7.85420322e-02],\n", | |
" [ 3.55058738e-01, 7.85420322e-02, 2.32439032e-01]]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 14 | |
} | |
], | |
"source": [ | |
"A[:, None, :] * A[:, :, None]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a37c40b8-8a99-46db-860b-a04b2918b976", | |
"metadata": { | |
"id": "a37c40b8-8a99-46db-860b-a04b2918b976" | |
}, | |
"source": [ | |
"## `jax.grad`\n", | |
"\n", | |
"Any JAX function can also be differentiated." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fc0e710d-962b-44e7-8649-51b88e47c806", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "fc0e710d-962b-44e7-8649-51b88e47c806", | |
"outputId": "501f0dd5-9aee-4df1-c3d1-ff14e7e8bf15" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray(1.4174242, dtype=float32, weak_type=True)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 15 | |
} | |
], | |
"source": [ | |
"grad_function = jax.grad(jitted_function)\n", | |
"grad_function(0.5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fc51106f-1379-4b53-95f1-031a5d67264d", | |
"metadata": { | |
"id": "fc51106f-1379-4b53-95f1-031a5d67264d" | |
}, | |
"source": [ | |
"By default, differentiation is only supported for scalar outputs:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c343dc80-4d7e-461b-b520-34dfadff76f2", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "c343dc80-4d7e-461b-b520-34dfadff76f2", | |
"outputId": "eb595fe6-4fda-4bc1-a331-b81f18737388" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
" with pytest.raises(TypeError) as info:\n", | |
"> jax.grad(jitted_function)(x)\n", | |
"E TypeError: Gradient only defined for scalar-output functions. Output had shape: (5,).\n", | |
"\n", | |
"<ipython-input-16-b967da8725c7>:4: TypeError\n" | |
] | |
} | |
], | |
"source": [ | |
"import pytest\n", | |
"\n", | |
"with pytest.raises(TypeError) as info:\n", | |
" jax.grad(jitted_function)(x)\n", | |
"print(\"\\n\\n\".join(str(info.getrepr()).split(\"\\n\\n\")[-2:]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6d2a14ce-d5bd-436f-801a-71242d1167f9", | |
"metadata": { | |
"id": "6d2a14ce-d5bd-436f-801a-71242d1167f9" | |
}, | |
"source": [ | |
"But we can combine `grad` with `vmap` to get the derivative at each input point:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "59edef16-1a86-4ce4-8ac6-5e470090c242", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "59edef16-1a86-4ce4-8ac6-5e470090c242", | |
"outputId": "fc867951-c5f6-4e37-88d2-3b98353b8767" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([ 1.0994664 , 0.6418517 , -1.4497899 , -0.4459506 ,\n", | |
" 0.10872913], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
], | |
"source": [ | |
"jax.vmap(jax.grad(jitted_function))(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fb4d5bbf-4e23-4cd9-be05-d21bf89094fe", | |
"metadata": { | |
"id": "fb4d5bbf-4e23-4cd9-be05-d21bf89094fe" | |
}, | |
"source": [ | |
"Another useful function is `jax.value_and_grad`:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ce76be99-0778-4859-90cf-6fce4c1cf3cd", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ce76be99-0778-4859-90cf-6fce4c1cf3cd", | |
"outputId": "dc0ae4a5-810d-42f3-da39-e5a4a0623bec" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 , 1.883305 ], dtype=float32),\n", | |
" DeviceArray([ 1.0994664 , 0.6418517 , -1.4497899 , -0.4459506 ,\n", | |
" 0.10872913], dtype=float32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
], | |
"source": [ | |
"jax.vmap(jax.value_and_grad(jitted_function))(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ba508356-266d-49cb-b6a9-9fbc46c0b7bf", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 265 | |
}, | |
"id": "ba508356-266d-49cb-b6a9-9fbc46c0b7bf", | |
"outputId": "f43e9d52-701c-4309-fb89-352670f40497" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"x_grid = jnp.linspace(-5, 5, 100)\n", | |
"value, grad = jax.vmap(jax.value_and_grad(jitted_function))(x_grid)\n", | |
"plt.plot(x_grid, value, label=\"value\")\n", | |
"plt.plot(x_grid, grad, label=\"grad\")\n", | |
"plt.legend();" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@jax.jit\n", | |
"def f(x):\n", | |
" y = jnp.exp(-2.0 * x)\n", | |
" return (1.0 - y) / (1.0 + y)\n", | |
"\n", | |
"dfdx = jax.grad(f)\n", | |
"d2fdx = jax.grad(dfdx)\n", | |
"d3fdx = jax.grad(d2fdx)\n", | |
"d4fdx = jax.grad(d3fdx)\n", | |
"\n", | |
"x = jnp.linspace(-4,4, 200)\n", | |
"plt.plot(x, f(x), label=\"f\")\n", | |
"plt.plot(x, jax.vmap(dfdx)(x), label=\"f'\")\n", | |
"plt.plot(x, jax.vmap(d2fdx)(x), label=\"f''\")\n", | |
"plt.plot(x, jax.vmap(d3fdx)(x), label=\"f'''\")\n", | |
"plt.plot(x, jax.vmap(d4fdx)(x), label=\"f''''\")\n", | |
"plt.legend(frameon=False, loc='upper right')\n", | |
"plt.gca().axis('off')\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 248 | |
}, | |
"id": "EqDMScSD4mhD", | |
"outputId": "e60899b5-6523-43b9-cbc8-e059fa391567" | |
}, | |
"id": "EqDMScSD4mhD", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f0704757-2bd8-4439-a0f7-8d17dd1a138a", | |
"metadata": { | |
"id": "f0704757-2bd8-4439-a0f7-8d17dd1a138a" | |
}, | |
"source": [ | |
"## PyTrees\n", | |
"\n", | |
"Another useful JAX concept is \"PyTrees\".\n", | |
"This allows us to use structured inputs and still use `jit`, `vmap`, and `grad`.\n", | |
"For example:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "db182d2d-47c3-4f63-88cc-f0f316fb0ad8", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "db182d2d-47c3-4f63-88cc-f0f316fb0ad8", | |
"outputId": "32cf977c-757a-42fa-d6c7-af71c435c4fa" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray(0.02227585, dtype=float32, weak_type=True)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 20 | |
} | |
], | |
"source": [ | |
"def pytree_func(params):\n", | |
" return jnp.exp(params[\"log_amp\"]) * jnp.sin(params[\"log_scale\"])\n", | |
"\n", | |
"params = {\n", | |
" \"log_amp\": -1.5,\n", | |
" \"log_scale\": 0.1,\n", | |
"}\n", | |
"pytree_func(params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "59f7514e-4c74-4a64-9f56-bfec4469c41c", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "59f7514e-4c74-4a64-9f56-bfec4469c41c", | |
"outputId": "8b8e8341-ed82-487f-9387-da348e724b9a" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'log_amp': DeviceArray(0.02227585, dtype=float32, weak_type=True),\n", | |
" 'log_scale': DeviceArray(0.22201544, dtype=float32, weak_type=True)}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 21 | |
} | |
], | |
"source": [ | |
"jax.grad(pytree_func)(params)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Random numbers\n", | |
"\n", | |
"Random number generation in JAX is a little different from in numpy.\n", | |
"For example, every random function takes a \"key\" as input:" | |
], | |
"metadata": { | |
"id": "RFR4UqrmMoFQ" | |
}, | |
"id": "RFR4UqrmMoFQ" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from jax import random\n", | |
"\n", | |
"key = random.PRNGKey(42)\n", | |
"random.normal(key)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "cTAI87MPMnvI", | |
"outputId": "8c93bd9c-31ab-4cfe-abb1-f5dd7cd04135" | |
}, | |
"id": "cTAI87MPMnvI", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray(-0.18471177, dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 22 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0b6d64a9-6361-4f05-b399-29312c15ef31", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0b6d64a9-6361-4f05-b399-29312c15ef31", | |
"outputId": "8644972a-d897-4877-c0be-e1dd64da92ef" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray(-0.18471177, dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 23 | |
} | |
], | |
"source": [ | |
"random.normal(key)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"If you want to generate multiple different random numbers, a good approach is to \"split\" the key." | |
], | |
"metadata": { | |
"id": "sOY38c39NQNN" | |
}, | |
"id": "sOY38c39NQNN" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"key1, key2 = random.split(key)\n", | |
"random.normal(key1), random.uniform(key2)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "9v5Hob9_NEKM", | |
"outputId": "06b28e51-f6ed-40a5-8845-f95ad1a81292" | |
}, | |
"id": "9v5Hob9_NEKM", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(DeviceArray(0.13790321, dtype=float32),\n", | |
" DeviceArray(0.91457367, dtype=float32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 24 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Optimizers\n", | |
"\n", | |
"The JAX ecosystem is pretty modular and there are various packages available for non-linear function optimization.\n", | |
"Some popular ones include [jaxopt](https://github.com/google/jaxopt) (\"scipy.optimize with support for PyTrees\") and [optax](https://github.com/deepmind/optax) (\"feature-rich framework with a lot more boilerplate\")." | |
], | |
"metadata": { | |
"id": "u01iPar9Wc8B" | |
}, | |
"id": "u01iPar9Wc8B" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%pip install -q jaxopt optax" | |
], | |
"metadata": { | |
"id": "LPVdUpTUNLLU" | |
}, | |
"id": "LPVdUpTUNLLU", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import jaxopt\n", | |
"import optax\n", | |
"\n", | |
"def loss(params):\n", | |
" return jnp.sum(jnp.square(params[\"x\"]))\n", | |
"\n", | |
"params = {\"x\": 12.5}\n", | |
"opt = jaxopt.ScipyMinimize(fun=loss)\n", | |
"soln = opt.run(params)\n", | |
"print(soln)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1oYP60kxXf1v", | |
"outputId": "e67eddb6-d155-4d25-f599-4607aef8d41e" | |
}, | |
"id": "1oYP60kxXf1v", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"OptStep(params={'x': DeviceArray(4.7211597e-07, dtype=float32)}, state=ScipyMinimizeInfo(fun_val=DeviceArray(2.228935e-13, dtype=float32, weak_type=True), success=True, status=0, iter_num=2))\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"params = {\"x\": 12.5}\n", | |
"opt = optax.sgd(0.1)\n", | |
"opt_state = opt.init(params)\n", | |
"\n", | |
"@jax.jit\n", | |
"def train(params, opt_state):\n", | |
" value, grads = jax.value_and_grad(loss)(params)\n", | |
" updates, opt_state = opt.update(grads, opt_state)\n", | |
" params = optax.apply_updates(params, updates)\n", | |
" return value, params, opt_state\n", | |
"\n", | |
"losses = []\n", | |
"for _ in range(100):\n", | |
" value, params, opt_state = train(params, opt_state)\n", | |
" losses.append(value)\n", | |
"params" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4FUgjJwXXuq8", | |
"outputId": "bf12e9a8-d442-414d-f23a-cb0693cca0b7" | |
}, | |
"id": "4FUgjJwXXuq8", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'x': DeviceArray(2.5462958e-09, dtype=float32)}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 27 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"plt.plot(losses)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 282 | |
}, | |
"id": "rF30CR93YdEP", | |
"outputId": "9c201c4c-163a-4b95-b542-e57bdf4e9aa4" | |
}, | |
"id": "rF30CR93YdEP", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x7fbb7f52a710>]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 28 | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "Afl22IuwYie5" | |
}, | |
"id": "Afl22IuwYie5", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.9" | |
}, | |
"colab": { | |
"name": "intro-to-jax-part2.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment