Skip to content

Instantly share code, notes, and snippets.

@MarcoGorelli
Last active November 7, 2021 09:07
Show Gist options
  • Save MarcoGorelli/6cbbca162a40597d3b81792629717cad to your computer and use it in GitHub Desktop.
Save MarcoGorelli/6cbbca162a40597d3b81792629717cad to your computer and use it in GitHub Desktop.
analytical solution
def inv_g(x_tilde):
"""Inverse of `g`."""
return jnp.asarray([jax.scipy.special.logit(x_tilde[0]), jnp.log(x_tilde[1])])
x_tilde = jnp.column_stack(
[jnp.linspace(0.001, 0.999, 1000), jnp.linspace(0.001, 3, 1000)]
)
pre_x_tilde = jax.vmap(inv_g)(x_tilde)
@functools.partial(jax.vmap, in_axes=(0, None))
@functools.partial(jax.vmap, in_axes=(None, 0))
def probability_density(x_0, x_1):
"""Calculate probability density of `distribution` at given point."""
return jnp.exp(distribution.log_prob(jnp.asarray([x_0, x_1])))
@functools.partial(jax.vmap, in_axes=(0, None))
@functools.partial(jax.vmap, in_axes=(None, 0))
def inv_det_jacobian_g(x_0, x_1):
"""
Calculate the reciprocal of the determinant of the Jacobian of
`g`, evaluated at given point.
"""
return 1 / jnp.linalg.det(jax.jacobian(g)(jnp.asarray([x_0, x_1])))
pushforward_density = (
inv_det_jacobian_g(pre_x_tilde[:, 0], pre_x_tilde[:, 1])
* probability_density(pre_x_tilde[:, 0], pre_x_tilde[:, 1])
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment