Created
August 11, 2024 18:23
-
-
Save fonnesbeck/a9a0b1624324ee2df30ff32c51e452a8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# posteriordb Model Template\n", | |
"\n", | |
"Below is an example of how to create a new model in the posterior database, using the eight schools noncentered model as an example. \n", | |
"\n", | |
"Everything below assumes that this notebook is in the root of the `posteriordb` repository. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pymc as pm\n", | |
"import arviz as az\n", | |
"import numpy as np\n", | |
"import zipfile\n", | |
"import json\n", | |
"\n", | |
"import sys\n", | |
"from datetime import datetime\n", | |
"import pkg_resources\n", | |
"from pathlib import Path" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'J': 8,\n", | |
" 'y': [28, 8, -3, 7, -1, 1, 18, 12],\n", | |
" 'sigma': [15, 10, 16, 11, 9, 11, 10, 18]}" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Unzip the file\n", | |
"with zipfile.ZipFile(\"posterior_database/data/data/eight_schools.json.zip\", 'r') as zip_ref:\n", | |
" zip_ref.extractall(\"posterior_database/data/data/\")\n", | |
"\n", | |
"# Load the JSON file into a Python dictionary\n", | |
"with open(\"posterior_database/data/data/eight_schools.json\", \"r\") as file:\n", | |
" data = json.load(file)\n", | |
"\n", | |
"data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def model(data):\n", | |
" y_obs = np.array(data[\"y\"]) # estimated treatment\n", | |
" sigma = np.array(data[\"sigma\"]) # std of estimated effect\n", | |
" coords = {\"school\": np.arange(data[\"J\"])}\n", | |
" with pm.Model(coords=coords) as pymc_model:\n", | |
"\n", | |
" mu = pm.Normal(\n", | |
" \"mu\", mu=0, sigma=5\n", | |
" ) # hyper-parameter of mean, non-informative prior\n", | |
" tau = pm.Cauchy(\"tau\", alpha=0, beta=5) # hyper-parameter of sigma\n", | |
" theta_trans = pm.Normal(\"theta_trans\", mu=0, sigma=1, dims=\"school\")\n", | |
" theta = pm.Deterministic(\"theta\", mu + tau * theta_trans)\n", | |
" y = pm.Normal(\"y\", mu=theta, sigma=sigma, observed=y_obs)\n", | |
" return pymc_model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sampler_config = {\n", | |
" \"chains\": 10,\n", | |
" \"draws\": 2000,\n", | |
" \"tune\": 1000,\n", | |
" \"random_seed\": 4711,\n", | |
" \"target_accept\": 0.95\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Auto-assigning NUTS sampler...\n", | |
"Initializing NUTS using jitter+adapt_diag...\n", | |
"Multiprocess sampling (10 chains in 4 jobs)\n", | |
"NUTS: [mu, tau, theta_trans]\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f6ddb14a4ca14c3d931a58f45babcb1f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Output()" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
], | |
"text/plain": [] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |
"</pre>\n" | |
], | |
"text/plain": [ | |
"\n" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sampling 10 chains for 1_000 tune and 2_000 draw iterations (10_000 + 20_000 draws total) took 6 seconds.\n" | |
] | |
} | |
], | |
"source": [ | |
"with model(data) as eight_schools_model:\n", | |
" trace = pm.sample(**sampler_config)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>mean</th>\n", | |
" <th>sd</th>\n", | |
" <th>hdi_3%</th>\n", | |
" <th>hdi_97%</th>\n", | |
" <th>mcse_mean</th>\n", | |
" <th>mcse_sd</th>\n", | |
" <th>ess_bulk</th>\n", | |
" <th>ess_tail</th>\n", | |
" <th>r_hat</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>mu</th>\n", | |
" <td>4.417</td>\n", | |
" <td>3.327</td>\n", | |
" <td>-1.791</td>\n", | |
" <td>10.742</td>\n", | |
" <td>0.024</td>\n", | |
" <td>0.018</td>\n", | |
" <td>19657.0</td>\n", | |
" <td>14859.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>tau</th>\n", | |
" <td>0.124</td>\n", | |
" <td>4.806</td>\n", | |
" <td>-9.143</td>\n", | |
" <td>9.313</td>\n", | |
" <td>0.046</td>\n", | |
" <td>0.037</td>\n", | |
" <td>11821.0</td>\n", | |
" <td>10546.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[0]</th>\n", | |
" <td>6.252</td>\n", | |
" <td>5.729</td>\n", | |
" <td>-3.693</td>\n", | |
" <td>17.391</td>\n", | |
" <td>0.045</td>\n", | |
" <td>0.035</td>\n", | |
" <td>17136.0</td>\n", | |
" <td>14445.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[1]</th>\n", | |
" <td>4.960</td>\n", | |
" <td>4.701</td>\n", | |
" <td>-3.705</td>\n", | |
" <td>14.123</td>\n", | |
" <td>0.033</td>\n", | |
" <td>0.026</td>\n", | |
" <td>19976.0</td>\n", | |
" <td>17196.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[2]</th>\n", | |
" <td>3.953</td>\n", | |
" <td>5.252</td>\n", | |
" <td>-6.297</td>\n", | |
" <td>13.432</td>\n", | |
" <td>0.040</td>\n", | |
" <td>0.030</td>\n", | |
" <td>18096.0</td>\n", | |
" <td>15365.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[3]</th>\n", | |
" <td>4.714</td>\n", | |
" <td>4.789</td>\n", | |
" <td>-4.407</td>\n", | |
" <td>13.826</td>\n", | |
" <td>0.034</td>\n", | |
" <td>0.026</td>\n", | |
" <td>19727.0</td>\n", | |
" <td>16743.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[4]</th>\n", | |
" <td>3.607</td>\n", | |
" <td>4.671</td>\n", | |
" <td>-5.728</td>\n", | |
" <td>11.980</td>\n", | |
" <td>0.034</td>\n", | |
" <td>0.026</td>\n", | |
" <td>19315.0</td>\n", | |
" <td>16464.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[5]</th>\n", | |
" <td>4.073</td>\n", | |
" <td>4.902</td>\n", | |
" <td>-5.511</td>\n", | |
" <td>13.221</td>\n", | |
" <td>0.036</td>\n", | |
" <td>0.028</td>\n", | |
" <td>19180.0</td>\n", | |
" <td>16186.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[6]</th>\n", | |
" <td>6.287</td>\n", | |
" <td>5.140</td>\n", | |
" <td>-2.986</td>\n", | |
" <td>16.205</td>\n", | |
" <td>0.038</td>\n", | |
" <td>0.030</td>\n", | |
" <td>18853.0</td>\n", | |
" <td>16037.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>theta[7]</th>\n", | |
" <td>4.897</td>\n", | |
" <td>5.301</td>\n", | |
" <td>-5.293</td>\n", | |
" <td>14.812</td>\n", | |
" <td>0.040</td>\n", | |
" <td>0.032</td>\n", | |
" <td>18420.0</td>\n", | |
" <td>15791.0</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", | |
"mu 4.417 3.327 -1.791 10.742 0.024 0.018 19657.0 \n", | |
"tau 0.124 4.806 -9.143 9.313 0.046 0.037 11821.0 \n", | |
"theta[0] 6.252 5.729 -3.693 17.391 0.045 0.035 17136.0 \n", | |
"theta[1] 4.960 4.701 -3.705 14.123 0.033 0.026 19976.0 \n", | |
"theta[2] 3.953 5.252 -6.297 13.432 0.040 0.030 18096.0 \n", | |
"theta[3] 4.714 4.789 -4.407 13.826 0.034 0.026 19727.0 \n", | |
"theta[4] 3.607 4.671 -5.728 11.980 0.034 0.026 19315.0 \n", | |
"theta[5] 4.073 4.902 -5.511 13.221 0.036 0.028 19180.0 \n", | |
"theta[6] 6.287 5.140 -2.986 16.205 0.038 0.030 18853.0 \n", | |
"theta[7] 4.897 5.301 -5.293 14.812 0.040 0.032 18420.0 \n", | |
"\n", | |
" ess_tail r_hat \n", | |
"mu 14859.0 1.0 \n", | |
"tau 10546.0 1.0 \n", | |
"theta[0] 14445.0 1.0 \n", | |
"theta[1] 17196.0 1.0 \n", | |
"theta[2] 15365.0 1.0 \n", | |
"theta[3] 16743.0 1.0 \n", | |
"theta[4] 16464.0 1.0 \n", | |
"theta[5] 16186.0 1.0 \n", | |
"theta[6] 16037.0 1.0 \n", | |
"theta[7] 15791.0 1.0 " | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"post_summary = az.summary(trace, var_names=[\"~trans\"], filter_vars='regex')\n", | |
"post_summary" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Reference posteriors draw info" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"info_path = Path(\"posterior_database/reference_posteriors/draws/info\")\n", | |
"info_filename = \"eight_schools-eight_schools_noncentered.pymc.info.json\"\n", | |
"\n", | |
"# Define the information to be included in the info.json file\n", | |
"info = {\n", | |
" \"name\": \"eight_schools-eight_schools_noncentered\",\n", | |
" \"inference\": {\n", | |
" \"method\": \"pymc_sampling\",\n", | |
" \"method_arguments\": {\n", | |
" \"chains\": sampler_config[\"chains\"],\n", | |
" \"iter\": sampler_config[\"draws\"] + sampler_config[\"tune\"],\n", | |
" \"warmup\": sampler_config[\"tune\"],\n", | |
" \"thin\": 1,\n", | |
" \"seed\": sampler_config[\"random_seed\"],\n", | |
" \"control\": {\n", | |
" \"adapt_delta\": sampler_config[\"target_accept\"]\n", | |
" }\n", | |
" }\n", | |
" },\n", | |
" \"diagnostics\": {\n", | |
" \"diagnostic_information\": {\n", | |
" \"names\": [\"mu\", \"tau\", \"theta[0]\", \"theta[1]\", \"theta[2]\", \"theta[3]\", \"theta[4]\", \"theta[5]\", \"theta[6]\", \"theta[7]\"]\n", | |
" },\n", | |
" \"ndraws\": sampler_config[\"draws\"] * sampler_config[\"chains\"],\n", | |
" \"nchains\": sampler_config[\"chains\"],\n", | |
" \"effective_sample_size_bulk\": post_summary[\"ess_bulk\"].values.tolist(),\n", | |
" \"effective_sample_size_tail\": post_summary[\"ess_tail\"].values.tolist(),\n", | |
" \"r_hat\": post_summary[\"r_hat\"].values.tolist(),\n", | |
" \"divergent_transitions\": [0] * len(post_summary), # Assuming no divergent transitions\n", | |
" \"expected_fraction_of_missing_information\": [0] * len(post_summary) # Assuming no missing information\n", | |
" },\n", | |
" \"checks_made\": {\n", | |
" \"ndraws_is_10k\": sampler_config[\"draws\"] * sampler_config[\"chains\"] >= 10000,\n", | |
" \"nchains_is_gte_4\": sampler_config[\"chains\"] >= 4,\n", | |
" \"ess_within_bounds\": all(ess > 1000 for ess in post_summary[\"ess_bulk\"].values),\n", | |
" \"r_hat_below_1_01\": all(r_hat < 1.01 for r_hat in post_summary[\"r_hat\"].values),\n", | |
" \"efmi_above_0_2\": all(efmi > 0.2 for efmi in [0] * len(post_summary)) # Assuming no missing information\n", | |
" },\n", | |
" \"comments\": \"\",\n", | |
" \"added_by\": \"Chris Fonnesbeck\",\n", | |
" \"added_date\": datetime.now().strftime(\"%Y-%m-%d\"),\n", | |
" \"versions\": {\n", | |
" \"pymc_version\": str(pm.__version__),\n", | |
" \"python_version\": str(sys.version),\n", | |
" \"python_environment\": str([f\"{package.project_name}=={package.version}\" for package in pkg_resources.working_set])\n", | |
" }\n", | |
"}\n", | |
"\n", | |
"# Save the info to a JSON file\n", | |
"with open(info_path / info_filename, \"w\") as file:\n", | |
" json.dump(info, file)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Reference posterior draws\n", | |
"\n", | |
"Converts the posterior draws to lists in a JSON." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"posterior_draws_path = Path(\"posterior_database/reference_posteriors/draws/posterior_draws/\")\n", | |
"posterior_draws_filename = \"eight_schools-eight_schools_noncentered.pymc.json\"\n", | |
"\n", | |
"posterior_draws = {key: value.values.tolist() for key, value in trace.posterior.items() if key in post_summary.index}\n", | |
"\n", | |
"# Export the filtered posterior to a JSON file\n", | |
"with open(posterior_draws_path / posterior_draws_filename, \"w\") as file:\n", | |
" json.dump(posterior_draws, file)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Mean summary statistics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mean_summary_info_path = Path(\"posterior_database/reference_posteriors/summary_statistics/mean/mean/\")\n", | |
"mean_summary_filename = \"eight_schools-eight_schools_noncentered.pymc.json\"\n", | |
"\n", | |
"mean_json = {\n", | |
" \"names\": post_summary.index.tolist(),\n", | |
" \"mean\": post_summary[\"mean\"].values.tolist(),\n", | |
" \"sd\": post_summary[\"sd\"].values.tolist(),\n", | |
" \"mcse_mean\": post_summary[\"mcse_mean\"].values.tolist()\n", | |
"}\n", | |
"\n", | |
"# Save the mean summary to a JSON file\n", | |
"mean_summary_path = mean_summary_info_path / \"mean_summary.json\"\n", | |
"with open(mean_summary_path, \"w\") as file:\n", | |
" json.dump(mean_json, file)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "pymc-dev", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment