Excerpt from this issue: patrick-kidger/diffrax#202
Created
October 8, 2023 09:38
-
-
Save llandsmeer/d11db31cbfe8a1cf6ddf44127aad8308 to your computer and use it in GitHub Desktop.
Make jax.lax.cond work in jax2tf
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import onnxruntime as rt | |
def patch_jax2tf_with_case_instead_of_switch_case(): | |
import tensorflow as tf | |
import jax | |
from jax.experimental.jax2tf import jax2tf | |
def _cond(index, *operands, branches, linear): | |
del linear | |
return tf.case([ | |
(tf.equal(i, index), lambda: jax2tf._interpret_jaxpr( | |
jaxpr, *operands, extra_name_stack=f'branch_{i}_fun') | |
) for i, jaxpr in enumerate(branches) | |
], exclusive=True) | |
jax2tf.tf_impl[jax.lax.cond_p] = _cond | |
def build_onnx_model(): | |
import equinox.internal as eqxi | |
from diffrax import diffeqsolve, ODETerm, Euler | |
def simulate(y0): | |
solution = diffeqsolve( | |
terms=ODETerm(lambda t, y, a: -y), solver=Euler(), | |
t0=0, t1=1, dt0=0.1, y0=y0 | |
) | |
return solution.ys[0] | |
onnx_generator_fn = eqxi.to_onnx(simulate) | |
model, _none = onnx_generator_fn(1.0) | |
return model | |
patch_jax2tf_with_case_instead_of_switch_case() | |
onnx_model = build_onnx_model() | |
sess = rt.InferenceSession(onnx_model.SerializeToString()) | |
input_name = sess.get_inputs()[0].name | |
onnx_output = sess.run(None, {input_name: np.array(100.0).astype('float32')})[0] | |
print(onnx_output) | |
assert np.isclose(onnx_output, 100 * np.exp(-1), rtol=0.1, atol=0.1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment