Last active
October 14, 2020 17:14
-
-
Save ahartikainen/0eb924cea21409600ced23881c156dc5 to your computer and use it in GitHub Desktop.
Prior predictive with widgets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import panel as pn\n", | |
"pn.extension()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from cmdstanpy import CmdStanModel\n", | |
"from cmdstanpy.utils import cxx_toolchain_path" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import arviz as az" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import platform\n", | |
"import re" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if platform.system() == \"Windows\":\n", | |
" cxx_toolchain_path();" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"stan_model = \"\"\"\n", | |
"data {\n", | |
" int<lower = 0> N;\n", | |
" vector[N] x;\n", | |
" \n", | |
" // priors\n", | |
" real alpha_mu;\n", | |
" real<lower=0> alpha_sd;\n", | |
" real beta_mu;\n", | |
" real<lower=0> beta_sd;\n", | |
" real sigma_nu;\n", | |
" real sigma_mu;\n", | |
" real sigma_sd;\n", | |
"}\n", | |
"generated quantities {\n", | |
" real alpha = normal_rng(alpha_mu, alpha_sd);\n", | |
" real beta = normal_rng(beta_mu, beta_sd);\n", | |
" real sigma;\n", | |
" for (i in 1:100) {\n", | |
" sigma = student_t_rng(sigma_nu, sigma_mu, sigma_sd);\n", | |
" if (sigma > 0) {\n", | |
" break;\n", | |
" }\n", | |
" }\n", | |
" real y_sim[N] = normal_rng(alpha + beta * x, sigma);\n", | |
"}\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with open(\"./prior_predictive.stan\", \"w\") as f:\n", | |
" print(stan_model, file=f)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%time model = CmdStanModel(stan_file=\"./prior_predictive.stan\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# fake \"real\" x,y;\n", | |
"x = np.sort(np.random.rand(14)) * 10\n", | |
"y = 2.34 * x + np.random.randn(14) + 14.34" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from bokeh.plotting import figure\n", | |
"from bokeh.layouts import gridplot" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def make_plot(idata):\n", | |
" p = figure(width=700, height=400, toolbar_location=\"above\")\n", | |
"\n", | |
" n = 100\n", | |
"\n", | |
" random_sample = np.sort(np.random.choice(idata.prior_predictive.draw, size=n, replace=False))\n", | |
" x_data = idata.constant_data.x.values\n", | |
" y_data = idata.prior_predictive.isel({\"draw\": random_sample}).y_sim.values\n", | |
"\n", | |
" alpha = idata.prior.alpha.isel({\"draw\": random_sample}).values\n", | |
" beta = idata.prior.beta.isel({\"draw\": random_sample}).values\n", | |
" y_sim = alpha + beta * x_data[:, None]\n", | |
"\n", | |
" for i in range(n):\n", | |
" p.circle(x_data, y_data[0, i], fill_color=\"orange\", fill_alpha=0.5, line_color=None)\n", | |
"\n", | |
" for i in range(n):\n", | |
" p.line(x_data, y_sim[:, i], line_color=\"black\", line_alpha=0.3)\n", | |
"\n", | |
"\n", | |
" p.circle(x=idata.constant_data.x_obs.values, y=idata.observed_data.y_obs.values, fill_color=\"red\", fill_alpha=0.9, line_color=None, size=10)\n", | |
" return p" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def analyze(\n", | |
" model,\n", | |
" model_code,\n", | |
" y_obs,\n", | |
" x_obs,\n", | |
" x,\n", | |
" widgets,\n", | |
"):\n", | |
" \n", | |
" @pn.depends(\n", | |
" alpha_mu=alpha_mu, \n", | |
" alpha_sd=alpha_sd, \n", | |
" beta_mu=beta_mu, \n", | |
" beta_sd=beta_sd, \n", | |
" sigma_nu=sigma_nu, \n", | |
" sigma_mu=sigma_mu, \n", | |
" sigma_sd=sigma_sd\n", | |
" )\n", | |
" def analyze_plots(\n", | |
" alpha_mu=0,\n", | |
" alpha_sd=1,\n", | |
" beta_mu=0,\n", | |
" beta_sd=1,\n", | |
" sigma_nu=3,\n", | |
" sigma_mu=0,\n", | |
" sigma_sd=1,\n", | |
" ):\n", | |
" N = len(x)\n", | |
" stan_data = dict(\n", | |
" N=N,\n", | |
" x=x,\n", | |
" alpha_mu=float(alpha_mu),\n", | |
" alpha_sd=float(alpha_sd),\n", | |
" beta_mu=float(beta_mu),\n", | |
" beta_sd=float(beta_sd),\n", | |
" sigma_nu=float(sigma_nu),\n", | |
" sigma_mu=float(sigma_mu),\n", | |
" sigma_sd=float(sigma_sd),\n", | |
" )\n", | |
" fit = model.sample(data=stan_data, iter_sampling=500, fixed_param=True)\n", | |
" idata = az.from_cmdstanpy(\n", | |
" prior=fit, \n", | |
" prior_predictive=\"y_sim\", \n", | |
" observed_data={\"y_obs\": y_obs},\n", | |
" constant_data={\"x_obs\": x_obs, **stan_data}\n", | |
" )\n", | |
"\n", | |
" p_regression = make_plot(idata)\n", | |
" p_pair = gridplot(az.plot_pair(idata.prior, var_names=[\"alpha\", \"beta\", \"sigma\"], backend=\"bokeh\", backend_kwargs={\"width\": 220, \"height\": 220}, show=False).tolist())\n", | |
" p_trace = gridplot(az.plot_trace(idata.prior, var_names=[\"alpha\", \"beta\", \"sigma\"], backend=\"bokeh\", show=False).tolist())\n", | |
" summary_p = az.summary(idata.prior, var_names=[\"alpha\", \"beta\", \"sigma\"], kind=\"stats\")\n", | |
" \n", | |
" model_code_pane = pn.pane.Markdown(\"`\"*3+\"stan\\n\"+re.sub(r\"\\n\", \" \\n\", stan_model)+\"`\"*3)\n", | |
" \n", | |
" return pn.Column(pn.Row(pn.Column(widgets, summary_p, p_regression), model_code_pane), p_trace, p_pair, width=800)\n", | |
" \n", | |
" return analyze_plots" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x_pred = np.linspace(0,10,100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"alpha_mu = pn.widgets.TextInput(name='alpha_mu', value=\"0\", width=80)\n", | |
"alpha_sd = pn.widgets.TextInput(name='alpha_sd', value=\"1\", width=80)\n", | |
"\n", | |
"beta_mu = pn.widgets.TextInput(name='beta_mu', value=\"0\", width=80)\n", | |
"beta_sd = pn.widgets.TextInput(name='beta_sd', value=\"1\", width=80)\n", | |
"\n", | |
"sigma_nu = pn.widgets.TextInput(name='sigma_nu', value=\"3\", width=80)\n", | |
"sigma_mu = pn.widgets.TextInput(name='sigma_mu', value=\"0\", width=80)\n", | |
"sigma_sd = pn.widgets.TextInput(name='sigma_sd', value=\"1\", width=80)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"analysis = pn.panel(analyze(model, stan_model, y, x, x_pred, pn.Row(alpha_mu, alpha_sd, beta_mu, beta_sd, sigma_nu, sigma_mu, sigma_sd)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"analysis.save(\"prior_predictive_panel\", resources=\"inline\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"analysis.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment