Created
June 9, 2023 14:44
-
-
Save ricardoV94/1038c2c45a9acfd081654a2e64e757b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "1d614fbd", | |
"metadata": {}, | |
"source": [ | |
"# Automatic probabiltity\n", | |
"\n", | |
"Slides for the accompanying talk: [link](https://docs.google.com/presentation/d/1xLvEOdnEC2nZ0jqBSbssP2vWsiNHHmMX-kprhuNH5Cg/edit?usp=sharing)\n", | |
"\n", | |
"Source code: [link](https://github.com/pymc-devs/pymc/tree/main/pymc/logprob)\n", | |
"\n", | |
"Versioned source code: [link](https://github.com/pymc-devs/pymc/tree/2ac88afa4212dcbeaf9471c6f54c7f50b5a3db53/pymc/logprob)\n", | |
"\n", | |
"Previous related talk: [link](https://www.youtube.com/watch?v=_APNiXTfYJw)\n", | |
"\n", | |
"Relevant documentation:\n", | |
"* PyMC and PyTensor: [link](https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_pytensor.html)\n", | |
"* CustomDist: [link](https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.CustomDist.html)\n", | |
"\n", | |
"Compatible PyMC version: 5.5.0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "0fc1322a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pymc as pm\n", | |
"import pytensor\n", | |
"import numpy as np\n", | |
"\n", | |
"import pytensor.tensor as pt\n", | |
"\n", | |
"pytensor.config.mode = \"FAST_COMPILE\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ae0cff50", | |
"metadata": {}, | |
"source": [ | |
"# Single random variable transformation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "db7aede3", | |
"metadata": {}, | |
"source": [ | |
"## 1-to-1 transformations" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a2388f43", | |
"metadata": {}, | |
"source": [ | |
"### Linear shift" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "1c797a5d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array(-1.10557431), array(3.89442569))" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = pm.Normal.dist()\n", | |
"z = x + 5\n", | |
"\n", | |
"x_draw, z_draw = pm.draw([x, z])\n", | |
"x_draw, z_draw" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "3a2f9cbf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(0.21651709)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def prob(rv, value):\n", | |
" return pm.logp(rv, value).exp()\n", | |
"\n", | |
"prob(x, x_draw).eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "ad1c43a2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(0.21651709)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob(z, z_draw).eval()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5ade8be6", | |
"metadata": {}, | |
"source": [ | |
"### Exponentiation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "2b13f751", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array(-0.28362562), array(0.75304852))" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = pm.Normal.dist()\n", | |
"z = pt.exp(x)\n", | |
"\n", | |
"x_draw, z_draw = pm.draw([x, z])\n", | |
"x_draw, z_draw" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1b5b6c9b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(0.38321454)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob(x, x_draw).eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "13ef22b7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(0.50888427)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob(z, z_draw).eval()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "834f9841", | |
"metadata": {}, | |
"source": [ | |
"### Inverse" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "a6b60ffd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([3.96040008e-09, 1.41976891e-01])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = pm.Normal.dist(shape=(2,))\n", | |
"z = 1 / x\n", | |
"\n", | |
"value = np.array([-0.15, 1.5])\n", | |
"prob(z, value).eval()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "97f1f260", | |
"metadata": {}, | |
"source": [ | |
"### Others" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "49bcac15", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.24197072451914337\n", | |
"0.48394144903828673\n", | |
"0.07820853879509117\n", | |
"0.026958231758816044\n", | |
"0.39614036832836785\n" | |
] | |
} | |
], | |
"source": [ | |
"x = pm.Normal.dist()\n", | |
"\n", | |
"print(\n", | |
" prob(x ** 2, 1).eval(),\n", | |
" prob(pt.sqrt(x), 1).eval(),\n", | |
" prob(x * 5, 1).eval(),\n", | |
" prob(pt.log(x), 1).eval(),\n", | |
" prob(pt.erf(x), 0.5).eval(),\n", | |
" sep=\"\\n\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e7f3c223", | |
"metadata": {}, | |
"source": [ | |
"## Many-to-1 transformations" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "105c8bd0", | |
"metadata": {}, | |
"source": [ | |
"### Abs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "39a7a3c3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = pm.Normal.dist(0, 1)\n", | |
"z = pt.abs(x)\n", | |
"\n", | |
"np.isclose(\n", | |
" prob(x, 0.3).eval(),\n", | |
" (prob(z, 0.3) / 2).eval(),\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "577cbfad", | |
"metadata": {}, | |
"source": [ | |
"### Clipping" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "03cb9b86", | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"x=-0.84 → z=-0.84 \n", | |
"x=-1.16 → z=-1.00 (clipped)\n", | |
"x=-0.76 → z=-0.76 \n", | |
"x=0.61 → z=0.61 \n", | |
"x=2.07 → z=2.07 \n", | |
"x=1.24 → z=1.24 \n", | |
"x=0.90 → z=0.90 \n", | |
"x=-0.89 → z=-0.89 \n", | |
"x=-0.14 → z=-0.14 \n", | |
"x=-0.78 → z=-0.78 \n" | |
] | |
} | |
], | |
"source": [ | |
"x = pm.Normal.dist()\n", | |
"z = pt.clip(x, -1, np.inf)\n", | |
"\n", | |
"draws = pm.draw([x, z], draws=10)\n", | |
"\n", | |
"for x_draw, z_draw in zip(*draws):\n", | |
" print(\n", | |
" f\"x={x_draw: <5.2f} → z={z_draw:.2f}\"\n", | |
" f\" {'(clipped)' if x_draw < -1 else ''}\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "e6a7ec97", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[0.15865525 0.35206533 0.02275013]\n", | |
"[0.15865525 0.35206533 0.02275013]\n" | |
] | |
} | |
], | |
"source": [ | |
"rv = pm.Normal.dist(shape=(3,))\n", | |
"\n", | |
"clipped_rv = pt.clip(rv, -1, 2)\n", | |
"censored_rv = pm.Censored.dist(rv, lower=-1, upper=2)\n", | |
"\n", | |
"clipped_value = [-1, 0.5, 2]\n", | |
"print(\n", | |
" prob(clipped_rv, clipped_value).eval(),\n", | |
" prob(censored_rv, clipped_value).eval(),\n", | |
" sep=\"\\n\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4ab7ba28", | |
"metadata": {}, | |
"source": [ | |
"### Others" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "d7b7543e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.24197072451914337\n", | |
"0.24173033519680465\n", | |
"0.3413447502096956\n", | |
"0.1359051181327496\n" | |
] | |
} | |
], | |
"source": [ | |
"rv = pm.Normal.dist(1)\n", | |
"\n", | |
"print(\n", | |
" prob(rv, 0).eval(),\n", | |
" prob(pt.round(rv), 0).eval(),\n", | |
" prob(pt.floor(rv), 0).eval(),\n", | |
" prob(pt.ceil(rv), 0).eval(),\n", | |
" sep=\"\\n\",\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6b8d9b78", | |
"metadata": {}, | |
"source": [ | |
"## Chained transformations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "6b549a5f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(0.19394715)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = pm.Laplace.dist(0, 1)\n", | |
"y = x + 2\n", | |
"z = pt.abs(y)\n", | |
"\n", | |
"prob(z, 0.9).eval()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "edb0473c", | |
"metadata": {}, | |
"source": [ | |
"# Multiple random variable transformations" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "51444a8a", | |
"metadata": {}, | |
"source": [ | |
"## Conditional probability" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "67f828c9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(1.30169683), array(0.03433906), array(1.26735777)]" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = pm.Normal.dist()\n", | |
"y = pm.Normal.dist()\n", | |
"z = x - y\n", | |
"\n", | |
"pm.draw([x, y, z])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "bd697b50", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pymc.logprob.basic import conditional_logp\n", | |
"\n", | |
"def conditional_prob(rvs_to_values: dict, fn=True):\n", | |
" logps = conditional_logp(rvs_to_values).values()\n", | |
" probs = [logp.exp() for logp in logps]\n", | |
" if fn:\n", | |
" values = list(rvs_to_values.values())\n", | |
" return pytensor.function(values, probs)\n", | |
" else:\n", | |
" return probs\n", | |
" \n", | |
"x_value = pt.scalar(\"x_value\")\n", | |
"y_value = pt.scalar(\"y_value\")\n", | |
"z_value = pt.scalar(\"z_value\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "e0076461", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(0.35206533), array(0.26608525)]" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob_fn = conditional_prob({x: x_value, y: y_value})\n", | |
"prob_fn(x_value=0.5, y_value=0.9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "c4fd46ac", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(0.35206533), array(0.26608525)]" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob_fn = conditional_prob({x: x_value, z: z_value})\n", | |
"prob_fn(x_value=0.5, z_value=0.5+0.9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "92794504", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(0.26608525), array(0.35206533)]" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob_fn = conditional_prob({y: y_value, z: z_value})\n", | |
"prob_fn(y_value=0.9, z_value=0.5-0.9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "8511c691", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"RuntimeError: The logprob terms of the following value variables could not be derived: {z_value}\n" | |
] | |
} | |
], | |
"source": [ | |
"try:\n", | |
" prob_fn = conditional_prob({z: z_value})\n", | |
"except RuntimeError as err:\n", | |
" print(f\"RuntimeError: {err}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "7f42dd87", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"RuntimeError: The logprob terms of the following value variables could not be derived: {z_value}\n" | |
] | |
} | |
], | |
"source": [ | |
"try:\n", | |
" prob_fn = conditional_prob({x: x_value, y: y_value, z: z_value})\n", | |
"except RuntimeError as err:\n", | |
" print(f\"RuntimeError: {err}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ad65290b", | |
"metadata": {}, | |
"source": [ | |
"**Note:** Once we condition other variables, the last example is just a single variable logp expression" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f0d9c89a", | |
"metadata": {}, | |
"source": [ | |
"## Control flow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "3c3f7f3c", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"switch=0 → x=0.91\n", | |
"switch=0 → x=1.05\n", | |
"switch=1 → x=-0.45\n", | |
"switch=0 → x=1.96\n", | |
"switch=1 → x=-2.03\n" | |
] | |
} | |
], | |
"source": [ | |
"from pytensor.ifelse import ifelse\n", | |
"\n", | |
"switch = pm.Bernoulli.dist(p=0.7)\n", | |
"x1 = pm.Normal.dist(-1)\n", | |
"x2 = pm.Laplace.dist(1, 1)\n", | |
"x = ifelse(switch, x1, x2)\n", | |
"\n", | |
"for switch_draw, x_draw in zip(*pm.draw([switch, x], draws=5)):\n", | |
" print(f\"switch={switch_draw} → x={x_draw:.2f}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "2584d27f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"switch_value = pt.scalar(\"switch_value\", dtype=int)\n", | |
"x_value = pt.scalar(\"x_value\", dtype=float)\n", | |
"prob_fn = conditional_prob({switch: switch_value, x: x_value})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "481da31e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(0.7), array(0.39695255)]" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob_fn(switch_value=1, x_value=-0.9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "3855115b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(0.3), array(0.07478431)]" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob_fn(switch_value=0, x_value=-0.9)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "28a2d551", | |
"metadata": {}, | |
"source": [ | |
"## Stacking operations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "f29f8e0f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([1., 1., 1.])]" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x0 = pm.Uniform.dist(0, 1)\n", | |
"x1 = pm.Uniform.dist(x0, x0+1)\n", | |
"x2 = pm.Uniform.dist(x1, x1+1)\n", | |
"xs = pt.stack([x0, x1, x2])\n", | |
"\n", | |
"xs_values = pt.vector(\"xs\", shape=(3,))\n", | |
"conditional_prob_fn = conditional_prob({xs: xs_values})\n", | |
"\n", | |
"conditional_prob_fn([0.5, 1.5, 2.5])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "1256d2bd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([1., 1., 0.])]" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"conditional_prob_fn([0.5, 1.5, 0.5])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "a52245d8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([-0.03677558, 1.32534683, 1.25944947, 0.96523192, 2.29075459])" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mu = 0\n", | |
"sigma = 1\n", | |
"\n", | |
"y0 = np.array(0.5)\n", | |
"ys = []\n", | |
"y_tm1 = y0\n", | |
"for i in range(5):\n", | |
" y = y_tm1 + pm.Normal.dist(mu, sigma)\n", | |
" ys = pt.concatenate([ys, [y]])\n", | |
" y_tm1 = y\n", | |
" \n", | |
"pm.draw(ys)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "48146b70", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"62" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from pytensor.graph.basic import ancestors\n", | |
"len(list(ancestors(ys)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "c09289b5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(1.43436108),\n", | |
" array(0.48492827),\n", | |
" array([1.79509336, 2.27295269, 3.73986879, 5.14309914, 6.35157192])]" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from pymc.pytensorf import collect_default_updates\n", | |
"\n", | |
"mu = pm.Normal.dist()\n", | |
"sigma = pm.HalfNormal.dist()\n", | |
"y0 = np.array(0.5)\n", | |
"\n", | |
"def rw_step(*args):\n", | |
" y_tm1, mu, sigma = args\n", | |
" y = y_tm1 + pm.Normal.dist(mu=mu, sigma=sigma)\n", | |
" return y, collect_default_updates(inputs=args, outputs=[y])\n", | |
"\n", | |
"ys, _ = pytensor.scan(\n", | |
" fn=rw_step,\n", | |
" outputs_info=[y0],\n", | |
" non_sequences=[mu, sigma],\n", | |
" n_steps=5,\n", | |
")\n", | |
"\n", | |
"pm.draw([mu, sigma, ys])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "906c79bd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array(0.24197072),\n", | |
" array(0.10798193),\n", | |
" array([0.19333406, 0.19947114, 0.19947114, 0.00043634, 0.19947114])]" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mu_value = pt.scalar(\"mu_value\")\n", | |
"sigma_value = pt.scalar(\"sigma_value\")\n", | |
"ys_values = pt.vector(\"ys_values\")\n", | |
"\n", | |
"prob_fn = conditional_prob({mu: mu_value, sigma: sigma_value, ys: ys_values})\n", | |
"\n", | |
"prob_fn(mu_value=1, sigma_value=2, ys_values=[1, 2, 3, -3, -2])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2b7d577e", | |
"metadata": {}, | |
"source": [ | |
"# Extras" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3ab72c17", | |
"metadata": {}, | |
"source": [ | |
"## CustomDist" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "5ace2829", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([-1.93039708, 4.40054909, -3.67255424, 2.06813603, -4.47753441])" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def dist1(loc, lam, size):\n", | |
" return loc + pm.Exponential.dist(lam, shape=size)\n", | |
" \n", | |
"def dist2(alpha, beta, lower, upper, size):\n", | |
" range_ = upper - lower\n", | |
" return pm.Beta.dist(alpha, beta, shape=size) * (range_) + lower\n", | |
" \n", | |
"comp1 = pm.CustomDist.dist(-5, 1, dist=dist1)\n", | |
"comp2 = pm.CustomDist.dist(1, 1, -5, 5, dist=dist2)\n", | |
"mix = pm.Mixture.dist([0.3, 0.7], comp_dists=[comp1, comp2], shape=(5,))\n", | |
"pm.draw(mix)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "5780cc39", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(0.00408677)" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob(comp1, 0.5).eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "c89059c1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0.08493612, 0.07549469, 0.07202138, 0.07074363, 0.07027356])" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prob(mix, [-2, -1, 0, 1, 2]).eval()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2a23c513", | |
"metadata": {}, | |
"source": [ | |
"## Indexing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "2aeb18f3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Monty Hall\n", | |
"from pytensor.ifelse import ifelse\n", | |
"\n", | |
"first_pick = pt.scalar(\"first_pick\", dtype=int)\n", | |
"correct = pm.Categorical.dist([1/3, 1/3, 1/3])\n", | |
"opened_door = ( \n", | |
" ifelse(\n", | |
" pt.eq(first_pick, 0),\n", | |
" pt.stack([\n", | |
" pm.Categorical.dist([0, 1/2, 1/2]), # correct = 0\n", | |
" pm.Categorical.dist([0, 0, 1]), # correct = 1\n", | |
" pm.Categorical.dist([0, 1, 0]), # correct = 2\n", | |
" ])[correct],\n", | |
" # else(first pick != 0)\n", | |
" ifelse( \n", | |
" pt.eq(first_pick, 1),\n", | |
" pt.stack([\n", | |
" pm.Categorical.dist([0, 0, 1]), # correct = 0\n", | |
" pm.Categorical.dist([1/2, 0, 1/2]), # correct = 1\n", | |
" pm.Categorical.dist([1, 0, 0]), # correct = 2\n", | |
" ])[correct],\n", | |
" # else (first_pick == 2)\n", | |
" pt.stack([\n", | |
" pm.Categorical.dist([0, 1, 0]), # correct = 0\n", | |
" pm.Categorical.dist([1, 0, 0]), # correct = 1\n", | |
" pm.Categorical.dist([1/2, 1/2, 0]), # correct = 2\n", | |
" ])[correct], \n", | |
" ), \n", | |
" )\n", | |
")\n", | |
"\n", | |
"correct_value = pt.scalar(\"correct\", dtype=int)\n", | |
"opened_door_value = pt.scalar(\"opened_door\", dtype=int)\n", | |
"\n", | |
"prob_correct, prob_opened_door = conditional_prob({correct: correct_value, opened_door: opened_door_value}, fn=False)\n", | |
"total_prob = prob_correct * prob_opened_door" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "0ba756ec", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.16666666666666666\n", | |
"0.0\n", | |
"0.3333333333333333\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/ricardo/miniconda3/envs/pymc/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:779: RuntimeWarning: divide by zero encountered in log\n", | |
" variables = ufunc(*ufunc_args, **ufunc_kwargs)\n" | |
] | |
} | |
], | |
"source": [ | |
"for c in (0, 1, 2):\n", | |
" print(total_prob.eval({first_pick: 0, opened_door_value: 1, correct_value:c}))" | |
] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "pymc", | |
"language": "python", | |
"name": "pymc" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.8" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment