Created
July 19, 2016 10:40
-
-
Save taku-y/0d5965a184066699cfdc22d9e71d032b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Applied stickbreaking-transform to p and added transformed p_stickbreaking_ to model.\n", | |
"Applied interval-transform to sd and added transformed sd_interval_ to model.\n", | |
" [-----------------100%-----------------] 100 of 100 complete in 0.2 sec[[ 9.99747864e-01 2.52136138e-04]\n", | |
" [ 9.51298297e-01 4.87017033e-02]\n", | |
" [ 9.99573099e-01 4.26900692e-04]\n", | |
" [ 7.42235726e-01 2.57764274e-01]\n", | |
" [ 9.01139410e-01 9.88605896e-02]\n", | |
" [ 2.95292867e-02 9.70470713e-01]\n", | |
" [ 9.49719240e-02 9.05028076e-01]\n", | |
" [ 6.70368286e-04 9.99329632e-01]\n", | |
" [ 5.84711404e-04 9.99415289e-01]\n", | |
" [ 3.64180996e-02 9.63581900e-01]]\n", | |
"[0 0 0 0 0 1 1 1 1 1]\n" | |
] | |
} | |
], | |
"source": [ | |
"import os, sys\n", | |
"sys.path.insert(0, os.path.expanduser('~/work/git/github/taku-y/pymc3'))\n", | |
"\n", | |
"import numpy as np\n", | |
"import pymc3 as pm\n", | |
"\n", | |
"# artificial data\n", | |
"k = 2\n", | |
"ndata = 10\n", | |
"data = np.array([1, 2.2, 2.5, 2.8, 4, 6, 7.2, 7.5, 7.8, 9])\n", | |
"\n", | |
"# model\n", | |
"alpha = 0.1 * np.ones((ndata, k))\n", | |
"with pm.Model() as model:\n", | |
" p = pm.Dirichlet('p', alpha, shape=(ndata, k))\n", | |
" mu = pm.Normal('mu', mu=5, sd=3, shape=k)\n", | |
" sd = pm.Uniform('sd', lower=0.1, upper=0.5, shape=k)\n", | |
" categ = pm.Categorical('categ', p=p, shape=ndata)\n", | |
" obs = pm.Normal('obs',\n", | |
" mu=mu[categ],\n", | |
" sd=sd[categ],\n", | |
" observed=data)\n", | |
"\n", | |
" step1 = pm.Metropolis(vars=[p, sd, mu, obs])\n", | |
" step2 = pm.ElemwiseCategorical(vars=[categ])\n", | |
"\n", | |
" tr = pm.sample(100, step=[step1, step2])\n", | |
" print(tr['p'][-1, :])\n", | |
" print(tr['categ'][-1, :])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
}, | |
"nav_menu": {}, | |
"toc": { | |
"navigate_menu": true, | |
"number_sections": true, | |
"sideBar": true, | |
"threshold": 6, | |
"toc_cell": false, | |
"toc_section_display": "block", | |
"toc_window_display": false | |
}, | |
"widgets": { | |
"state": {}, | |
"version": "1.1.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment