Last active
January 13, 2017 16:46
-
-
Save usptact/b93137a76ea321227046b8c97562929b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Attempt to implement Dawid-Skene model in PyMC3 (broken) | |
# | |
import pymc3 as pm | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import theano | |
import sys | |
import theano.tensor as T | |
# data data matrix ix 80 x 5 x 4 (80 items, 5 annotators and 4 categories) | |
# categories are one-hot encoded | |
data = np.load( 'preset_MC_80_5_4.npz.npy' ) | |
trueCls = np.load( 'preset_MC_80_5_4_reference_classes.npy' ) | |
I = data.shape[0] # number of items | |
J = data.shape[1] # number of annotators | |
K = data.shape[2] # number of classes | |
N = I * J # number of observations | |
print( I, J, K, N ) | |
# create data triplets | |
jj = list() | |
ii = list() | |
y = list() | |
# create data triplets | |
for i in range( I ): | |
for j in range( J ): | |
dat = data[ i, j, : ] | |
k = np.where( dat == 1 )[0][0] | |
ii.append( i ) | |
jj.append( j ) | |
y.append( j ) | |
print( len(ii), len(jj), len(y) ) | |
alpha = 1. / K * np.ones( K ) | |
beta = 1. / K * np.ones( K ) # has to be K x K with a dominant diagonal but keeping it simple | |
model = pm.Model() | |
with model: | |
# RV over K simplex | |
pi = pm.Dirichlet( 'pi', a=alpha ) | |
# I expected an array shaped JxK of Dirichlet RVs; each has to be over K simplext | |
# it looks like underlying code tries to broadcast/align the shape of beta with dimensions of (J,K)... | |
theta = pm.Dirichlet( 'theta', a=beta, shape=(J,K) ) # this does not give a J x K array of Dirichlet RVs | |
theta = T.stack( [ pm.Dirichlet( 'theta_'+str(idx), a=beta ) for idx in range(J*K) ] ) # this is very slow but appears to give an expected flat JxK array of Dirichlet RVs | |
# similarly here, we want an array of I Categorical RVs that are distrib. over K categories each | |
# here it also tries to broadcast the shape of pi with I | |
z = pm.Categorical( 'z', p=pi, shape=I ) | |
#z = T.stack( [ pm.Categorical('z_i'+str(i), p=pi) for i in range(I) ] ) | |
# I get indexing errors on theta array which has to be indexed with 2 indices | |
for n in range( N ): | |
# y_obs = pm.Categorical('y_obs_' + str(n), p=theta[ z[ii[n]] ], observed=y[n]) | |
y_obs = pm.Categorical( 'y_obs_'+str(n), p=theta[ jj[ n ], z[ ii[n] ] ], observed=y[n] ) | |
with model: | |
step = pm.Metropolis() | |
trace = pm.sample( 5, step=step, progressbar=True ) | |
pm.summary( trace ) | |
#pm.traceplot( trace ) | |
#plt.show() | |
# | |
# Output of the script where shapes are used (no T.stack() calls) | |
# | |
80 5 4 400 | |
400 400 400 | |
log_thunk_trace: There was a problem executing an Op. | |
Traceback (most recent call last): | |
File "C:/Users/vladi/PycharmProjects/dawid-skene/dawid-skene.py", line 56, in <module> | |
y_obs = pm.Categorical( 'y_obs_'+str(n), p=theta[ jj[ n ], z[ ii[n] ] ], observed=y[n] ) | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\pymc3\distributions\distribution.py", line 30, in __new__ | |
dist = cls.dist(*args, **kwargs) | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\pymc3\distributions\distribution.py", line 41, in dist | |
dist.__init__(*args, **kwargs) | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\pymc3\distributions\discrete.py", line 377, in __init__ | |
self.k = tt.shape(p)[-1].tag.test_value | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\theano\tensor\var.py", line 575, in __getitem__ | |
lambda entry: isinstance(entry, Variable))) | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\theano\gof\op.py", line 663, in __call__ | |
required = thunk() | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\theano\gof\op.py", line 832, in rval | |
fill_storage() | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\theano\gof\cc.py", line 1701, in __call__ | |
reraise(exc_type, exc_value, exc_trace) | |
File "C:\Users\vladi\Anaconda2\envs\py35\lib\site-packages\six.py", line 686, in reraise | |
raise value | |
IndexError: index out of bounds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment