Last active
January 12, 2023 01:57
-
-
Save jakevdp/467da4f567d34c59c1f34559790ef85f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## JAX vmap\n", | |
"\n", | |
"This is the source material for a tweet thread I did recently: https://twitter.com/jakevdp/status/1612544608646606849\n", | |
"\n", | |
"[data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open in Colab"](https://colab.research.google.com/gist/jakevdp/467da4f567d34c59c1f34559790ef85f)" | |
], | |
"metadata": { | |
"id": "4sEj41C8MWRj" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"Let's talk about JAX's vmap! It's a transformation that can automatically create vectorized, batched versions of your functions... but what exactly it does is sometimes misunderstood. So let's dig-in!\n", | |
"\n", | |
"<img src=\"https://jax.readthedocs.io/en/latest/_static/jax_logo_250px.png\"/>\n", | |
"<font size=6>\n", | |
"\n", | |
"```python\n", | |
"from jax import vmap\n", | |
"```\n", | |
"\n", | |
"</font>\n" | |
], | |
"metadata": { | |
"id": "C-rrr7PTnWf3" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"Suppose you've implemented a model that maps a vector input to a scalar output. As an example, here's a simple function similar to a single neuron in a neural net:" | |
], | |
"metadata": { | |
"id": "HSUIp2U_c-na" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import numpy as np\n", | |
"\n", | |
"rng = np.random.RandomState(8675309) # PRNGenny\n", | |
"W = rng.randn(3, 5) # weights\n", | |
"b = 1.0 # bias\n", | |
"\n", | |
"def model(v, W=W, b=b):\n", | |
" return jnp.tanh(W @ v + b).sum()" | |
], | |
"metadata": { | |
"id": "v_3lI5DxrCWL" | |
}, | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"This function accepts a single length-5 vector of inputs, and outputs a scalar:" | |
], | |
"metadata": { | |
"id": "-qoIeQhztUVS" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"v = rng.randn(5)\n", | |
"print(model(v))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ydrO9URuuN3O", | |
"outputId": "124cef9d-4b3e-4d64-ec14-69bd73f491fd" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2.0699806\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"Now, suppose you want to apply this model across a 2D array, where each row of the array is an input. Passing this batched data directly leads to an error:" | |
], | |
"metadata": { | |
"id": "obtKROhinUnQ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# This tells Jupyter to print one-line summaries of exceptions.\n", | |
"%xmode minimal" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "RdETKwxQzmRa", | |
"outputId": "be02c84c-a159-442e-d1bb-362f51e03825" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Exception reporting mode: Minimal\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"v_batch = rng.randn(4, 5) # 4 batches\n", | |
"model(v_batch)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 97 | |
}, | |
"id": "6DUvyosYuw9i", | |
"outputId": "029301bb-da66-4f1e-d839-828a6abea55b" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "ValueError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 4 is different from 5)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"This error arises because our function is not defined in a way that can handle batched input. So what do we do? The easiest approach might be to use a simple Python list comprehension:" | |
], | |
"metadata": { | |
"id": "FDa9OAzMvAH7" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"jnp.array([model(v) for v in v_batch])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "NFWmiu3EvOSX", | |
"outputId": "e55cb1c7-fa44-4042-d2c1-8f7855e403d6" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([-2.263083 , -1.4514356, 0.9401485, 2.9187164], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"This works, of course, but if you're familiar with NumPy-style computing in Python you'll immediately recognize the problem: loops in Python are typically slow compared to the native vectorized operations offered by NumPy & JAX." | |
], | |
"metadata": { | |
"id": "GDc2O5BOu67p" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"In the old days, you'd have to re-write your model to explicitly accept batched data. This sometimes takes some thought, for example here the simple matrix product becomes an Einstein summation:" | |
], | |
"metadata": { | |
"id": "d8fy5VFSv8Vs" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def batched_model(v_batch, W=W, b=b):\n", | |
" # Here are the dimensions for the batched matrix product:\n", | |
" # W: (m, k)\n", | |
" # v_batch: (n_batches, k)\n", | |
" # output: (n_batches, m)\n", | |
" return jnp.tanh(jnp.einsum(\"mk,nk->nm\", W, v_batch) + b).sum(1)\n", | |
"\n", | |
"# Results should match!\n", | |
"print(jnp.array([model(v) for v in v_batch]))\n", | |
"print(batched_model(v_batch))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GHZxQ4DlwIHE", | |
"outputId": "e4dd3798-17da-4215-e347-58e938d004cd" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[-2.263083 -1.4514356 0.9401485 2.9187164]\n", | |
"[-2.263083 -1.4514352 0.9401484 2.9187164]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"As models get more complex, this sort of manual batchification can be complicated and error-prone. This is where jax.vmap comes in: it can transform your function into an efficient and correct batched version automatically!" | |
], | |
"metadata": { | |
"id": "aepo4NQHwt3H" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from jax import vmap\n", | |
"\n", | |
"print(batched_model(v_batch)) # manual batching\n", | |
"print(vmap(model)(v_batch)) # automatic batching!" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ZrFT2m7DxxEK", | |
"outputId": "df9a3130-7f35-46ce-be91-524857889481" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[-2.263083 -1.4514352 0.9401484 2.9187164]\n", | |
"[-2.263083 -1.4514351 0.9401484 2.9187164]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"You might ask now which approach is more efficient: surely vmap must come at a cost? In most cases, however, vmap will produce virtually identical operations as the manual implementation, which we can see by printing the jaxpr (JAX's internal function representation) for each." | |
], | |
"metadata": { | |
"id": "AzHxQrUkyAFV" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"jax.make_jaxpr(batched_model)(v_batch)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "uyhwf0NOzu2O", | |
"outputId": "a772259f-ddbb-4391-93e8-12b72e27ba9d" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ lambda a:f32[3,5]; b:f32[4,5]. let\n", | |
" c:f32[4,3] = xla_call[\n", | |
" call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let\n", | |
" f:f32[4,3] = dot_general[\n", | |
" dimension_numbers=(((1,), (1,)), ((), ()))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" ] e d\n", | |
" in (f,) }\n", | |
" name=_einsum\n", | |
" ] a b\n", | |
" g:f32[4,3] = add c 1.0\n", | |
" h:f32[4,3] = tanh g\n", | |
" i:f32[4] = reduce_sum[axes=(1,)] h\n", | |
" in (i,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"jax.make_jaxpr(vmap(model))(v_batch)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "QO4yt0ahywiB", | |
"outputId": "c82a8ea8-e38a-4e43-8ad8-9cc07792e5d7" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ lambda a:f32[3,5]; b:f32[4,5]. let\n", | |
" c:f32[3,4] = dot_general[\n", | |
" dimension_numbers=(((1,), (1,)), ((), ()))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" ] a b\n", | |
" d:f32[3,4] = add c 1.0\n", | |
" e:f32[3,4] = tanh d\n", | |
" f:f32[4] = reduce_sum[axes=(0,)] e\n", | |
" in (f,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"The details differ slightly — for example, xla_call comes from the fact that einsum is jit compiled — but the essential steps in the computation match more-or-less exactly: dot_general(), then add(), then tanh(), then reduce_sum().\n", | |
"\n", | |
"<pre>\n", | |
"{ lambda a:f32[3,5]; b:f32[4,5]. let { lambda a:f32[3,5]; b:f32[4,5]. let\n", | |
" c:f32[4,3] = xla_call[ c:f32[3,4] = <mark>dot_general</mark>[\n", | |
" call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let dimension_numbers=(((1,), (1,)), ((), ()))\n", | |
" f:f32[4,3] = <mark>dot_general</mark>[ precision=None\n", | |
" dimension_numbers=(((1,), (1,)), ((), ())) preferred_element_type=None\n", | |
" precision=None ] a b\n", | |
" preferred_element_type=None d:f32[3,4] = <mark>add</mark> c 1.0\n", | |
" ] e d e:f32[3,4] = <mark>tanh</mark> d\n", | |
" in (f,) } f:f32[4] = <mark>reduce_sum</mark>[axes=(0,)] e\n", | |
" name=_einsum in (f,) }\n", | |
" ] a b\n", | |
" g:f32[4,3] = <mark>add</mark> c 1.0\n", | |
" h:f32[4,3] = <mark>tanh</mark> g\n", | |
" i:f32[4] = <mark>reduce_sum</mark>[axes=(1,)] h\n", | |
" in (i,) }\n", | |
"</pre>" | |
], | |
"metadata": { | |
"id": "hupLvslAz8o2" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"---\n", | |
"And this is what jax.vmap gives you: a way to automatically create efficient batched versions of your functions – that will lower to fast vectorized computations – without having to re-write your code by hand.\n", | |
"\n", | |
"You can read more about vmap and related transforms in the JAX docs: https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html" | |
], | |
"metadata": { | |
"id": "yVbYunFrddch" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment