Last active
January 14, 2017 05:51
-
-
Save usptact/7476f3d3a3d89f38090b87e24f38ca91 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Attempt to implement Dawid-Skene model in PyMC3 (probably working) | |
# | |
import pymc3 as pm | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import theano | |
import sys | |
import theano.tensor as T | |
from pymc3.backends import SQLite | |
data = np.load( 'preset_MC_80_5_4.npz.npy' ) | |
trueCls = np.load( 'preset_MC_80_5_4_reference_classes.npy' ) | |
I = data.shape[0] # number of items | |
J = data.shape[1] # number of annotators | |
K = data.shape[2] # number of classes | |
N = I * J | |
# create data triplets | |
jj = list() | |
ii = list() | |
y = list() | |
# create data triplets | |
for i in range( I ): | |
for j in range( J ): | |
dat = data[ i, j, : ] | |
k = np.where( dat == 1 )[0][0] | |
ii.append( i ) | |
jj.append( j ) | |
y.append( k ) | |
# class prevalence (flat prior) | |
alpha = 1. / K * np.ones( K ) | |
# individual annotator confusion matrices - dominant diagonal | |
beta = np.ones( (K,K) ) + np.diag( np.ones(K) ) | |
row_sums = beta.sum(axis=1) | |
beta = beta / row_sums[:, np.newaxis] | |
model = pm.Model() | |
with model: | |
pi = pm.Dirichlet( 'pi', a=alpha ) | |
theta = pm.Dirichlet( 'theta', a=beta, shape=(J,K,K) ) | |
z = pm.Categorical( 'z', p=pi, shape=(I,K) ) | |
for n in range( N ): | |
y_obs = pm.Categorical( 'y_obs_'+str(n), p=theta[ jj[ n ], y[n], z[ ii[n] ] ], observed=y[n] ) | |
with model: | |
backend = SQLite( 'dawid-skene.sqlite' ) | |
step1 = pm.Metropolis( vars=[pi,theta] ) | |
step2 = pm.CategoricalGibbsMetropolis( vars=[z] ) | |
trace = pm.sample( 10000, step=[step1,step2], trace=backend, progressbar=True ) | |
pm.summary( trace ) | |
pm.traceplot( trace ) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment