Skip to content

Instantly share code, notes, and snippets.

@josephrocca
Last active October 17, 2022 08:15
Show Gist options
  • Save josephrocca/5c847d793eed1efb67a6566b616eb467 to your computer and use it in GitHub Desktop.
Save josephrocca/5c847d793eed1efb67a6566b616eb467 to your computer and use it in GitHub Desktop.
simple-jax2tf-hello-world-example.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNX53M17VAZNLWhQl9nqDFj",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/josephrocca/5c847d793eed1efb67a6566b616eb467/simple-jax2tf-hello-world-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KFkH90O4qoeh"
},
"outputs": [],
"source": [
"!pip install --upgrade jax jaxlib"
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import tensorflow as tf\n",
"from jax.experimental import jax2tf"
],
"metadata": {
"id": "Kd-cBuQCqu7P"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Example 1:\n",
"def sincos(x):\n",
" return jnp.sin(jnp.cos(x))\n",
"\n",
"my_model = tf.Module()\n",
"my_model.f = tf.function(jax2tf.convert(sincos), autograph=False, jit_compile=True, input_signature=[tf.TensorSpec([], tf.float32)])\n",
"# print(my_model.f(0.53))\n",
"\n",
"tf.saved_model.save(my_model, './sincos', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) # <-- This produces a strange graph with no output node"
],
"metadata": {
"id": "CH1NkSY2repG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Example 2:\n",
"def multiply(a, b):\n",
" return a * b\n",
"\n",
"my_model = tf.Module()\n",
"my_model.f = tf.function(jax2tf.convert(multiply, with_gradient=False), autograph=False, jit_compile=True, input_signature=[\n",
" tf.TensorSpec([3, 3], tf.float32, name=\"a\"),\n",
" tf.TensorSpec([3, 3], tf.float32, name=\"b\"),\n",
"])\n",
"# print(my_model.f(np.ones([3,3]), np.ones([3,3])))\n",
"\n",
"tf.saved_model.save(my_model, './multiply') # <-- This produces a strange graph with no output node"
],
"metadata": {
"id": "tmvlJPFjj6nj"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Convertion to TFLite directly from JAX (via `experimental_from_jax`) works correctly:\n",
"converter = tf.lite.TFLiteConverter.experimental_from_jax([multiply], [[('a', np.ones([3,3], dtype='float32')), ('b', np.ones([3,3], dtype='float32'))]])\n",
"tflite_model = converter.convert()\n",
"with open('multiply.tflite', 'wb') as f:\n",
" f.write(tflite_model)"
],
"metadata": {
"id": "IyYApdCTu_oZ"
},
"execution_count": 5,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment