Skip to content

Instantly share code, notes, and snippets.

Last active July 13, 2021 16:11
Show Gist options
  • Save dfm/6b7f9167ce08284721d8c11c4019c12a to your computer and use it in GitHub Desktop.
Save dfm/6b7f9167ce08284721d8c11c4019c12a to your computer and use it in GitHub Desktop.
Noncentral chi squared distribution
import tensorflow_probability.substrates.jax as tfp
import jax.numpy as jnp
import jax.scipy as jsp
from jax import lax
import jax.random as random
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
def _random_chi2(key, df, shape=(), dtype=jnp.float_):
return 2.0 * random.gamma(key, 0.5 * df, shape=shape, dtype=dtype)
class NoncentralChi2(Distribution):
arg_constraints = {
"df": constraints.positive,
"nc": constraints.positive,
support = constraints.positive
reparametrized_params = ["df", "nc"]
def __init__(self, df, nc, validate_args=None):
self.df, = promote_shapes(df, nc)
batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nc))
super(NoncentralChi2, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
def sample(self, key, sample_shape=()):
# Ref:
assert is_prng_key(key)
shape = sample_shape + self.batch_shape + self.event_shape
key1, key2, key3 = random.split(key, 3)
i = random.poisson(key1, 0.5 *, shape=shape)
n = random.normal(key2, shape=shape) + jnp.sqrt(
cond = jnp.greater(self.df, 1.0)
chi2 = _random_chi2(key3, jnp.where(cond, self.df - 1.0, self.df + 2.0 * i), shape=shape)
return jnp.where(cond, chi2 + n * n, chi2)
def log_prob(self, value):
# Ref:
df2 = self.df / 2.0 - 1.0
xs, ns = jnp.sqrt(value), jnp.sqrt(
res = jsp.special.xlogy(df2 / 2.0, value / - 0.5 * (xs - ns) ** 2
corr = tfp.math.bessel_ive(df2, xs * ns) / 2.0
return jnp.where(
jnp.greater(corr, 0.0),
res + jnp.log(corr),
def mean(self):
return self.df +
def variance(self):
return 2.0 * (self.df + 2.0 *
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment