Skip to content

Instantly share code, notes, and snippets.

@manuel-delverme
Created August 30, 2021 00:19
Show Gist options
  • Save manuel-delverme/596b6a3b5099b6e1e141772220e660fd to your computer and use it in GitHub Desktop.
Save manuel-delverme/596b6a3b5099b6e1e141772220e660fd to your computer and use it in GitHub Desktop.
import jax.config
import jax.numpy
import jax.ops
jax.config.update("jax_debug_nans", True)
with jax.disable_jit():
x0 = jax.numpy.zeros(1)
jax.value_and_grad(lambda x: jax.scipy.special.logsumexp(x))(x0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment