Last active
December 5, 2024 12:30
-
-
Save ricardoV94/6198899c00b7562f4547a898bd4140dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "d5e93528-fbba-4b7b-b2cb-ba345aba42f8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Initializing NUTS using jitter+adapt_diag...\n", | |
"Multiprocess sampling (4 chains in 4 jobs)\n", | |
"NUTS: [a, b, p]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9d492e469e4a4fe5a713b37c395f139d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Output()" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
], | |
"text/plain": [] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.\n", | |
"Multiprocess sampling (4 chains in 4 jobs)\n", | |
"NUTS: [a, b, p]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "abf539ffd43d4e4d9be83837bbaddac1", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Output()" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
], | |
"text/plain": [] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.\n", | |
"Multiprocess sampling (4 chains in 4 jobs)\n", | |
"CompoundStep\n", | |
">NUTS: [a, b]\n", | |
">NUTS: [p]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0d2fe145c02349439458316fad6e667d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Output()" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
], | |
"text/plain": [] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4 seconds.\n" | |
] | |
} | |
], | |
"source": [ | |
"import pymc as pm\n", | |
"\n", | |
"with pm.Model() as m:\n", | |
" a = pm.Gamma(\"a\", mu=1, sigma=1)\n", | |
" b = pm.Gamma(\"b\", mu=1, sigma=1)\n", | |
" p = pm.Beta(\"p\", a, b, shape=(4,))\n", | |
" y = pm.Binomial(\"y\", n=100, p=p, observed=[50, 25, 75, 50])\n", | |
" idata_default_nuts = pm.sample()\n", | |
" idata_custom_nuts = pm.sample(step=[pm.NUTS([a, b, p])])\n", | |
" idata_blocked_nuts = pm.sample(step=[pm.NUTS([a, b]), pm.NUTS(p)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "2edd46d6-2b87-41b5-b40d-373c38309ccd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>mean</th>\n", | |
" <th>sd</th>\n", | |
" <th>hdi_3%</th>\n", | |
" <th>hdi_97%</th>\n", | |
" <th>mcse_mean</th>\n", | |
" <th>mcse_sd</th>\n", | |
" <th>ess_bulk</th>\n", | |
" <th>ess_tail</th>\n", | |
" <th>r_hat</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>a</th>\n", | |
" <td>1.597</td>\n", | |
" <td>0.809</td>\n", | |
" <td>0.344</td>\n", | |
" <td>3.115</td>\n", | |
" <td>0.013</td>\n", | |
" <td>0.010</td>\n", | |
" <td>3813.0</td>\n", | |
" <td>2764.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>b</th>\n", | |
" <td>1.603</td>\n", | |
" <td>0.832</td>\n", | |
" <td>0.296</td>\n", | |
" <td>3.139</td>\n", | |
" <td>0.014</td>\n", | |
" <td>0.011</td>\n", | |
" <td>3749.0</td>\n", | |
" <td>2565.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail \\\n", | |
"a 1.597 0.809 0.344 3.115 0.013 0.010 3813.0 2764.0 \n", | |
"b 1.603 0.832 0.296 3.139 0.014 0.011 3749.0 2565.0 \n", | |
"\n", | |
" r_hat \n", | |
"a 1.0 \n", | |
"b 1.0 " | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pm.stats.summary(idata_default_nuts, var_names=[\"a\", \"b\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "007ac8e6-cb22-4d8b-8919-937234961857", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>mean</th>\n", | |
" <th>sd</th>\n", | |
" <th>hdi_3%</th>\n", | |
" <th>hdi_97%</th>\n", | |
" <th>mcse_mean</th>\n", | |
" <th>mcse_sd</th>\n", | |
" <th>ess_bulk</th>\n", | |
" <th>ess_tail</th>\n", | |
" <th>r_hat</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>a</th>\n", | |
" <td>1.588</td>\n", | |
" <td>0.836</td>\n", | |
" <td>0.305</td>\n", | |
" <td>3.177</td>\n", | |
" <td>0.014</td>\n", | |
" <td>0.011</td>\n", | |
" <td>3563.0</td>\n", | |
" <td>2629.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>b</th>\n", | |
" <td>1.600</td>\n", | |
" <td>0.853</td>\n", | |
" <td>0.351</td>\n", | |
" <td>3.186</td>\n", | |
" <td>0.015</td>\n", | |
" <td>0.012</td>\n", | |
" <td>3399.0</td>\n", | |
" <td>2289.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail \\\n", | |
"a 1.588 0.836 0.305 3.177 0.014 0.011 3563.0 2629.0 \n", | |
"b 1.600 0.853 0.351 3.186 0.015 0.012 3399.0 2289.0 \n", | |
"\n", | |
" r_hat \n", | |
"a 1.0 \n", | |
"b 1.0 " | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pm.stats.summary(idata_custom_nuts, var_names=[\"a\", \"b\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "bbf96661-10de-4c3b-bc17-48ecfb49e864", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>mean</th>\n", | |
" <th>sd</th>\n", | |
" <th>hdi_3%</th>\n", | |
" <th>hdi_97%</th>\n", | |
" <th>mcse_mean</th>\n", | |
" <th>mcse_sd</th>\n", | |
" <th>ess_bulk</th>\n", | |
" <th>ess_tail</th>\n", | |
" <th>r_hat</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>a</th>\n", | |
" <td>1.604</td>\n", | |
" <td>0.857</td>\n", | |
" <td>0.266</td>\n", | |
" <td>3.207</td>\n", | |
" <td>0.021</td>\n", | |
" <td>0.016</td>\n", | |
" <td>1653.0</td>\n", | |
" <td>1512.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>b</th>\n", | |
" <td>1.609</td>\n", | |
" <td>0.868</td>\n", | |
" <td>0.279</td>\n", | |
" <td>3.180</td>\n", | |
" <td>0.024</td>\n", | |
" <td>0.018</td>\n", | |
" <td>1441.0</td>\n", | |
" <td>1620.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail \\\n", | |
"a 1.604 0.857 0.266 3.207 0.021 0.016 1653.0 1512.0 \n", | |
"b 1.609 0.868 0.279 3.180 0.024 0.018 1441.0 1620.0 \n", | |
"\n", | |
" r_hat \n", | |
"a 1.0 \n", | |
"b 1.0 " | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pm.stats.summary(idata_blocked_nuts, var_names=[\"a\", \"b\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4352dbdf-adaa-4bf6-bb26-0e840bd8bccb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "pymc-dev", | |
"language": "python", | |
"name": "pymc-dev" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment