Last active
February 15, 2022 20:47
-
-
Save ririw/2e3a4415dc8271bd2d132c476b98b567 to your computer and use it in GitHub Desktop.
PYMC3 Zero truncated poisson distribution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymc3 as pm | |
from pymc3.distributions.dist_math import bound, logpow, factln | |
from pymc3.distributions import draw_values, generate_samples | |
import theano.tensor as tt | |
import numpy as np | |
import scipy.stats.distributions | |
class ZTP(pm.Discrete): | |
def __init__(self, mu, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.mode = tt.minimum(tt.floor(mu).astype('int32'), 1) | |
self.mu = mu = tt.as_tensor_variable(mu) | |
def zpt_cdf(self, mu, size=None): | |
mu = np.asarray(mu) | |
dist = scipy.stats.distributions.poisson(mu) | |
lower_cdf = dist.cdf(0) | |
upper_cdf = 1 | |
nrm = upper_cdf - lower_cdf | |
sample = np.random.rand(size) * nrm + lower_cdf | |
return dist.ppf(sample).astype('int64') # Thanks to @omrihar for this fix! | |
def random(self, point=None, size=None, repeat=None): | |
mu = draw_values([self.mu], point=point) | |
return generate_samples(self.zpt_cdf, mu, | |
dist_shape=self.shape, | |
size=size) | |
def logp(self, value): | |
mu = self.mu | |
# mu^k | |
# PDF = ------------ | |
# k! (e^mu - 1) | |
# log(PDF) = log(mu^k) - (log(k!) + log(e^mu - 1)) | |
# | |
# See https://en.wikipedia.org/wiki/Zero-truncated_Poisson_distribution | |
p = logpow(mu, value) - (factln(value) + pm.math.log(pm.math.exp(mu)-1)) | |
log_prob = bound( | |
p, | |
mu >= 0, value >= 0) | |
# Return zero when mu and value are both zero | |
return tt.switch(1 * tt.eq(mu, 0) * tt.eq(value, 0), | |
0, log_prob) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @ririw, thanks for this class and the blogpost it accompanies!
It helped me a lot.
If I can make one suggestion - I noticed when performing prior predictive checks that this class misbehaves when plotting, because it does not output samples that are integers but rather floats. Digging a little deeper I figured out that while
poisson.rvs()
returnsint64
types,poisson.ppf()
returnsfloat64
.There is an easy fix though. Just add
.astype('int64')
to the end of line 23, and everything works nicely with arviz :)