Last active
October 16, 2020 19:28
-
-
Save ahartikainen/8713171d259718cf737d8a483500e0c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pystan\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"stan_code = \"\"\"\n", | |
"parameters {\n", | |
" real<lower=0> a;\n", | |
" matrix[3,4] B;\n", | |
"}\n", | |
"model {\n", | |
" a ~ normal(0,1);\n", | |
" for (n in 1:3) {\n", | |
" for (m in 1:4) {\n", | |
" B[n,m] ~ normal(0,2);\n", | |
" }\n", | |
" }\n", | |
"}\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_a71ba528c20fc622bc4c49e3064eafab NOW.\n" | |
] | |
} | |
], | |
"source": [ | |
"model = pystan.StanModel(model_code=stan_code)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fit = model.sampling(iter=1, warmup=0, init=0, seed=1, control={\"adapt_engaged\": False}, check_hmc_diagnostics=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sample = fit.extract(permuted=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lp = sample[\"lp__\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sample = {key: values for key, values in sample.items() if not key.endswith(\"__\")}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'a': array([0.3855873 , 0.98054701, 0.49286373, 1.11726411]),\n", | |
" 'B': array([[[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n", | |
" [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n", | |
" [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]],\n", | |
" \n", | |
" [[ 4.1422781 , 3.68825208, -0.17252071, -1.68395211],\n", | |
" [-2.86501378, 0.33404798, -1.44207571, -1.80443805],\n", | |
" [-1.10181232, 0.2255809 , -0.37776586, 0.5961288 ]],\n", | |
" \n", | |
" [[-0.22384381, 0.3285261 , 0.27917157, 2.15246902],\n", | |
" [ 1.5110607 , -1.25719988, 0.80260026, -0.40612884],\n", | |
" [ 1.11420704, 0.18069843, -1.24951964, -0.30942338]],\n", | |
" \n", | |
" [[ 0.27401254, -1.66957992, -0.26767151, -0.44180366],\n", | |
" [ 0.25835625, -0.16191208, 0.95659342, 1.78234786],\n", | |
" [ 0.26586376, 0.91323483, 0.18991228, 0.13852963]]])}" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sample" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([-2.34083794, -6.63107518, -2.38813117, -1.54752603])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"a (4,)\n", | |
"B (4, 3, 4)\n" | |
] | |
} | |
], | |
"source": [ | |
"for key, values in sample.items():\n", | |
" print(key, values.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'a': 0.38558730390174756,\n", | |
" 'B': array([[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n", | |
" [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n", | |
" [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]])}" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# get one draw\n", | |
"example_dict = {key: values[0] for key, values in sample.items()}\n", | |
"example_dict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"a ()\n", | |
"B (3, 4)\n" | |
] | |
} | |
], | |
"source": [ | |
"for key, values in example_dict.items():\n", | |
" print(key, values.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[], [3, 4]]" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit.par_dims" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([-0.95298764, 0.31295903, -0.28700276, 0.07881552, 1.59176432,\n", | |
" 0.00954776, -0.20168992, 0.25751408, 0.80052233, 1.10617236,\n", | |
" -0.54202137, 0.21879722, 2.33979838])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"unconstrained = fit.unconstrain_pars(example_dict)\n", | |
"unconstrained" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['a',\n", | |
" 'B.1.1',\n", | |
" 'B.2.1',\n", | |
" 'B.3.1',\n", | |
" 'B.1.2',\n", | |
" 'B.2.2',\n", | |
" 'B.3.2',\n", | |
" 'B.1.3',\n", | |
" 'B.2.3',\n", | |
" 'B.3.3',\n", | |
" 'B.1.4',\n", | |
" 'B.2.4',\n", | |
" 'B.3.4']" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# this is the order expected, but unconstrain_pars handles that\n", | |
"fit.unconstrained_param_names()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Calculate log_prob" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-2.3408379411644353" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit.log_prob(unconstrained)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-2.3408379411644353" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lp[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit.log_prob(unconstrained) == lp[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([-0.14867757, -0.07823976, 0.07175069, -0.01970388, -0.39794108,\n", | |
" -0.00238694, 0.05042248, -0.06437852, -0.20013058, -0.27654309,\n", | |
" 0.13550534, -0.0546993 , -0.5849496 ])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit.grad_log_prob(unconstrained, adjust_transform=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0.85132243, -0.07823976, 0.07175069, -0.01970388, -0.39794108,\n", | |
" -0.00238694, 0.05042248, -0.06437852, -0.20013058, -0.27654309,\n", | |
" 0.13550534, -0.0546993 , -0.5849496 ])" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# the Jacobian adjustment\n", | |
"fit.grad_log_prob(unconstrained, adjust_transform=True)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment