Created
December 24, 2016 01:43
-
-
Save usptact/8670925f3609346552f2989d41c014d3 to your computer and use it in GitHub Desktop.
Estimating coin bias from noisy observations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Estimate coin bias given multiple observations from adversarial experts. | |
# | |
import sys | |
import pymc3 as pm | |
import numpy as np | |
import theano.tensor as tt | |
import matplotlib.pyplot as plt | |
from scipy.optimize import fmin_powell | |
from scipy.stats.distributions import bernoulli | |
p_true = 0.1 # true coin bias (probability of heads) | |
a_true = np.array( [ 0.1, 0.2, 0.3 ] ) # noise (flipping probability) | |
N = 1000 # number of coin observations | |
K = 3 # number of experts | |
# generate reference data | |
X = bernoulli.rvs( size=N, p=p_true ) | |
# corrupt data with noise | |
Y = np.zeros( (K,N) ) | |
for k in range( K ): | |
Y[k,:] = X | |
flip_or_not = bernoulli.rvs( size=N, p=a_true[k] ) | |
for i in range( N ): | |
if flip_or_not[i] == 1: | |
if Y[k,i] == 1: | |
Y[k,i] = 0 | |
else: | |
Y[k,i] = 1 | |
model = pm.Model() | |
with model: | |
alpha0 = pm.HalfCauchy( 'alpha0', beta=1 ) | |
beta0 = pm.HalfCauchy( 'beta0', beta=1 ) | |
p = pm.Beta( 'p', alpha=alpha0, beta=beta0 ) | |
a = pm.Uniform( 'a', lower=0, upper=0.5, shape=K ) | |
q = a + p - 2 * a * p | |
y_hidden = pm.Bernoulli('y_hidden_' + str(k), p=p, shape=N) | |
for k in range( K ): | |
y_obs = pm.Bernoulli( 'y_obs_' + str(k), p=q[k], observed=Y[k,:] ) | |
pot = pm.Potential( 'pot_' + str(k), -1000000 * ( tt.sum( (y_hidden - y_obs)**2 ) ) ) | |
#for i in range( N ): | |
# potential = pm.Potential( 'potential_' + str(i), -1000*(y_hidden[i] - y_obs[i]) ** 2 ) | |
start = pm.find_MAP() | |
step = pm.NUTS( scaling=start ) | |
trace = pm.sample( 1000, start=start, step=step ) | |
varnames = [ 'a', 'p' ] | |
pm.traceplot( trace, varnames=varnames ) | |
#pm.autocorrplot( trace ) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment