Last active
November 8, 2022 18:37
-
-
Save ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "7205a740", | |
"metadata": {}, | |
"source": [ | |
"# Marginalizing discrete RVs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "6a238994", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import aeppl\n", | |
"import aesara\n", | |
"import aesara.tensor as at\n", | |
"from aesara.graph import FunctionGraph\n", | |
"from aesara.compile.builders import OpFromGraph\n", | |
"import numpy as np\n", | |
"\n", | |
"import pymc as pm" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e05a3823", | |
"metadata": {}, | |
"source": [ | |
"## Marginalizing a single RV" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "cf25b739", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with pm.Model() as m:\n", | |
" p = pm.Dirichlet(\"p\", [1, 1])\n", | |
" x = pm.Categorical(\"x\", p=p)\n", | |
" y = pm.Normal(\"y\", pm.math.stack([-1, 1])[x], 1, observed=1) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "3f94e5a6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"p_vv = m.rvs_to_values[p]\n", | |
"x_vv = m.rvs_to_values[x]\n", | |
"logp = m.logp()\n", | |
"logp_op = OpFromGraph([p_vv, x_vv], [logp], inline=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "d1a4e4f7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OpFromGraph{inline=True}.0" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"logp_op(p_vv, x_vv)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "6f626312", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'p_simplex__': array([0.]), 'x': array(0)}" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"m.initial_point()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "872648d2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array(-4.30523289), array(-2.30523289))" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"logp_op(np.array([0]), 0).eval(), logp_op(np.array([0]), 1).eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "7367bbe3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x_domain = range(2) # Possible values of the categorical\n", | |
"marginal_logp = at.logsumexp([logp_op(p_vv, x_vv_const) for x_vv_const in x_domain])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "7df0efd4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-2.17830488)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"marginal_logp.eval({p_vv: np.array([0])})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "35e97418", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-2.17830488)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with pm.Model() as m_ref:\n", | |
" p = pm.Dirichlet(\"p\", [1, 1])\n", | |
" y = pm.NormalMixture(\"y\", w=p, mu=[-1, 1], sigma=1, observed=1) \n", | |
"m_ref.compile_logp()({\"p_simplex__\": np.array([0])})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "c4846826", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"f = aesara.function([p_vv], marginal_logp)\n", | |
"# aesara.dprint(f)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "c5f15010", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"52" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(f.maker.fgraph.apply_nodes)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "37ef8d2b", | |
"metadata": {}, | |
"source": [ | |
"## Marginalize multiple RVs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "81d63419", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def explicit_mixture(name, categorical_idx, components):\n", | |
" return pm.Normal(name, pm.math.stack(components)[categorical_idx], 1)\n", | |
" \n", | |
"with pm.Model() as m:\n", | |
" p1 = pm.Dirichlet(\"p1\", [1, 1])\n", | |
" mix_comp1 = pm.Categorical(\"mix_comp1\", p=p1) \n", | |
" y1 = explicit_mixture(\"y1\", mix_comp1, [-1, 1])\n", | |
" \n", | |
" p2 = pm.Dirichlet(\"p2\", [1, 1])\n", | |
" mix_comp2 = pm.Categorical(\"mix_comp2\", p=p2) \n", | |
" y2 = explicit_mixture(\"y2\", mix_comp2, [-2, 2])\n", | |
" \n", | |
" p3 = pm.Dirichlet(\"p3\", [1, 1])\n", | |
" mix_comp3 = pm.Categorical(\"mix_comp3\", p=p3)\n", | |
" y3 = explicit_mixture(\"y3\", mix_comp3, [y1, y2])\n", | |
" \n", | |
" pm.Normal(\"llike\", y3, 1, observed=9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "bf922bc3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mix_comp3\n", | |
"mix_comp2\n", | |
"mix_comp1\n" | |
] | |
} | |
], | |
"source": [ | |
"logp_graph = m.logp()\n", | |
"rvs = list(m.free_RVs)\n", | |
"marginalize_rvs = {mix_comp1, mix_comp2, mix_comp3}\n", | |
"fg = FunctionGraph(outputs=rvs, clone=False)\n", | |
"order = fg.toposort()\n", | |
"for rv in sorted(marginalize_rvs, key=lambda x: order.index(x.owner), reverse=True):\n", | |
" print(rv)\n", | |
" rvs.remove(rv)\n", | |
" vv = m.rvs_to_values[rv]\n", | |
" vvs = [m.rvs_to_values[rv] for rv in rvs]\n", | |
" logp_op = OpFromGraph([vv, *vvs], [logp_graph], inline=True)\n", | |
" rv_domain = range(2) # Hard-coded\n", | |
" logp_graph = at.logsumexp([logp_op(vv_const, *vvs) for vv_const in rv_domain])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "5c59cafa", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'p1_simplex__': array([0.]),\n", | |
" 'y1': array(-1.),\n", | |
" 'p2_simplex__': array([0.]),\n", | |
" 'y2': array(-2.),\n", | |
" 'p3_simplex__': array([0.]),\n", | |
" 'y3': array(-1.)}" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ip = m.initial_point()\n", | |
"ip.pop(\"mix_comp3\", None)\n", | |
"ip.pop(\"mix_comp2\", None)\n", | |
"ip.pop(\"mix_comp1\", None)\n", | |
"ip" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "8a2f1de8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-57.23329681)" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"f = m.compile_fn(logp_graph)\n", | |
"f(ip)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "2c8ba2ac", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"193" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(f.f.maker.fgraph.apply_nodes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "c5ef7a5c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-57.23329681)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with pm.Model() as m_ref:\n", | |
" p1 = pm.Dirichlet(\"p1\", [1, 1])\n", | |
" y1 = pm.NormalMixture(\"y1\", p1, [-1, 1])\n", | |
" \n", | |
" p2 = pm.Dirichlet(\"p2\", [1, 1])\n", | |
" y2 = pm.NormalMixture(\"y2\", p2, [-2, 2])\n", | |
" \n", | |
" p3 = pm.Dirichlet(\"p3\", [1, 1])\n", | |
" y3 = pm.NormalMixture(\"y3\", p3, [y1, y2])\n", | |
" \n", | |
" pm.Normal(\"llike\", y3, 1, observed=9)\n", | |
" \n", | |
"m_ref.compile_logp()(ip)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "98072f1c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "pymc", | |
"language": "python", | |
"name": "pymc" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.4" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment