Skip to content

Instantly share code, notes, and snippets.

@fonnesbeck
Created August 11, 2024 18:23
Show Gist options
  • Save fonnesbeck/a9a0b1624324ee2df30ff32c51e452a8 to your computer and use it in GitHub Desktop.
Save fonnesbeck/a9a0b1624324ee2df30ff32c51e452a8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# posteriordb Model Template\n",
"\n",
"Below is an example of how to create a new model in the posterior database, using the eight schools noncentered model as an example. \n",
"\n",
"Everything below assumes that this notebook is in the root of the `posteriordb` repository. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pymc as pm\n",
"import arviz as az\n",
"import numpy as np\n",
"import zipfile\n",
"import json\n",
"\n",
"import sys\n",
"from datetime import datetime\n",
"import pkg_resources\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'J': 8,\n",
" 'y': [28, 8, -3, 7, -1, 1, 18, 12],\n",
" 'sigma': [15, 10, 16, 11, 9, 11, 10, 18]}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Unzip the file\n",
"with zipfile.ZipFile(\"posterior_database/data/data/eight_schools.json.zip\", 'r') as zip_ref:\n",
" zip_ref.extractall(\"posterior_database/data/data/\")\n",
"\n",
"# Load the JSON file into a Python dictionary\n",
"with open(\"posterior_database/data/data/eight_schools.json\", \"r\") as file:\n",
" data = json.load(file)\n",
"\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def model(data):\n",
" y_obs = np.array(data[\"y\"]) # estimated treatment\n",
" sigma = np.array(data[\"sigma\"]) # std of estimated effect\n",
" coords = {\"school\": np.arange(data[\"J\"])}\n",
" with pm.Model(coords=coords) as pymc_model:\n",
"\n",
" mu = pm.Normal(\n",
" \"mu\", mu=0, sigma=5\n",
" ) # hyper-parameter of mean, non-informative prior\n",
" tau = pm.Cauchy(\"tau\", alpha=0, beta=5) # hyper-parameter of sigma\n",
" theta_trans = pm.Normal(\"theta_trans\", mu=0, sigma=1, dims=\"school\")\n",
" theta = pm.Deterministic(\"theta\", mu + tau * theta_trans)\n",
" y = pm.Normal(\"y\", mu=theta, sigma=sigma, observed=y_obs)\n",
" return pymc_model"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"sampler_config = {\n",
" \"chains\": 10,\n",
" \"draws\": 2000,\n",
" \"tune\": 1000,\n",
" \"random_seed\": 4711,\n",
" \"target_accept\": 0.95\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (10 chains in 4 jobs)\n",
"NUTS: [mu, tau, theta_trans]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f6ddb14a4ca14c3d931a58f45babcb1f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 10 chains for 1_000 tune and 2_000 draw iterations (10_000 + 20_000 draws total) took 6 seconds.\n"
]
}
],
"source": [
"with model(data) as eight_schools_model:\n",
" trace = pm.sample(**sampler_config)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>mu</th>\n",
" <td>4.417</td>\n",
" <td>3.327</td>\n",
" <td>-1.791</td>\n",
" <td>10.742</td>\n",
" <td>0.024</td>\n",
" <td>0.018</td>\n",
" <td>19657.0</td>\n",
" <td>14859.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>tau</th>\n",
" <td>0.124</td>\n",
" <td>4.806</td>\n",
" <td>-9.143</td>\n",
" <td>9.313</td>\n",
" <td>0.046</td>\n",
" <td>0.037</td>\n",
" <td>11821.0</td>\n",
" <td>10546.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[0]</th>\n",
" <td>6.252</td>\n",
" <td>5.729</td>\n",
" <td>-3.693</td>\n",
" <td>17.391</td>\n",
" <td>0.045</td>\n",
" <td>0.035</td>\n",
" <td>17136.0</td>\n",
" <td>14445.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[1]</th>\n",
" <td>4.960</td>\n",
" <td>4.701</td>\n",
" <td>-3.705</td>\n",
" <td>14.123</td>\n",
" <td>0.033</td>\n",
" <td>0.026</td>\n",
" <td>19976.0</td>\n",
" <td>17196.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[2]</th>\n",
" <td>3.953</td>\n",
" <td>5.252</td>\n",
" <td>-6.297</td>\n",
" <td>13.432</td>\n",
" <td>0.040</td>\n",
" <td>0.030</td>\n",
" <td>18096.0</td>\n",
" <td>15365.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[3]</th>\n",
" <td>4.714</td>\n",
" <td>4.789</td>\n",
" <td>-4.407</td>\n",
" <td>13.826</td>\n",
" <td>0.034</td>\n",
" <td>0.026</td>\n",
" <td>19727.0</td>\n",
" <td>16743.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[4]</th>\n",
" <td>3.607</td>\n",
" <td>4.671</td>\n",
" <td>-5.728</td>\n",
" <td>11.980</td>\n",
" <td>0.034</td>\n",
" <td>0.026</td>\n",
" <td>19315.0</td>\n",
" <td>16464.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[5]</th>\n",
" <td>4.073</td>\n",
" <td>4.902</td>\n",
" <td>-5.511</td>\n",
" <td>13.221</td>\n",
" <td>0.036</td>\n",
" <td>0.028</td>\n",
" <td>19180.0</td>\n",
" <td>16186.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[6]</th>\n",
" <td>6.287</td>\n",
" <td>5.140</td>\n",
" <td>-2.986</td>\n",
" <td>16.205</td>\n",
" <td>0.038</td>\n",
" <td>0.030</td>\n",
" <td>18853.0</td>\n",
" <td>16037.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>theta[7]</th>\n",
" <td>4.897</td>\n",
" <td>5.301</td>\n",
" <td>-5.293</td>\n",
" <td>14.812</td>\n",
" <td>0.040</td>\n",
" <td>0.032</td>\n",
" <td>18420.0</td>\n",
" <td>15791.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n",
"mu 4.417 3.327 -1.791 10.742 0.024 0.018 19657.0 \n",
"tau 0.124 4.806 -9.143 9.313 0.046 0.037 11821.0 \n",
"theta[0] 6.252 5.729 -3.693 17.391 0.045 0.035 17136.0 \n",
"theta[1] 4.960 4.701 -3.705 14.123 0.033 0.026 19976.0 \n",
"theta[2] 3.953 5.252 -6.297 13.432 0.040 0.030 18096.0 \n",
"theta[3] 4.714 4.789 -4.407 13.826 0.034 0.026 19727.0 \n",
"theta[4] 3.607 4.671 -5.728 11.980 0.034 0.026 19315.0 \n",
"theta[5] 4.073 4.902 -5.511 13.221 0.036 0.028 19180.0 \n",
"theta[6] 6.287 5.140 -2.986 16.205 0.038 0.030 18853.0 \n",
"theta[7] 4.897 5.301 -5.293 14.812 0.040 0.032 18420.0 \n",
"\n",
" ess_tail r_hat \n",
"mu 14859.0 1.0 \n",
"tau 10546.0 1.0 \n",
"theta[0] 14445.0 1.0 \n",
"theta[1] 17196.0 1.0 \n",
"theta[2] 15365.0 1.0 \n",
"theta[3] 16743.0 1.0 \n",
"theta[4] 16464.0 1.0 \n",
"theta[5] 16186.0 1.0 \n",
"theta[6] 16037.0 1.0 \n",
"theta[7] 15791.0 1.0 "
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"post_summary = az.summary(trace, var_names=[\"~trans\"], filter_vars='regex')\n",
"post_summary"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reference posteriors draw info"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"\n",
"info_path = Path(\"posterior_database/reference_posteriors/draws/info\")\n",
"info_filename = \"eight_schools-eight_schools_noncentered.pymc.info.json\"\n",
"\n",
"# Define the information to be included in the info.json file\n",
"info = {\n",
" \"name\": \"eight_schools-eight_schools_noncentered\",\n",
" \"inference\": {\n",
" \"method\": \"pymc_sampling\",\n",
" \"method_arguments\": {\n",
" \"chains\": sampler_config[\"chains\"],\n",
" \"iter\": sampler_config[\"draws\"] + sampler_config[\"tune\"],\n",
" \"warmup\": sampler_config[\"tune\"],\n",
" \"thin\": 1,\n",
" \"seed\": sampler_config[\"random_seed\"],\n",
" \"control\": {\n",
" \"adapt_delta\": sampler_config[\"target_accept\"]\n",
" }\n",
" }\n",
" },\n",
" \"diagnostics\": {\n",
" \"diagnostic_information\": {\n",
" \"names\": [\"mu\", \"tau\", \"theta[0]\", \"theta[1]\", \"theta[2]\", \"theta[3]\", \"theta[4]\", \"theta[5]\", \"theta[6]\", \"theta[7]\"]\n",
" },\n",
" \"ndraws\": sampler_config[\"draws\"] * sampler_config[\"chains\"],\n",
" \"nchains\": sampler_config[\"chains\"],\n",
" \"effective_sample_size_bulk\": post_summary[\"ess_bulk\"].values.tolist(),\n",
" \"effective_sample_size_tail\": post_summary[\"ess_tail\"].values.tolist(),\n",
" \"r_hat\": post_summary[\"r_hat\"].values.tolist(),\n",
" \"divergent_transitions\": [0] * len(post_summary), # Assuming no divergent transitions\n",
" \"expected_fraction_of_missing_information\": [0] * len(post_summary) # Assuming no missing information\n",
" },\n",
" \"checks_made\": {\n",
" \"ndraws_is_10k\": sampler_config[\"draws\"] * sampler_config[\"chains\"] >= 10000,\n",
" \"nchains_is_gte_4\": sampler_config[\"chains\"] >= 4,\n",
" \"ess_within_bounds\": all(ess > 1000 for ess in post_summary[\"ess_bulk\"].values),\n",
" \"r_hat_below_1_01\": all(r_hat < 1.01 for r_hat in post_summary[\"r_hat\"].values),\n",
" \"efmi_above_0_2\": all(efmi > 0.2 for efmi in [0] * len(post_summary)) # Assuming no missing information\n",
" },\n",
" \"comments\": \"\",\n",
" \"added_by\": \"Chris Fonnesbeck\",\n",
" \"added_date\": datetime.now().strftime(\"%Y-%m-%d\"),\n",
" \"versions\": {\n",
" \"pymc_version\": str(pm.__version__),\n",
" \"python_version\": str(sys.version),\n",
" \"python_environment\": str([f\"{package.project_name}=={package.version}\" for package in pkg_resources.working_set])\n",
" }\n",
"}\n",
"\n",
"# Save the info to a JSON file\n",
"with open(info_path / info_filename, \"w\") as file:\n",
" json.dump(info, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reference posterior draws\n",
"\n",
"Converts the posterior draws to lists in a JSON."
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"posterior_draws_path = Path(\"posterior_database/reference_posteriors/draws/posterior_draws/\")\n",
"posterior_draws_filename = \"eight_schools-eight_schools_noncentered.pymc.json\"\n",
"\n",
"posterior_draws = {key: value.values.tolist() for key, value in trace.posterior.items() if key in post_summary.index}\n",
"\n",
"# Export the filtered posterior to a JSON file\n",
"with open(posterior_draws_path / posterior_draws_filename, \"w\") as file:\n",
" json.dump(posterior_draws, file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Mean summary statistics"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"mean_summary_info_path = Path(\"posterior_database/reference_posteriors/summary_statistics/mean/mean/\")\n",
"mean_summary_filename = \"eight_schools-eight_schools_noncentered.pymc.json\"\n",
"\n",
"mean_json = {\n",
" \"names\": post_summary.index.tolist(),\n",
" \"mean\": post_summary[\"mean\"].values.tolist(),\n",
" \"sd\": post_summary[\"sd\"].values.tolist(),\n",
" \"mcse_mean\": post_summary[\"mcse_mean\"].values.tolist()\n",
"}\n",
"\n",
"# Save the mean summary to a JSON file\n",
"mean_summary_path = mean_summary_info_path / \"mean_summary.json\"\n",
"with open(mean_summary_path, \"w\") as file:\n",
" json.dump(mean_json, file)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment