Skip to content

Instantly share code, notes, and snippets.

@fonnesbeck
Created May 21, 2014 21:11
Show Gist options
  • Save fonnesbeck/4784df751a6d0280345f to your computer and use it in GitHub Desktop.
Save fonnesbeck/4784df751a6d0280345f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:62fb4b6307f8d3ec1041ab86542cad5920238ff63476f513cd04ebaa74218282"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np\n",
"import pymc as mc\n",
"n = 3 #mixtures\n",
"B = 5 #Bias between those at different mixtures\n",
"tau = 3 #Variances\n",
"nprov = 60 #number of parent observations\n",
"mu = [[0,0],[0,B],[-B,0]]\n",
"true_cov0 = np.array([[1.,0.],[0.,1.]])\n",
"true_cov1 = np.array([[1.,0.],[0.,tau**(2)]])\n",
"true_cov2 = np.array([[tau**(-2),0],[0.,1.]])\n",
"trueprobs = [.4, .3, .3] #probability of being in each of the three mixtures\n",
"\n",
"prov = np.random.multinomial(1, trueprobs, size=nprov)\n",
"v = prov[:,1] + (prov[:,2])*2\n",
"numtoeach = 50\n",
"n_obs = nprov*numtoeach\n",
"vAll = np.tile(v,numtoeach)\n",
"ndata = numtoeach*nprov\n",
"p1 = range(nprov)\n",
"prov1 = np.tile(p1,numtoeach)\n",
"\n",
"data = (vAll==0)*(np.random.multivariate_normal(mu[0],true_cov0,ndata)).T \\\n",
" + (vAll==1)*(np.random.multivariate_normal(mu[1],true_cov1,ndata)).T \\\n",
" + (vAll==2)*(np.random.multivariate_normal(mu[2],true_cov2,ndata)).T\n",
"data=data.T"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"p = 2 #covariates\n",
"prior_mu1=np.ones(p)\n",
"prior_mu2=np.ones(p)\n",
"prior_mu3=np.ones(p)\n",
"post_mu1 = mc.Normal(\"returns1\",prior_mu1,1,size=p)\n",
"post_mu2 = mc.Normal(\"returns2\",prior_mu2,1,size=p)\n",
"post_mu3 = mc.Normal(\"returns3\",prior_mu3,1,size=p)\n",
"post_cov_matrix_inv1 = mc.Wishart(\"cov_matrix_inv1\",n_obs,np.eye(p) )\n",
"post_cov_matrix_inv2 = mc.Wishart(\"cov_matrix_inv2\",n_obs,np.eye(p) )\n",
"post_cov_matrix_inv3 = mc.Wishart(\"cov_matrix_inv3\",n_obs,np.eye(p) )\n",
"\n",
"#Combine prior means and variance matrices\n",
"meansAll= np.array([post_mu1,post_mu2,post_mu3], object)\n",
"precsAll= np.array([post_cov_matrix_inv1,post_cov_matrix_inv2,post_cov_matrix_inv3], object)\n",
"\n",
"dd = mc.Dirichlet('dd', theta=(1,)*n)\n",
"category = mc.Categorical('category', p=dd, size=nprov)\n",
"\n",
"#This step accounts for the hierarchy: observations' means are equal to their parents mean\n",
"#Parent is labeled prov1\n",
"\n",
"@mc.deterministic\n",
"def mean(category=category, meansAll=meansAll):\n",
" lat = category[prov1]\n",
" new = meansAll[lat]\n",
" return new\n",
"\n",
"@mc.deterministic\n",
"def prec(category=category, precsAll=precsAll):\n",
" lat = category[prov1]\n",
" return precsAll[lat]\n",
"\n",
"@mc.observed\n",
"def obs(value=data, mean=mean, prec=prec):\n",
" return sum(mc.mv_normal_like(v, m, T) for v,m,T in zip(data, mean, prec))\n",
"\n",
"#obs = mc.MvNormal( \"observed returns\", mean, prec, observed = True, value = data)"
],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"M = mc.MCMC(locals())\n",
"M.sample(10000, 5000)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 8 of 10000 complete in 0.6 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 17 of 10000 complete in 1.1 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 26 of 10000 complete in 1.6 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 35 of 10000 complete in 2.1 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 44 of 10000 complete in 2.7 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 53 of 10000 complete in 3.2 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 62 of 10000 complete in 3.7 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 71 of 10000 complete in 4.3 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 80 of 10000 complete in 4.8 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 88 of 10000 complete in 5.3 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 0% ] 97 of 10000 complete in 5.9 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 105 of 10000 complete in 6.4 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 114 of 10000 complete in 6.9 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 123 of 10000 complete in 7.4 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 132 of 10000 complete in 8.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 141 of 10000 complete in 8.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 150 of 10000 complete in 9.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 157 of 10000 complete in 9.6 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 165 of 10000 complete in 10.1 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 173 of 10000 complete in 10.7 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 181 of 10000 complete in 11.2 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 189 of 10000 complete in 11.7 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 1% ] 197 of 10000 complete in 12.2 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 2% ] 206 of 10000 complete in 12.7 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 2% ] 215 of 10000 complete in 13.3 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 2% ] 223 of 10000 complete in 13.8 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 2% ] 232 of 10000 complete in 14.3 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 2% ] 240 of 10000 complete in 14.8 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 2% ] 249 of 10000 complete in 15.4 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [ 2% ] 258 of 10000 complete in 15.9 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 2% ] 267 of 10000 complete in 16.4 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 2% ] 276 of 10000 complete in 17.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 2% ] 285 of 10000 complete in 17.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 2% ] 294 of 10000 complete in 18.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 303 of 10000 complete in 18.6 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 312 of 10000 complete in 19.1 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 321 of 10000 complete in 19.6 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 330 of 10000 complete in 20.2 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 339 of 10000 complete in 20.7 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 348 of 10000 complete in 21.2 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 357 of 10000 complete in 21.8 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 366 of 10000 complete in 22.3 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 375 of 10000 complete in 22.9 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 384 of 10000 complete in 23.4 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 393 of 10000 complete in 23.9 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 402 of 10000 complete in 24.4 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 411 of 10000 complete in 25.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 420 of 10000 complete in 25.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 429 of 10000 complete in 26.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 438 of 10000 complete in 26.6 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 447 of 10000 complete in 27.1 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 456 of 10000 complete in 27.6 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 4% ] 464 of 10000 complete in 28.2 sec"
]
}
]
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment