Last active
November 8, 2022 18:37
-
-
Save ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"outputs": [], | |
"source": [ | |
"from collections import defaultdict\n", | |
"\n", | |
"import aesara\n", | |
"import aesara.tensor as at\n", | |
"import numpy as np\n", | |
"\n", | |
"from aeppl import factorized_joint_logprob, joint_logprob" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"outputs": [], | |
"source": [ | |
"def marginalize_factorized_logp_dict(logp_dict, marginalize):\n", | |
" if not marginalize:\n", | |
" return logp_dict\n", | |
" else:\n", | |
" marginalize = marginalize.copy()\n", | |
" marginalized_vv, constant_values = marginalize.popitem()\n", | |
" marginalized_logp_dict = defaultdict(lambda: at.constant(-np.inf))\n", | |
"\n", | |
" for constant_value in constant_values:\n", | |
" constant_value = at.constant(\n", | |
" constant_value,\n", | |
" dtype=marginalized_vv.dtype,\n", | |
" name=f\"{marginalized_vv}={constant_value}\",\n", | |
" )\n", | |
" new_logp_dict = {\n", | |
" vv: logp_expr for vv, logp_expr in zip(\n", | |
" logp_dict.keys(),\n", | |
" aesara.graph.clone_replace(\n", | |
" list(logp_dict.values()),\n", | |
" replace={marginalized_vv: constant_value},\n", | |
" )\n", | |
" )\n", | |
" }\n", | |
" marginalized_var_constant_logp = new_logp_dict.pop(marginalized_vv)\n", | |
" new_logp_dict = marginalize_factorized_logp_dict(new_logp_dict, marginalize)\n", | |
"\n", | |
" for value_var, logp_expr in new_logp_dict.items():\n", | |
" marginalized_logp_dict[value_var] = at.logsumexp((\n", | |
" marginalized_logp_dict[value_var],\n", | |
" logp_expr + marginalized_var_constant_logp\n", | |
" ))\n", | |
"\n", | |
" return marginalized_logp_dict" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"outputs": [], | |
"source": [ | |
"x = at.random.categorical([0.1, 0.2, 0.3, 0.4], name=\"x\")\n", | |
"y = at.random.categorical([0.1, 0.3, 0.6], name=\"y\")\n", | |
"z = at.random.normal(x, y + 1, name=\"z\")\n", | |
"x_vv = x.clone()\n", | |
"y_vv = y.clone()\n", | |
"z_vv = z.clone()\n", | |
"\n", | |
"ref_logp = joint_logprob({x: x_vv, y: y_vv, z: z_vv}, sum=True)\n", | |
"ref_logp_fn = aesara.function([x_vv, y_vv, z_vv], ref_logp)\n", | |
"\n", | |
"logp_dict = factorized_joint_logprob({x:x_vv, y:y_vv, z: z_vv})\n", | |
"logp_dict = marginalize_factorized_logp_dict(logp_dict, marginalize={x_vv: range(4), y_vv: range(3)})" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "array(-1.97233263)" | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"logp_dict[z_vv].eval({z_vv: 1})" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "-1.9723326261602925" | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.log(np.sum([np.exp(ref_logp_fn(x=x_vv, y=y_vv, z=1)) for x_vv in range(4) for y_vv in range(3)]))" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment