Skip to content

Instantly share code, notes, and snippets.

@maedoc
Created October 5, 2023 10:33
Show Gist options
  • Save maedoc/b9a46ed90498c9628802e1bbba23cb8f to your computer and use it in GitHub Desktop.
Save maedoc/b9a46ed90498c9628802e1bbba23cb8f to your computer and use it in GitHub Desktop.
Simple examples of simulation based inference
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "a7a70d0f-b789-4d98-a5b5-ef8030860ec0",
"metadata": {},
"source": [
"# simple simulation based inference\n",
"\n",
"this notebooks explains simulation based inference with simpler methods than the more well known packages like mackelab's. \n",
"## what is sbi actually doing?\n",
"The approach is very general: the simulator generates many samples of pairs of (parameter values, data features), and a Bayesian regression of parameters on data features is done. In a very simple case this could be just least squares (for point esimates) or a simple linear Bayesian model. "
]
},
{
"cell_type": "markdown",
"id": "5b96bac7-6ebc-4dbf-b1af-fce605b4cefd",
"metadata": {},
"source": [
"## setup\n",
"\n",
"in case of least squares we don't need anything other than numpy,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fb5d1b3-af2a-4b6f-9c84-be01af89f3ca",
"metadata": {},
"outputs": [],
"source": [
"%pylab inline\n",
"import tqdm"
]
},
{
"cell_type": "markdown",
"id": "2d2b13a5-145a-4494-9b85-a82686885117",
"metadata": {},
"source": [
"some dimensions of the problem"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "b21c3749-4fff-48d4-bab7-6f8c466aaffc",
"metadata": {},
"outputs": [],
"source": [
"num_params = 10\n",
"num_samples = 1000\n",
"num_features = 5"
]
},
{
"cell_type": "markdown",
"id": "a4b11ea9-d7d2-4fd5-9e1e-c951cee20aea",
"metadata": {},
"source": [
"a simple simulator of a latent linear system where we are going to estimate its initial conditions"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "20f3bb71-738c-45ed-9964-894cc17f0ce0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ -85435.66261261, 308974.00453229, 243658.29502642,\n",
" -215082.5979462 , -388077.38372195])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A = np.random.randn(num_params, num_params)\n",
"G = np.random.randn(num_features, num_params)\n",
"\n",
"def simulator(p): # linear system\n",
" x = p\n",
" for i in range(10):\n",
" x = A @ x\n",
" return G @ x\n",
"\n",
"simulator(np.random.randn(num_params))"
]
},
{
"cell_type": "markdown",
"id": "79b8e9ad-0410-4911-bae0-fffab1ec7d38",
"metadata": {},
"source": [
"sample parameters and run simulations"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "6403a706-e988-44b8-bcdf-c23a6a69e6ac",
"metadata": {},
"outputs": [],
"source": [
"params = np.random.randn(num_samples, num_params)\n",
"features = np.array([simulator(p) for p in params])"
]
},
{
"cell_type": "markdown",
"id": "f93c4df2-da45-41b1-a9ac-a6e37957b7cd",
"metadata": {},
"source": [
"do least squares and compute losses"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "e155f089-2f5a-460b-a096-5aa2d5781ec2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5077.507087203544"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictor, *_ = np.linalg.lstsq(features, params, rcond=None)\n",
"assert predictor.shape == (features.shape[1], num_params)\n",
"params_est = features @ predictor\n",
"loss = np.sum(np.square( params - params_est ))\n",
"loss"
]
},
{
"cell_type": "markdown",
"id": "812ea6cd-3caa-4960-9468-160ad81fdd0d",
"metadata": {},
"source": [
"compare the true and estimated, "
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "371cdefc-f0ed-4163-8612-ac62f675814d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 800x200 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"figure(figsize=(8,2)); subplot(121)\n",
"plot(params.ravel(), params_est.ravel(), ',')\n",
"grid(1), axis('equal'), xlabel('params'), ylabel('estimated params');\n",
"subplot(122), hist(params.ravel() - params_est.ravel()), grid(1), xlabel('z scored true params');"
]
},
{
"cell_type": "markdown",
"id": "e0da43fb-e0c2-4adb-8315-0f89085f3b67",
"metadata": {},
"source": [
"given how simple least squares is, this is already quite ok.\n",
"\n",
"## w/ Bayesian MLE instead\n",
"\n",
"The full SBI techniques use various (deep neural) predictors of parameters of the approximate posterior, but it doesn't work with a plain least squares, since some derivatives are required, so we can try with Jax,"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "8943d961-adc2-4fe6-8817-ccdaf6ae8594",
"metadata": {},
"outputs": [],
"source": [
"import jax, jax.numpy as jp, jax.scipy as js"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "4be32af2-32b0-4911-b5eb-32e5ae3c2eac",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"data": {
"text/plain": [
"Array(-1.7370857, dtype=float32, weak_type=True)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"js.stats.norm.logpdf(0,1,2)"
]
},
{
"cell_type": "markdown",
"id": "3f4e7443-791a-4987-b5ff-3a72e28c25d6",
"metadata": {},
"source": [
"instead of solving a least squares problem directly, we write a loss function with a linear Bayesian model"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "c0cdd2e9-6b45-475f-8237-46f2351c2379",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(127184.44, dtype=float32)"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"j_params = jp.array(params) # num_samples, num_params\n",
"j_features = jp.array(features) # num_samples, num_features\n",
"\n",
"# predict a mean & std per parameter\n",
"j_predictor = jp.ones((2, num_features, num_params))\n",
"\n",
"def loss(j_predictor):\n",
" mu = j_features @ j_predictor[0]\n",
" sd = j_features @ j_predictor[1]\n",
" lp_per = js.stats.norm.logpdf(j_params, mu, sd)\n",
" return -jp.sum(lp_per)\n",
"\n",
"loss(j_predictor)"
]
},
{
"cell_type": "markdown",
"id": "f3670d68-d696-400e-8d8f-c331b071adcb",
"metadata": {},
"source": [
"with a loss defined, we can descend the gradient,"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "e8190a6e-acc7-4b15-99eb-ed9ad2731df7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss=270608.5625: 100%|█████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 81.80it/s]\n"
]
}
],
"source": [
"jit_grad_loss = jax.jit(jax.grad(loss))\n",
"\n",
"num_iter = 10\n",
"lr = 1e-3\n",
"j_predictor = jp.ones((2, num_features, num_params))\n",
"for i in (pbar := tqdm.trange(num_iter)):\n",
" j_predictor = j_predictor - lr * jit_grad_loss(j_predictor)\n",
" pbar.set_description(f'loss={loss(j_predictor)}')"
]
},
{
"cell_type": "markdown",
"id": "a9c38aba-af7d-4cb3-b5bf-0590032e570c",
"metadata": {},
"source": [
"the loss stablizes quickly at ~270000 because there's no nonlinearity and the loss landscape is convex (i.e. easy).\n",
"\n",
"a simple way to evaluate the result is to z-score the true params based on the estimated mu & sd,"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "dbdaa2d8-72a1-4044-810c-5824d6047362",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 800x200 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mu = j_features @ j_predictor[0]\n",
"sd = j_features @ j_predictor[1]\n",
"z = (j_params - mu) / sd\n",
"figure(figsize=(8,2)); hist(z.ravel(), np.r_[-2:2:0.1]), grid(1);"
]
},
{
"cell_type": "markdown",
"id": "76a2b662-9166-416a-92fd-2491ce970bb6",
"metadata": {},
"source": [
"with z scores close to zero, then we can be confident that this has worked."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment