Created
October 18, 2023 16:49
-
-
Save bmorris3/0557c13584c3fe321827eca788f37d02 to your computer and use it in GitHub Desktop.
How to compute leave-one-out cross validation stats for a numpyro (jax) model with a Gaussian process from celerite2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "151c0c61-b90b-4f03-b80f-e78fdb4bcdcc", | |
"metadata": {}, | |
"source": [ | |
"# Compute LOO for models with GPs\n", | |
"\n", | |
"The techniques used in this notebook are explained in [this tutorial for stats folk](https://mc-stan.org/loo/articles/loo2-non-factorizable.html)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b4142a99-bc9f-4812-bbe4-c7c5b7fa8664", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import numpyro\n", | |
"cpu_cores = 8\n", | |
"numpyro.set_host_device_count(cpu_cores)\n", | |
"\n", | |
"import numpyro.distributions as dist\n", | |
"from numpyro.infer import MCMC, NUTS, Predictive\n", | |
"\n", | |
"from jax.config import config\n", | |
"config.update('jax_enable_x64', True)\n", | |
"\n", | |
"from jax import random, numpy as jnp\n", | |
"\n", | |
"from celerite2 import terms as terms_py, GaussianProcess as GaussianProcess_py\n", | |
"from celerite2.jax import terms, GaussianProcess\n", | |
"\n", | |
"import arviz\n", | |
"from arviz.stats.stats_utils import logsumexp as _logsumexp\n", | |
"from arviz.stats.stats import _ic_matrix\n", | |
"\n", | |
"import pandas as pd\n", | |
"import xarray as xr\n", | |
"from corner import corner\n", | |
"from scipy.optimize import minimize\n", | |
"from scipy import stats as st" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "602ad2e3-4360-40a7-beb2-4506754fa17f", | |
"metadata": {}, | |
"source": [ | |
"Generate synthetic data with a GP. I'll use the \"python\" (not JAX) module of celerite2 in (only) this cell, so we can produce a random sample:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "36d7fc22-c630-43b1-a5f8-c098af1abb26", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(42) \n", | |
"\n", | |
"t = np.linspace(0, 100, 1000)\n", | |
"\n", | |
"true_sigma = 0.2\n", | |
"true_rho = 9\n", | |
"true_tau = 100\n", | |
"true_mean = 2\n", | |
"kernel = terms_py.SHOTerm(sigma=true_sigma, rho=true_rho, tau=true_tau)\n", | |
"gp = GaussianProcess_py(kernel, t=t, mean=true_mean)\n", | |
"\n", | |
"yerr = 0.05\n", | |
"y = gp.sample() + np.random.normal(scale=yerr, size=len(t))\n", | |
"\n", | |
"plt.plot(t, y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1d3617fc-76ce-4c83-82cb-82a6cdb6e365", | |
"metadata": {}, | |
"source": [ | |
"Fit the synthetic data with a GP using numpyro+celerite2:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "353e0376-5e74-47a4-a678-c502c9610f9d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def numpyro_model():\n", | |
" # this model looks like the tutorial in celerite2 here:\n", | |
" # https://celerite2.readthedocs.io/en/latest/tutorials/first/#posterior-inference-using-numpyro\n", | |
" \n", | |
" mean = numpyro.sample(\"mean\", dist.Normal(0.0, 1))\n", | |
" log_jitter = numpyro.sample(\"log_jitter\", dist.Normal(-7, 0.5))\n", | |
"\n", | |
" log_sigma = numpyro.sample(\"log_sigma\", dist.Normal(-1, 1))\n", | |
" log_rho = numpyro.sample(\"log_rho\", dist.Normal(2, 2))\n", | |
" log_tau = numpyro.sample(\"log_tau\", dist.Normal(4, 2))\n", | |
" kernel = terms.UnderdampedSHOTerm(\n", | |
" sigma=jnp.exp(log_sigma), rho=jnp.exp(log_rho), tau=jnp.exp(log_tau)\n", | |
" )\n", | |
"\n", | |
" gp = GaussianProcess(kernel, mean=mean)\n", | |
" gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)\n", | |
"\n", | |
" numpyro.sample(\"obs\", gp.numpyro_dist(), obs=y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "34895e72-113d-42cc-bba1-03119fe427eb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"nuts_kernel = NUTS(numpyro_model, dense_mass=True)\n", | |
"mcmc = MCMC(\n", | |
" nuts_kernel,\n", | |
" num_warmup=1000,\n", | |
" num_samples=1000,\n", | |
" num_chains=cpu_cores,\n", | |
" progress_bar=True,\n", | |
")\n", | |
"rng_key = random.PRNGKey(34923)\n", | |
"mcmc.run(rng_key)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b7d04dca-beed-42e4-b1f4-bb3c8ef98eae", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"result = arviz.from_numpyro(mcmc)\n", | |
"\n", | |
"corner(\n", | |
" result, \n", | |
" var_names='log_sigma log_rho log_tau mean'.split(), \n", | |
" truths=np.log([true_sigma, true_rho, true_tau]).tolist() + [true_mean]\n", | |
")\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9fa9201f-066c-46cc-b651-6b30d9a86ad6", | |
"metadata": {}, | |
"source": [ | |
"But note that numpyro is tracking the log likelihood as a single number for all datapoints in the timeseries, rather than pointwise:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2386e929-6e06-4a27-b109-740cea79d298", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"result.log_likelihood" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7cc7764e-ac2a-4375-8f06-e1656e467bbb", | |
"metadata": {}, | |
"source": [ | |
"Here's how we can modify the numpyro model to (optionally) compute the pointwise log likelihood, which allows us to use Leave-One-Out cross validation: " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "740e2f93-81b1-4544-b18f-76aecedfc4e9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def numpyro_model_pointwise(pointwise=False):\n", | |
" mean = numpyro.sample(\"mean\", dist.Normal(0.0, 1))\n", | |
" log_jitter = numpyro.sample(\"log_jitter\", dist.Normal(-7, 0.5))\n", | |
"\n", | |
" log_sigma = numpyro.sample(\"log_sigma\", dist.Normal(-1, 1))\n", | |
" log_rho = numpyro.sample(\"log_rho\", dist.Normal(2, 2))\n", | |
" log_tau = numpyro.sample(\"log_tau\", dist.Normal(4, 2))\n", | |
" kernel = terms.UnderdampedSHOTerm(\n", | |
" sigma=jnp.exp(log_sigma), rho=jnp.exp(log_rho), tau=jnp.exp(log_tau)\n", | |
" )\n", | |
"\n", | |
" gp = GaussianProcess(kernel, mean=mean)\n", | |
" gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)\n", | |
"\n", | |
" numpyro.sample(\"obs\", gp.numpyro_dist(), obs=y)\n", | |
" \n", | |
" if pointwise:\n", | |
" \n", | |
" # if you have a non-uniform mean model (which is independent of the GP),\n", | |
" # you should assign it to `mean_model` here. In the model above, \n", | |
" # the mean model is described by a single parameter `mean`:\n", | |
" mean_model = mean\n", | |
" \n", | |
" diag = yerr ** 2 + jnp.exp(log_jitter) ** 2\n", | |
" K_s = kernel.to_dense(t.flatten(), np.zeros_like(t))\n", | |
" covariance_matrix = K_s + jnp.eye(*K_s.shape) * diag\n", | |
" inv_cov = jnp.linalg.inv(covariance_matrix)\n", | |
"\n", | |
" # https://mc-stan.org/loo/articles/loo2-non-factorizable.html\n", | |
" g_i = inv_cov @ (y - mean_model)\n", | |
" c_ii = jnp.diag(inv_cov)\n", | |
"\n", | |
" lnlike = (\n", | |
" -0.5 * jnp.log(2 * np.pi) + 0.5 * jnp.log(c_ii) -\n", | |
" 0.5 * (g_i**2 / c_ii)\n", | |
" )\n", | |
"\n", | |
" numpyro.deterministic(\"pointwise\", lnlike)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ca0ff5f2-327a-48c5-ad40-52482f5940d0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"nuts_kernel = NUTS(numpyro_model_pointwise, dense_mass=True)\n", | |
"mcmc_pointwise = MCMC(\n", | |
" nuts_kernel,\n", | |
" num_warmup=1000,\n", | |
" num_samples=1000,\n", | |
" num_chains=cpu_cores,\n", | |
" progress_bar=True,\n", | |
")\n", | |
"rng_key = random.PRNGKey(34923)\n", | |
"mcmc_pointwise.run(rng_key)\n", | |
"result_pointwise = arviz.from_numpyro(mcmc_pointwise)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "213d0751-4b1e-4266-bab5-82b3db3444b8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"corner(\n", | |
" result_pointwise, \n", | |
" var_names='log_sigma log_rho log_tau mean'.split(), \n", | |
" truths=np.log([true_sigma, true_rho, true_tau]).tolist() + [true_mean]\n", | |
")\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "15d13fb9-90ea-45a9-98f8-7b3bc8ef24bb", | |
"metadata": {}, | |
"source": [ | |
"The results in `result_pointwise` are no different than in `result` above, since `pointwise=False` by default. But here's where we can make use of that new logic:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "02594df2-03f0-4f47-985c-33fc8845b971", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def recompute_pointwise_lnlike(result_pointwise, n_draws_recompute=150):\n", | |
" n_chains = len(result_pointwise.sample_stats.chain)\n", | |
"\n", | |
" posterior_sample = {\n", | |
" k: result_pointwise.posterior[k].data.ravel() \n", | |
" if len(np.shape(result_pointwise.posterior[k])) == 2 else\n", | |
" result_pointwise.posterior[k].data.reshape(\n", | |
" (result_pointwise.posterior[k].data.shape[0] * \n", | |
" result_pointwise.posterior[k].data.shape[1], -1))\n", | |
" for k in result_pointwise.posterior.keys()\n", | |
" }\n", | |
"\n", | |
" last_n_samples = - n_draws_recompute * n_chains\n", | |
"\n", | |
" pred_kwargs = {key: posterior_sample[key][last_n_samples:] for key in posterior_sample}\n", | |
" pred = Predictive(\n", | |
" numpyro_model_pointwise, \n", | |
" pred_kwargs, \n", | |
" return_sites=[\"pointwise\"],\n", | |
" batch_ndims=1\n", | |
" )\n", | |
"\n", | |
" pointwise_logps = pred(rng_key, pointwise=True)['pointwise']\n", | |
"\n", | |
" n_draws_total = result_pointwise.sample_stats.draw.shape[0]\n", | |
" draws = result_pointwise.sample_stats.draw.draw[-n_draws_recompute:]\n", | |
" posterior = result_pointwise.posterior.sel(draw=draws)\n", | |
"\n", | |
" new_shape = (\n", | |
" n_chains, \n", | |
" pointwise_logps.shape[0] // n_chains, \n", | |
" pointwise_logps.shape[-1]\n", | |
" )\n", | |
"\n", | |
" log_likelihood = xr.Dataset(\n", | |
" {\n", | |
" 'obs': xr.DataArray(\n", | |
" pointwise_logps.reshape(new_shape), \n", | |
" dims=['chain', 'draw', 'obs_dim_0']\n", | |
" )\n", | |
" }\n", | |
" )\n", | |
"\n", | |
" return arviz.InferenceData(log_likelihood=log_likelihood, posterior=posterior)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f1aed95c-076b-4318-9866-1ccfc1beeec4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"result_pointwise_subset = recompute_pointwise_lnlike(result_pointwise)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "417c724d-d405-468d-9250-98d91273f792", | |
"metadata": {}, | |
"source": [ | |
"Now we have computed the pointwise log likelihood for a subset of samples in `result_pointwise`, which we have stored in `result_pointwise_subset`. If you look at the shape of the log likelihood, you'll see it now has a `time` dimension (which corresponds to the time series dimension in the original data `y`)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "23e9eecb-091c-48b9-a8c2-9f703c3dab83", | |
"metadata": {}, | |
"source": [ | |
"Now we can compute the LOO after we compute the effective sample size `reff`:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "696538f0-b3a2-4225-98b8-a4ffd9f871e8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def loo(result_pointwise_subset):\n", | |
" n = np.prod(result_pointwise_subset.posterior['mean'].shape)\n", | |
" reff = arviz.ess(result_pointwise_subset, method='mean').mean() / n\n", | |
"\n", | |
" reff = (\n", | |
" np.hstack([reff[v].values.flatten() for v in reff.data_vars]).mean()\n", | |
" )\n", | |
"\n", | |
" return arviz.loo(result_pointwise_subset, pointwise=True, reff=reff)\n", | |
"\n", | |
"loo_result = loo(result_pointwise_subset)\n", | |
"loo_result" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b845976c-16f6-41f3-8671-47830051e3d0", | |
"metadata": {}, | |
"source": [ | |
"Now that we've computed LOO CV results for the GP model, let's compare the results to a simple strictly sinusoidal model:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "37358117-dc9c-4af1-97c3-e79609374dd0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def numpyro_model_sinusoid():\n", | |
" \"\"\"\n", | |
" this version of the model has a strict sinusoid\n", | |
" as the mean model and no GP, rather than GP with\n", | |
" the SHO kernel above.\n", | |
" \"\"\"\n", | |
" \n", | |
" mean = numpyro.sample(\"mean\", dist.Normal(0.0, 1))\n", | |
" log_jitter = numpyro.sample(\"log_jitter\", dist.Normal(-7, 0.5))\n", | |
"\n", | |
" amp = numpyro.sample(\"amp\", dist.Uniform(0, 0.5))\n", | |
" phase = numpyro.sample(\"phase\", dist.Uniform(0, 2*np.pi))\n", | |
" period = numpyro.sample(\"period\", dist.Uniform(1, 3))\n", | |
" model = amp * jnp.sin(2*np.pi / period - phase) + mean\n", | |
"\n", | |
" numpyro.sample(\n", | |
" \"obs\", \n", | |
" dist.Normal(\n", | |
" loc=model, \n", | |
" scale=yerr + jnp.exp(log_jitter)\n", | |
" ), \n", | |
" obs=y\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3f3187a1-f356-4529-8b16-a392dba7f844", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"nuts_kernel_sin = NUTS(numpyro_model_sinusoid, dense_mass=True)\n", | |
"mcmc_pointwise_sinusoid = MCMC(\n", | |
" nuts_kernel_sin,\n", | |
" num_warmup=1000,\n", | |
" num_samples=1000,\n", | |
" num_chains=cpu_cores,\n", | |
" progress_bar=True,\n", | |
")\n", | |
"rng_key = random.PRNGKey(34923)\n", | |
"mcmc_pointwise_sinusoid.run(rng_key)\n", | |
"result_sinusoid = arviz.from_numpyro(mcmc_pointwise_sinusoid)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d9d9b8a3-b460-46a5-90ca-4ece9f82788e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loo_sinusoid = arviz.loo(result_sinusoid, pointwise=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "137e4048-242c-4da4-9c52-806704fd7d99", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def compare(\n", | |
" dataset_dict, ic='loo', method='stacking', \n", | |
" ascending=False, b_samples=1000, alpha=1, seed=None\n", | |
"):\n", | |
" \"\"\"\n", | |
" This is a modified version of arviz.compare that works \n", | |
" on the LOO outputs generated above.\n", | |
" \"\"\"\n", | |
" scale_value = 1\n", | |
" np.random.seed(seed)\n", | |
" if ic != 'loo': \n", | |
" raise NotImplementedError()\n", | |
" \n", | |
" names = list(dataset_dict.keys())\n", | |
" \n", | |
" ic_se = f\"{ic}_se\"\n", | |
" p_ic = f\"p_{ic}\"\n", | |
" ic_i = f\"{ic}_i\"\n", | |
" scale_col = f\"{ic}_scale\"\n", | |
" df_comp = pd.DataFrame(\n", | |
" index=names,\n", | |
" columns=[\n", | |
" \"rank\",\n", | |
" \"loo\",\n", | |
" \"p_loo\",\n", | |
" \"d_loo\",\n", | |
" \"weight\",\n", | |
" \"se\",\n", | |
" \"dse\",\n", | |
" \"warning\",\n", | |
" \"loo_scale\",\n", | |
" ],\n", | |
" dtype=np.float_,\n", | |
" )\n", | |
" \n", | |
" ics = pd.DataFrame()\n", | |
" names = []\n", | |
" for name, dataset in dataset_dict.items():\n", | |
" names.append(name)\n", | |
" try:\n", | |
" # Here is where the IC function is actually computed -- the rest of this\n", | |
" # function is argument processing and return value formatting\n", | |
" # ics = ics.append([dataset_dict[name]])\n", | |
" ics = pd.concat([ics, pd.DataFrame([dataset_dict[name]])], ignore_index=True)\n", | |
"\n", | |
" except Exception as e:\n", | |
" raise e.__class__(f\"Encountered error trying to compute {ic} from model {name}.\") from e\n", | |
" ics.index = names\n", | |
" ics.sort_values(by=ic, inplace=True, ascending=ascending)\n", | |
" ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())\n", | |
" \n", | |
" \n", | |
" if method.lower() == \"stacking\":\n", | |
" rows, cols, ic_i_val = _ic_matrix(ics, ic_i)\n", | |
" exp_ic_i = np.exp(ic_i_val / scale_value)\n", | |
" km1 = cols - 1\n", | |
"\n", | |
" def w_fuller(weights):\n", | |
" return np.concatenate((weights, [max(1.0 - np.sum(weights), 0.0)]))\n", | |
"\n", | |
" def log_score(weights):\n", | |
" w_full = w_fuller(weights)\n", | |
" score = 0.0\n", | |
" for i in range(rows):\n", | |
" score += np.log(np.dot(exp_ic_i[i], w_full))\n", | |
" return -score\n", | |
"\n", | |
" def gradient(weights):\n", | |
" w_full = w_fuller(weights)\n", | |
" grad = np.zeros(km1)\n", | |
" for k in range(km1):\n", | |
" for i in range(rows):\n", | |
" grad[k] += (exp_ic_i[i, k] - exp_ic_i[i, km1]) / np.dot(exp_ic_i[i], w_full)\n", | |
" return -grad\n", | |
"\n", | |
" theta = np.full(km1, 1.0 / cols)\n", | |
" bounds = [(0.0, 1.0) for _ in range(km1)]\n", | |
" constraints = [\n", | |
" {\"type\": \"ineq\", \"fun\": lambda x: -np.sum(x) + 1.0},\n", | |
" {\"type\": \"ineq\", \"fun\": np.sum},\n", | |
" ]\n", | |
"\n", | |
" weights = minimize(\n", | |
" fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints\n", | |
" )\n", | |
"\n", | |
" weights = w_fuller(weights[\"x\"])\n", | |
" ses = ics[ic_se]\n", | |
"\n", | |
" elif method.lower() == \"bb-pseudo-bma\":\n", | |
" rows, cols, ic_i_val = _ic_matrix(ics, ic_i)\n", | |
" ic_i_val = ic_i_val * rows\n", | |
"\n", | |
" b_weighting = st.dirichlet.rvs(alpha=[alpha] * rows, size=b_samples, random_state=seed)\n", | |
" weights = np.zeros((b_samples, cols))\n", | |
" z_bs = np.zeros_like(weights)\n", | |
" for i in range(b_samples):\n", | |
" z_b = np.dot(b_weighting[i], ic_i_val)\n", | |
" u_weights = np.exp((z_b - np.max(z_b)) / scale_value)\n", | |
" z_bs[i] = z_b # pylint: disable=unsupported-assignment-operation\n", | |
" weights[i] = u_weights / np.sum(u_weights)\n", | |
"\n", | |
" weights = weights.mean(axis=0)\n", | |
" ses = pd.Series(z_bs.std(axis=0), index=names) # pylint: disable=no-member\n", | |
"\n", | |
" elif method.lower() == \"pseudo-bma\":\n", | |
" min_ic = ics.iloc[0][ic]\n", | |
" z_rv = np.exp((ics[ic] - min_ic) / scale_value)\n", | |
" weights = z_rv / np.sum(z_rv)\n", | |
" ses = ics[ic_se]\n", | |
"\n", | |
" if np.any(weights):\n", | |
" min_ic_i_val = ics[ic_i].iloc[0]\n", | |
" for idx, val in enumerate(ics.index):\n", | |
" res = ics.loc[val]\n", | |
" if scale_value < 0:\n", | |
" diff = res[ic_i] - min_ic_i_val\n", | |
" else:\n", | |
" diff = min_ic_i_val - res[ic_i]\n", | |
" d_ic = np.sum(diff)\n", | |
" d_std_err = np.sqrt(len(diff) * np.var(diff))\n", | |
" std_err = ses.loc[val]\n", | |
" weight = weights[idx]\n", | |
" df_comp.loc[val] = (\n", | |
" idx,\n", | |
" res[ic],\n", | |
" res[p_ic],\n", | |
" d_ic,\n", | |
" weight,\n", | |
" std_err,\n", | |
" d_std_err,\n", | |
" res[\"warning\"],\n", | |
" res[scale_col],\n", | |
" )\n", | |
"\n", | |
" df_comp[\"rank\"] = df_comp[\"rank\"].astype(int)\n", | |
" df_comp[\"warning\"] = df_comp[\"warning\"].astype(bool)\n", | |
" return df_comp.sort_values(by=ic, ascending=ascending)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3c61e651-416e-4d89-81c5-8e69db522f3d", | |
"metadata": {}, | |
"source": [ | |
"Here's the result:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "413bb045-dbbe-462b-abea-243a9bc77b43", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"compare({'sinusoid': loo_sinusoid, 'gp': loo_result})" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9b707983-5b3f-4704-9a3a-8462a0fe819e", | |
"metadata": {}, | |
"source": [ | |
"The GP model is preferred!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "01ce72d1-2ab4-4471-b13b-fda9ef0af00c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment