Last active
July 31, 2022 17:36
-
-
Save mtreviso/9724d40cbca3d6ded0f4501113d0d4f7 to your computer and use it in GitHub Desktop.
Kuma and HardKuma distributions in JAX using distrax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Adapted from https://github.com/bastings/interpretable_predictions""" | |
import math | |
import distrax | |
import jax | |
import jax.numpy as jnp | |
EPS = 1e-6 | |
@jax.jit | |
def hard_tanh(x, min_val=-1.0, max_val=1.0): | |
return jnp.where(x > 1, max_val, jnp.where(x < -1, min_val, x)) | |
@jax.jit | |
def lbeta(x): | |
x_abs = jnp.abs(x) | |
log_prod_gamma_x = jax.lax.lgamma(x_abs).sum(-1) | |
log_gamma_sum_x = jax.lax.lgamma(x_abs.sum(-1)) | |
return log_prod_gamma_x - log_gamma_sum_x | |
@jax.jit | |
def _harmonic_number(x): | |
""" | |
Compute the harmonic number from its analytic continuation. | |
""" | |
one = jnp.ones(1) | |
return jax.lax.digamma(x + one) - jax.lax.digamma(one) | |
@jax.jit | |
def kuma_mean(a, b): | |
""" | |
Computes the mean of Kumaraswamy using kuma_moments | |
""" | |
return kuma_moments(a, b, 1) | |
@jax.jit | |
def kuma_moments(a, b, n): | |
""" | |
Computes nth moment of Kumaraswamy using jax.lax.lgamma | |
""" | |
arg1 = 1 + n / a | |
log_value = jax.lax.lgamma(jnp.abs(arg1)) | |
log_value += jax.lax.lgamma(jnp.abs(b)) | |
log_value -= jax.lax.lgamma(jnp.abs(arg1 + b)) | |
return b * jnp.exp(log_value) | |
class Kuma(distrax.Distribution): | |
""" | |
A Kumaraswamy, or Kuma for short, is similar to a Beta distribution, though not an exponential family. | |
Kuma variables are specified by two shape parameters, similar to Beta, though for settings that typically | |
yield a symmetric Beta won't necessarily yield a symmetric Kuma. | |
X ~ Kuma(a,b) | |
where a, b > 0 | |
Or equivalently, | |
U ~ U(0,1) | |
x = (1 - (1 - u)^(1/b))^(1/a) | |
In practice we sample from U(0 + eps, 1 - eps) for some small positive constant eps to avoid instabilities. | |
""" | |
def __init__(self, params: list): | |
self.a = params[0] | |
self.b = params[1] | |
def params(self): | |
return [self.a, self.b] | |
def mean(self): | |
return kuma_moments(self.a, self.b, 1) | |
@property | |
def event_shape(self): | |
return () | |
@property | |
def batch_shape(self): | |
return self.a.shape | |
def _sample_n(self, key, n, eps=0.001): | |
shape = [n] + list(self.a.shape) | |
u = jax.random.uniform(key, shape=shape, minval=eps, maxval=1.0-eps) | |
return (1.0 - (1 - u) ** jnp.reciprocal(self.b)) ** jnp.reciprocal(self.a) | |
def log_prob(self, x): | |
""" | |
Kuma(x|a, b) = U(s(x)|0, 1) |det J_s| | |
where x = t(u) and u = s(x) and J_s is the Jacobian matrix of s(x) | |
""" | |
t1 = jnp.log(self.a) + jnp.log(self.b) | |
t2 = (self.a - 1.0 + EPS) * jnp.log(x) | |
pow_x_a = (x ** self.a) + EPS | |
t3b = jnp.log(1.0 - pow_x_a) | |
t3 = (self.b - 1.0 + EPS) * t3b | |
return t1 + t2 + t3 | |
def log_cdf(self, x): | |
r = 1.0 - ((1.0 - (x ** self.a)) ** self.b) | |
r = jnp.log(r + EPS) | |
return jax.lax.clamp(math.log(EPS), r, math.log(1 - EPS)) | |
class StretchedVariable(distrax.Distribution): | |
""" | |
A continuous variable over the open interval [left, right]. | |
X ~ StretchedVariable(RelaxedBinary, [left, right]) | |
left < 0 and right > 1 | |
Or equivalently, | |
Y ~ RelaxedBinary() | |
x = location + y * scale | |
where location = left | |
and scale = right - left | |
""" | |
def __init__(self, dist: distrax.Distribution, support: list): | |
""" | |
:param dist: a RelaxedBinary variable (e.g. BinaryConcrete or Kuma) | |
:param support: a pair specifying the limits of the stretched support (e.g. [-1, 2]) | |
we use these values to compute location = pair[0] and scale = pair[1] - pair[0] | |
""" | |
assert support[0] < support[1], "I need an ordered support, got %s" % support | |
self._dist = dist | |
self.loc = support[0] | |
self.scale = support[1] - support[0] | |
def params(self): | |
return self._dist.params() | |
@property | |
def event_shape(self): | |
return self._dist.event_shape | |
@property | |
def batch_shape(self): | |
return self._dist.batch_shape | |
def _sample_n(self, key, n, eps=0.001): | |
# sample a relaxed binary variable | |
x_ = self._dist._sample_n(key, n, eps=eps) | |
# and stretch it | |
return x_ * self.scale + self.loc | |
def log_prob(self, x): | |
# shrink the stretched variable | |
x_ = (x - self.loc) / self.scale | |
# and assess the stretched pdf using the original pdf | |
# see eq 25 (left) of Louizos et al | |
return self._dist.log_prob(x_) - jnp.log(self.scale) | |
def log_cdf(self, x): | |
# shrink the stretched variable | |
x_ = (x - self.loc) / self.scale | |
# assess its cdf | |
# see eq 25 (right) of Louizos et al | |
r = self._dist.log_cdf(x_) | |
return jax.lax.clamp(math.log(EPS), r, math.log(1 - EPS)) | |
class HardBinary(distrax.Distribution): | |
""" | |
A continuous variable over the closed interval [0, 1] which can assign non-zero probability mass | |
to {0} and {1} (which are sets of zero measure in a standard RelaxedBinary or StretchedVariable). | |
X ~ HardBinary(StretchedVariable) | |
Or equivalently, | |
Y ~ StretchedVariable() | |
x = hardsigmoid(y) | |
""" | |
def __init__(self, dist: StretchedVariable): | |
self._dist = dist | |
@property | |
def event_shape(self): | |
return self._dist.event_shape | |
@property | |
def batch_shape(self): | |
return self._dist.batch_shape | |
def _sample_n(self, key, n, eps=0.001): | |
# sample a stretched variable and rectify it | |
x_ = self._dist._sample_n(key, n, eps=eps) | |
return hard_tanh(x_, min_val=0.0, max_val=1.0) | |
def log_prob(self, x): | |
""" | |
We obtain pdf(0) by integrating the stretched variable over the interval [left, 0] | |
HardBinary.pdf(0) = StretchedVariable.cdf(0) | |
and pdf(1) by integrating the stretched variable over the interval [1, right], or equivalently, | |
HardBinary.pdf(1) = 1 - StretchedVariable.cdf(1) | |
finally, for values in the open (0, 1) we scale the pdf of the stretched variable by the remaining probability | |
mass HardBinary.pdf(x) = StretchedVariable.pdf(x) * (1 - HardBinary.pdf(0) - HardBinary.pdf(1)) | |
See that the total mass over the discrete set {0, 1} is | |
HardBinary.pdf(0) + HardBinary.pdf(1) | |
in other words, with this probability we will be sampling a discrete value. | |
Whenever this probability is greater than 0.5, most probability mass is away from continuous samples. | |
""" | |
# cache these for faster computation | |
log_cdf_0 = self._dist.log_cdf(jnp.zeros(1)) | |
cdf_1 = self._dist.cdf(jnp.ones(1)) | |
# first we fix log_pdf for 0s and 1s | |
# log Q(0) # log (1-Q(1)) | |
log_p = jnp.where(x == 0.0, log_cdf_0, jnp.log(1.0 - cdf_1)) | |
# then for those that are in the open (0, 1) | |
log_p = jnp.where((0.0 < x) & (x < 1.0), self._dist.log_prob(x), log_p) | |
# see eq 26 of Louizos et al | |
return log_p | |
def log_cdf(self, x): | |
""" | |
Note that HardKuma.cdf(0) = HardKuma.pdf(0) by definition of HardKuma.pdf(0), | |
also note that HardKuma.cdf(1) = 1 by definition because | |
the support of HardKuma is the *closed* interval [0, 1] | |
and not the open interval (left, right) which is the support of the stretched variable. | |
""" | |
# all of the mass | |
log_c = jnp.where(x < 1.0, self._dist.log_cdf(x), 0) | |
return jax.lax.clamp(math.log(EPS), log_c, math.log(1 - EPS)) | |
class HardKuma(HardBinary): | |
def __init__(self, params: list, support: list): | |
super().__init__(StretchedVariable(Kuma(params), support)) | |
# shortcut to underlying a and b | |
self.a = self._dist._dist.a | |
self.b = self._dist._dist.b | |
def mean(self): | |
return kuma_moments(self.a, self.b, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment