- Use
jax.tree_util.Partial
to pass partially-applied functions to JIT-compiled code - Use
static_argnames
, notstatic_argnums
, whenever possible
Created
June 23, 2025 22:06
-
-
Save yberreby/79850710c020dbe8d36828c54cde8e82 to your computer and use it in GitHub Desktop.
JAX Lessons Learned
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment