Skip to content

Instantly share code, notes, and snippets.

@ipashchenko
Last active August 29, 2015 14:21
Show Gist options
  • Select an option

  • Save ipashchenko/fa2b9e76ce7335a1270d to your computer and use it in GitHub Desktop.

Select an option

Save ipashchenko/fa2b9e76ce7335a1270d to your computer and use it in GitHub Desktop.
Bayesian estimation of number of trials in a Binomial distribution
import numpy as np
from scipy import stats
try:
import matplotlib.pyplot as plt
except:
plt = None
def posterior_trials(n_max, data, alpha, beta, delta=0.005):
"""
Posterior for ``(p, n)`` marginalized over ``p``.
See https://www.cna.org/sites/default/files/research/2787018500.pdf
:param n_max:
Upper limit for uniform prior on binomial parameter ``n``.
:param data:
Container of data (number of successes in each of the unknown
``n`` number of trials).
:param alpha:
Beta distribution parameter for prior on ``p``.
:param beta:
Beta distribution parameter for prior on ``p``.
:return:
Two numpy arrays - for ``n`` and for posterior probability of ``n``.
"""
beta = float(beta)
alpha = float(alpha)
data = np.asarray(data)
r = len(data)
t = sum(data)
xmax = max(data)
q = [1.]
js = [0]
j = 0
while True:
print "j, q", js, q
if (q[j] / sum(q) < delta) or (j >= n_max - xmax):
break
j += 1
factor1 = q[j - 1]
factor2 = (r * xmax - t + beta + (j - 1) * r + np.arange(r)) /\
(r * xmax + alpha + beta + (j - 1) * r + np.arange(r))
factor2 = factor2.cumprod()[-1]
factor3 = float((xmax + j)) ** r / (float(xmax) - data + j).cumprod()[-1]
q_j = factor1 * factor2 * factor3
q.append(q_j)
js.append(j)
return np.asarray(js) + xmax, np.asarray(q) / sum(q)
if __name__ == '__main__':
np.random.seed(42)
# It overflows for n>~250
data = stats.binom(10, 0.3).rvs(250)
# Parameters of Beta prior on binomial distribution ``p``.
alpha = 1.
beta = 1.
# Upper limit for uniform prior on binomial parameter ``n``.
n_max = 15
n, p_n = posterior_trials(n_max, data, alpha, beta)
if plt:
plt.plot(n, p_n, '.k')
plt.xlabel(r'n for binomial model')
plt.ylabel(r'p(n | data)')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment