Last active
August 6, 2022 15:04
-
-
Save shoyer/5f72853c2788e99e785f4737ee8a6ae1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "JAX/TF eager autodiff compatibility.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "mDVvqRieQy8q", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Proof of concept for JAX/TF eager autodiff compatibility\n", | |
"\n", | |
"Author: [email protected]\n", | |
"\n", | |
"Date: April 16, 2020\n", | |
"\n", | |
"The wrapped functions compose well with JAX's autodiff system!\n", | |
"\n", | |
"Limitations (both directions):\n", | |
"\n", | |
"- Only supports TF 2, not TF 1 yet (but that should also be straightforward)\n", | |
"- Only supports a single input/output array for now, not arbitrary function values.\n", | |
"\n", | |
"Limitations for TF in JAX:\n", | |
"- Doesn't support `jit`. This would need support for [wrapping Python functions](https://github.com/google/jax/issues/766) via XLA's CustomCall.\n", | |
"- Doesn't support other transformations like `vmap`. Conceivably if we implemented this via a JAX Primitive instead, we could define a batching rule with [`tf.vectorized_map`](https://www.tensorflow.org/api_docs/python/tf/vectorized_map).\n", | |
"\n", | |
"Current rough edges:\n", | |
"\n", | |
"- Neither TF nor JAX directly converting in-memory arrays into the library directly, so we do everything via NumPy. Ideally this would work transparently and would also transfer arrays directly on device.\n", | |
"- To use TF's auto-diff need to pass a closure between from the forward to backwards passes, but `custom_vjp` insists on auxiliary outputs being pytrees. So we lie and wrap the closure in `tree_util.Partial`.\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XhYoqxRldO5j", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## License" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ltR5KBWLdOlM", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"```\n", | |
"Copyright 2020 Google LLC\n", | |
"\n", | |
"Licensed under the Apache License, Version 2.0 (the \"License\");\n", | |
"you may not use this file except in compliance with the License.\n", | |
"You may obtain a copy of the License at\n", | |
"\n", | |
" https://www.apache.org/licenses/LICENSE-2.0\n", | |
"\n", | |
"Unless required by applicable law or agreed to in writing, software\n", | |
"distributed under the License is distributed on an \"AS IS\" BASIS,\n", | |
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | |
"See the License for the specific language governing permissions and\n", | |
"limitations under the License.\n", | |
"````" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TjXNT6k1Vb53", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Wrap TF in JAX" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3hOyB6bxfrSt", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"! pip install -q -U jaxlib jax" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zHWbqUl3Qu_I", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
}, | |
"outputId": "3a25e300-311b-42b8-f41e-54e7a8b8c8a6" | |
}, | |
"source": [ | |
"%tensorflow_version 2.x\n", | |
"import jax\n", | |
"import tensorflow as tf\n", | |
"import jax.numpy as jnp\n", | |
"import numpy as np\n", | |
"from functools import partial\n", | |
"\n", | |
"def as_jax(x):\n", | |
" return jnp.asarray(np.asarray(x))\n", | |
"\n", | |
"def as_tf(x):\n", | |
" return tf.convert_to_tensor(np.asarray(x))\n", | |
"\n", | |
"def wrap_tf_in_jax(tf_func):\n", | |
" @jax.custom_vjp # requires latest JAX release\n", | |
" def f(x):\n", | |
" return as_jax(tf_func(as_tf(x)))\n", | |
" def f_fwd(x):\n", | |
" with tf.GradientTape() as tape:\n", | |
" x = as_tf(x)\n", | |
" tape.watch(x)\n", | |
" y = tf_func(x)\n", | |
" vjp_func = jax.tree_util.Partial(partial(tape.gradient, y, x))\n", | |
" return as_jax(y), vjp_func\n", | |
" def f_rev(vjp_func, ct_y):\n", | |
" ct_x = vjp_func(as_tf(ct_y))\n", | |
" return (as_jax(ct_x),)\n", | |
" f.defvjp(f_fwd, f_rev)\n", | |
" return f\n", | |
"\n", | |
"x = jnp.arange(3.0)\n", | |
"wrapped_sum = wrap_tf_in_jax(tf.reduce_sum)\n", | |
"np.testing.assert_allclose(wrapped_sum(x), jnp.sum(x))\n", | |
"np.testing.assert_allclose(jax.grad(wrapped_sum)(x), jax.grad(jnp.sum)(x))\n", | |
"\n", | |
"wrapped_square = wrap_tf_in_jax(tf.square)\n", | |
"def tf_and_jax(x):\n", | |
" return wrapped_square(x).sum()\n", | |
"def jax_only(x):\n", | |
" return (x ** 2).sum()\n", | |
"np.testing.assert_allclose(tf_and_jax(x), jax_only(x))\n", | |
"np.testing.assert_allclose(jax.grad(tf_and_jax)(x), jax.grad(jax_only)(x))\n" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n", | |
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zp8ANLTRVeYE", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Wrap JAX in TF" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AS3chbsLQ1Cj", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def wrap_jax_in_tf(jax_func):\n", | |
" @tf.custom_gradient\n", | |
" def f(x):\n", | |
" y, jax_vjp_fn = jax.vjp(jax_func, as_jax(x))\n", | |
" def tf_vjp_fn(ct_y):\n", | |
" ct_x, = as_tf(jax_vjp_fn(as_jax(ct_y)))\n", | |
" return ct_x\n", | |
" return as_tf(y), tf_vjp_fn\n", | |
" return f\n", | |
"\n", | |
"def tf_grad(f):\n", | |
" def f2(x):\n", | |
" with tf.GradientTape() as g:\n", | |
" g.watch(x)\n", | |
" y = f(x)\n", | |
" return g.gradient(y, x)\n", | |
" return f2\n", | |
"\n", | |
"x = tf.range(3.0)\n", | |
"wrapped_sum = wrap_jax_in_tf(jnp.sum)\n", | |
"np.testing.assert_allclose(wrapped_sum(x), tf.reduce_sum(x))\n", | |
"np.testing.assert_allclose(tf_grad(wrapped_sum)(x), tf_grad(tf.reduce_sum)(x))\n", | |
"\n", | |
"wrapped_square = wrap_jax_in_tf(jnp.square)\n", | |
"def tf_and_jax(x):\n", | |
" return tf.reduce_sum(wrapped_square(x))\n", | |
"def tf_only(x):\n", | |
" return tf.reduce_sum(x ** 2)\n", | |
"np.testing.assert_allclose(tf_and_jax(x), tf_only(x))\n", | |
"np.testing.assert_allclose(tf_grad(tf_and_jax)(x), tf_grad(tf_only)(x))\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment