Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active December 5, 2024 12:30
Show Gist options
  • Save ricardoV94/6198899c00b7562f4547a898bd4140dd to your computer and use it in GitHub Desktop.
Save ricardoV94/6198899c00b7562f4547a898bd4140dd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d5e93528-fbba-4b7b-b2cb-ba345aba42f8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [a, b, p]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9d492e469e4a4fe5a713b37c395f139d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [a, b, p]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "abf539ffd43d4e4d9be83837bbaddac1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"CompoundStep\n",
">NUTS: [a, b]\n",
">NUTS: [p]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d2fe145c02349439458316fad6e667d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4 seconds.\n"
]
}
],
"source": [
"import pymc as pm\n",
"\n",
"with pm.Model() as m:\n",
" a = pm.Gamma(\"a\", mu=1, sigma=1)\n",
" b = pm.Gamma(\"b\", mu=1, sigma=1)\n",
" p = pm.Beta(\"p\", a, b, shape=(4,))\n",
" y = pm.Binomial(\"y\", n=100, p=p, observed=[50, 25, 75, 50])\n",
" idata_default_nuts = pm.sample()\n",
" idata_custom_nuts = pm.sample(step=[pm.NUTS([a, b, p])])\n",
" idata_blocked_nuts = pm.sample(step=[pm.NUTS([a, b]), pm.NUTS(p)])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2edd46d6-2b87-41b5-b40d-373c38309ccd",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>a</th>\n",
" <td>1.597</td>\n",
" <td>0.809</td>\n",
" <td>0.344</td>\n",
" <td>3.115</td>\n",
" <td>0.013</td>\n",
" <td>0.010</td>\n",
" <td>3813.0</td>\n",
" <td>2764.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>b</th>\n",
" <td>1.603</td>\n",
" <td>0.832</td>\n",
" <td>0.296</td>\n",
" <td>3.139</td>\n",
" <td>0.014</td>\n",
" <td>0.011</td>\n",
" <td>3749.0</td>\n",
" <td>2565.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail \\\n",
"a 1.597 0.809 0.344 3.115 0.013 0.010 3813.0 2764.0 \n",
"b 1.603 0.832 0.296 3.139 0.014 0.011 3749.0 2565.0 \n",
"\n",
" r_hat \n",
"a 1.0 \n",
"b 1.0 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.stats.summary(idata_default_nuts, var_names=[\"a\", \"b\"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "007ac8e6-cb22-4d8b-8919-937234961857",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>a</th>\n",
" <td>1.588</td>\n",
" <td>0.836</td>\n",
" <td>0.305</td>\n",
" <td>3.177</td>\n",
" <td>0.014</td>\n",
" <td>0.011</td>\n",
" <td>3563.0</td>\n",
" <td>2629.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>b</th>\n",
" <td>1.600</td>\n",
" <td>0.853</td>\n",
" <td>0.351</td>\n",
" <td>3.186</td>\n",
" <td>0.015</td>\n",
" <td>0.012</td>\n",
" <td>3399.0</td>\n",
" <td>2289.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail \\\n",
"a 1.588 0.836 0.305 3.177 0.014 0.011 3563.0 2629.0 \n",
"b 1.600 0.853 0.351 3.186 0.015 0.012 3399.0 2289.0 \n",
"\n",
" r_hat \n",
"a 1.0 \n",
"b 1.0 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.stats.summary(idata_custom_nuts, var_names=[\"a\", \"b\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bbf96661-10de-4c3b-bc17-48ecfb49e864",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>a</th>\n",
" <td>1.604</td>\n",
" <td>0.857</td>\n",
" <td>0.266</td>\n",
" <td>3.207</td>\n",
" <td>0.021</td>\n",
" <td>0.016</td>\n",
" <td>1653.0</td>\n",
" <td>1512.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>b</th>\n",
" <td>1.609</td>\n",
" <td>0.868</td>\n",
" <td>0.279</td>\n",
" <td>3.180</td>\n",
" <td>0.024</td>\n",
" <td>0.018</td>\n",
" <td>1441.0</td>\n",
" <td>1620.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail \\\n",
"a 1.604 0.857 0.266 3.207 0.021 0.016 1653.0 1512.0 \n",
"b 1.609 0.868 0.279 3.180 0.024 0.018 1441.0 1620.0 \n",
"\n",
" r_hat \n",
"a 1.0 \n",
"b 1.0 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.stats.summary(idata_blocked_nuts, var_names=[\"a\", \"b\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4352dbdf-adaa-4bf6-bb26-0e840bd8bccb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc-dev",
"language": "python",
"name": "pymc-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment