Last active
February 15, 2022 20:47
-
-
Save ririw/2e3a4415dc8271bd2d132c476b98b567 to your computer and use it in GitHub Desktop.
PYMC3 Zero truncated poisson distribution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymc3 as pm | |
from pymc3.distributions.dist_math import bound, logpow, factln | |
from pymc3.distributions import draw_values, generate_samples | |
import theano.tensor as tt | |
import numpy as np | |
import scipy.stats.distributions | |
class ZTP(pm.Discrete): | |
def __init__(self, mu, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.mode = tt.minimum(tt.floor(mu).astype('int32'), 1) | |
self.mu = mu = tt.as_tensor_variable(mu) | |
def zpt_cdf(self, mu, size=None): | |
mu = np.asarray(mu) | |
dist = scipy.stats.distributions.poisson(mu) | |
lower_cdf = dist.cdf(0) | |
upper_cdf = 1 | |
nrm = upper_cdf - lower_cdf | |
sample = np.random.rand(size) * nrm + lower_cdf | |
return dist.ppf(sample).astype('int64') # Thanks to @omrihar for this fix! | |
def random(self, point=None, size=None, repeat=None): | |
mu = draw_values([self.mu], point=point) | |
return generate_samples(self.zpt_cdf, mu, | |
dist_shape=self.shape, | |
size=size) | |
def logp(self, value): | |
mu = self.mu | |
# mu^k | |
# PDF = ------------ | |
# k! (e^mu - 1) | |
# log(PDF) = log(mu^k) - (log(k!) + log(e^mu - 1)) | |
# | |
# See https://en.wikipedia.org/wiki/Zero-truncated_Poisson_distribution | |
p = logpow(mu, value) - (factln(value) + pm.math.log(pm.math.exp(mu)-1)) | |
log_prob = bound( | |
p, | |
mu >= 0, value >= 0) | |
# Return zero when mu and value are both zero | |
return tt.switch(1 * tt.eq(mu, 0) * tt.eq(value, 0), | |
0, log_prob) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks! I've fixed it up.