Created
May 21, 2014 21:11
-
-
Save fonnesbeck/4784df751a6d0280345f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:62fb4b6307f8d3ec1041ab86542cad5920238ff63476f513cd04ebaa74218282" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"import numpy as np\n", | |
"import pymc as mc\n", | |
"n = 3 #mixtures\n", | |
"B = 5 #Bias between those at different mixtures\n", | |
"tau = 3 #Variances\n", | |
"nprov = 60 #number of parent observations\n", | |
"mu = [[0,0],[0,B],[-B,0]]\n", | |
"true_cov0 = np.array([[1.,0.],[0.,1.]])\n", | |
"true_cov1 = np.array([[1.,0.],[0.,tau**(2)]])\n", | |
"true_cov2 = np.array([[tau**(-2),0],[0.,1.]])\n", | |
"trueprobs = [.4, .3, .3] #probability of being in each of the three mixtures\n", | |
"\n", | |
"prov = np.random.multinomial(1, trueprobs, size=nprov)\n", | |
"v = prov[:,1] + (prov[:,2])*2\n", | |
"numtoeach = 50\n", | |
"n_obs = nprov*numtoeach\n", | |
"vAll = np.tile(v,numtoeach)\n", | |
"ndata = numtoeach*nprov\n", | |
"p1 = range(nprov)\n", | |
"prov1 = np.tile(p1,numtoeach)\n", | |
"\n", | |
"data = (vAll==0)*(np.random.multivariate_normal(mu[0],true_cov0,ndata)).T \\\n", | |
" + (vAll==1)*(np.random.multivariate_normal(mu[1],true_cov1,ndata)).T \\\n", | |
" + (vAll==2)*(np.random.multivariate_normal(mu[2],true_cov2,ndata)).T\n", | |
"data=data.T" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"p = 2 #covariates\n", | |
"prior_mu1=np.ones(p)\n", | |
"prior_mu2=np.ones(p)\n", | |
"prior_mu3=np.ones(p)\n", | |
"post_mu1 = mc.Normal(\"returns1\",prior_mu1,1,size=p)\n", | |
"post_mu2 = mc.Normal(\"returns2\",prior_mu2,1,size=p)\n", | |
"post_mu3 = mc.Normal(\"returns3\",prior_mu3,1,size=p)\n", | |
"post_cov_matrix_inv1 = mc.Wishart(\"cov_matrix_inv1\",n_obs,np.eye(p) )\n", | |
"post_cov_matrix_inv2 = mc.Wishart(\"cov_matrix_inv2\",n_obs,np.eye(p) )\n", | |
"post_cov_matrix_inv3 = mc.Wishart(\"cov_matrix_inv3\",n_obs,np.eye(p) )\n", | |
"\n", | |
"#Combine prior means and variance matrices\n", | |
"meansAll= np.array([post_mu1,post_mu2,post_mu3], object)\n", | |
"precsAll= np.array([post_cov_matrix_inv1,post_cov_matrix_inv2,post_cov_matrix_inv3], object)\n", | |
"\n", | |
"dd = mc.Dirichlet('dd', theta=(1,)*n)\n", | |
"category = mc.Categorical('category', p=dd, size=nprov)\n", | |
"\n", | |
"#This step accounts for the hierarchy: observations' means are equal to their parents mean\n", | |
"#Parent is labeled prov1\n", | |
"\n", | |
"@mc.deterministic\n", | |
"def mean(category=category, meansAll=meansAll):\n", | |
" lat = category[prov1]\n", | |
" new = meansAll[lat]\n", | |
" return new\n", | |
"\n", | |
"@mc.deterministic\n", | |
"def prec(category=category, precsAll=precsAll):\n", | |
" lat = category[prov1]\n", | |
" return precsAll[lat]\n", | |
"\n", | |
"@mc.observed\n", | |
"def obs(value=data, mean=mean, prec=prec):\n", | |
" return sum(mc.mv_normal_like(v, m, T) for v,m,T in zip(data, mean, prec))\n", | |
"\n", | |
"#obs = mc.MvNormal( \"observed returns\", mean, prec, observed = True, value = data)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"M = mc.MCMC(locals())\n", | |
"M.sample(10000, 5000)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 8 of 10000 complete in 0.6 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 17 of 10000 complete in 1.1 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 26 of 10000 complete in 1.6 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 35 of 10000 complete in 2.1 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 44 of 10000 complete in 2.7 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 53 of 10000 complete in 3.2 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 62 of 10000 complete in 3.7 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 71 of 10000 complete in 4.3 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 80 of 10000 complete in 4.8 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 88 of 10000 complete in 5.3 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 0% ] 97 of 10000 complete in 5.9 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 105 of 10000 complete in 6.4 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 114 of 10000 complete in 6.9 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 123 of 10000 complete in 7.4 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 132 of 10000 complete in 8.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 141 of 10000 complete in 8.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 150 of 10000 complete in 9.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 157 of 10000 complete in 9.6 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 165 of 10000 complete in 10.1 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 173 of 10000 complete in 10.7 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 181 of 10000 complete in 11.2 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 189 of 10000 complete in 11.7 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 1% ] 197 of 10000 complete in 12.2 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 2% ] 206 of 10000 complete in 12.7 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 2% ] 215 of 10000 complete in 13.3 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 2% ] 223 of 10000 complete in 13.8 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 2% ] 232 of 10000 complete in 14.3 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 2% ] 240 of 10000 complete in 14.8 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 2% ] 249 of 10000 complete in 15.4 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [ 2% ] 258 of 10000 complete in 15.9 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 2% ] 267 of 10000 complete in 16.4 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 2% ] 276 of 10000 complete in 17.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 2% ] 285 of 10000 complete in 17.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 2% ] 294 of 10000 complete in 18.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 303 of 10000 complete in 18.6 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 312 of 10000 complete in 19.1 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 321 of 10000 complete in 19.6 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 330 of 10000 complete in 20.2 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 339 of 10000 complete in 20.7 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 348 of 10000 complete in 21.2 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 357 of 10000 complete in 21.8 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 366 of 10000 complete in 22.3 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 375 of 10000 complete in 22.9 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 384 of 10000 complete in 23.4 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 393 of 10000 complete in 23.9 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 402 of 10000 complete in 24.4 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 411 of 10000 complete in 25.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 420 of 10000 complete in 25.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 429 of 10000 complete in 26.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 438 of 10000 complete in 26.6 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 447 of 10000 complete in 27.1 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 456 of 10000 complete in 27.6 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 4% ] 464 of 10000 complete in 28.2 sec" | |
] | |
} | |
] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment