Last active
October 31, 2024 19:11
-
-
Save ricardoV94/62f58a564ba39fbbfdc09755f703f0f8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"id": "initial_id", | |
"metadata": { | |
"collapsed": true, | |
"ExecuteTime": { | |
"end_time": "2024-10-31T19:10:45.531028Z", | |
"start_time": "2024-10-31T19:10:43.534555Z" | |
} | |
}, | |
"source": [ | |
"import pytensor.tensor as pt\n", | |
"import pymc as pm\n", | |
"\n", | |
"from pytensor.graph.fg import FunctionGraph\n", | |
"from pytensor.graph.rewriting.basic import node_rewriter, out2in\n", | |
"from pymc.model.fgraph import fgraph_from_model, model_from_fgraph, ModelObservedRV, model_observed_rv" | |
], | |
"outputs": [], | |
"execution_count": 1 | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2024-10-31T19:10:45.547421Z", | |
"start_time": "2024-10-31T19:10:45.533816Z" | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"@node_rewriter(tracks=[ModelObservedRV])\n", | |
"def summary_stats_normal(fgraph: FunctionGraph, node):\n", | |
" \"\"\"This applies the equivalence (up to a normalizing constant) described in:\n", | |
" \n", | |
" https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics \n", | |
" \"\"\"\n", | |
" [observed_rv] = node.outputs\n", | |
" [rv, data] = node.inputs\n", | |
" \n", | |
" if not isinstance(rv.owner.op, pm.Normal):\n", | |
" return None\n", | |
" \n", | |
" # Check the normal RV is not just a scalar\n", | |
" if all(rv.type.broadcastable):\n", | |
" return None\n", | |
"\n", | |
" # Check that the observed RV is not used anywhere else (like a Potential or Deterministic)\n", | |
" # There should be only one use: as an \"output\"\n", | |
" if len(fgraph.clients[observed_rv]) > 1:\n", | |
" return None\n", | |
"\n", | |
" mu, sigma = rv.owner.op.dist_params(rv.owner)\n", | |
"\n", | |
" # Check if mu and sigma are scalar RVs\n", | |
" if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable):\n", | |
" return None\n", | |
" \n", | |
" # Check that mu and sigma are not used anywhere else\n", | |
" # Note: This is too restrictive, it's fine if they're used in Deterministics!\n", | |
" # There should only be two uses: as an \"output\" and as the param of the `rv`\n", | |
" if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2:\n", | |
" return None\n", | |
" \n", | |
" # Remove expand_dims\n", | |
" mu = mu.squeeze()\n", | |
" sigma = sigma.squeeze()\n", | |
" \n", | |
" # Apply the rewrite\n", | |
" mean_data = pt.mean(data)\n", | |
" mean_data.name = None\n", | |
" var_data = pt.var(data, ddof=1)\n", | |
" var_data.name = None\n", | |
" N = data.size\n", | |
" sqrt_N = pt.sqrt(N)\n", | |
" nm1_over2 = 0.5 * (N - 1)\n", | |
" \n", | |
" observed_mean = model_observed_rv(\n", | |
" pm.Normal.dist(mu=mu, sigma=sigma / sqrt_N),\n", | |
" mean_data,\n", | |
" )\n", | |
" observed_mean.name = f\"{rv.name}_mean\"\n", | |
" \n", | |
" observed_var = model_observed_rv(\n", | |
" pm.Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma ** 2)),\n", | |
" var_data,\n", | |
" )\n", | |
" observed_var.name = f\"{rv.name}_var\"\n", | |
" \n", | |
" fgraph.add_output(observed_mean, import_missing=True)\n", | |
" fgraph.add_output(observed_var, import_missing=True)\n", | |
" fgraph.remove_node(node)\n", | |
"\n", | |
"summary_stats_rewrite = out2in(summary_stats_normal, ignore_newtrees=True)" | |
], | |
"id": "16c05fe619c60446", | |
"outputs": [], | |
"execution_count": 2 | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2024-10-31T19:10:47.995140Z", | |
"start_time": "2024-10-31T19:10:45.548754Z" | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"y_data = pm.draw(pm.Normal.dist(mu=1, sigma=0.5, size=(1000,)), random_seed=42)\n", | |
"\n", | |
"with pm.Model() as m:\n", | |
" mu = pm.Normal(\"mu\")\n", | |
" sigma = pm.HalfNormal(\"sigma\")\n", | |
" y = pm.Normal(\"y\", mu=mu, sigma=sigma, observed=y_data)\n", | |
" \n", | |
"m.to_graphviz()" | |
], | |
"id": "e1352af22b0a4f2a", | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.0.0 (20240803.0821)\n -->\n<!-- Pages: 1 -->\n<svg width=\"264pt\" height=\"254pt\"\n viewBox=\"0.00 0.00 264.15 254.25\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 250.25)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-250.25 260.15,-250.25 260.15,4 -4,4\"/>\n<g id=\"clust1\" class=\"cluster\">\n<title>cluster1000</title>\n<path fill=\"none\" stroke=\"black\" d=\"M93.18,-8C93.18,-8 183.18,-8 183.18,-8 189.18,-8 195.18,-14 195.18,-20 195.18,-20 195.18,-121.75 195.18,-121.75 195.18,-127.75 189.18,-133.75 183.18,-133.75 183.18,-133.75 93.18,-133.75 93.18,-133.75 87.18,-133.75 81.18,-127.75 81.18,-121.75 81.18,-121.75 81.18,-20 81.18,-20 81.18,-14 87.18,-8 93.18,-8\"/>\n<text text-anchor=\"middle\" x=\"169.18\" y=\"-15.95\" font-family=\"Times,serif\" font-size=\"14.00\">1000</text>\n</g>\n<!-- sigma -->\n<g id=\"node1\" class=\"node\">\n<title>sigma</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"70.18\" cy=\"-204\" rx=\"70.18\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-216.57\" font-family=\"Times,serif\" font-size=\"14.00\">sigma</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-199.32\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-182.07\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n</g>\n<!-- y -->\n<g id=\"node3\" class=\"node\">\n<title>y</title>\n<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"138.18\" cy=\"-83.5\" rx=\"48.97\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"138.18\" y=\"-96.07\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n<text text-anchor=\"middle\" x=\"138.18\" y=\"-78.82\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"138.18\" y=\"-61.57\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- sigma->y -->\n<g id=\"edge1\" class=\"edge\">\n<title>sigma->y</title>\n<path fill=\"none\" stroke=\"black\" d=\"M92.82,-163.54C98.68,-153.34 105.04,-142.25 111.1,-131.69\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"114.09,-133.51 116.03,-123.09 108.02,-130.02 114.09,-133.51\"/>\n</g>\n<!-- mu -->\n<g id=\"node2\" class=\"node\">\n<title>mu</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"207.18\" cy=\"-204\" rx=\"48.97\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-216.57\" font-family=\"Times,serif\" font-size=\"14.00\">mu</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-199.32\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-182.07\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- mu->y -->\n<g id=\"edge2\" class=\"edge\">\n<title>mu->y</title>\n<path fill=\"none\" stroke=\"black\" d=\"M185.54,-165.83C179.18,-154.91 172.14,-142.82 165.48,-131.38\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"168.56,-129.73 160.51,-122.84 162.52,-133.25 168.56,-129.73\"/>\n</g>\n</g>\n</svg>\n", | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7c2622fd2300>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"execution_count": 3 | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2024-10-31T19:10:48.074849Z", | |
"start_time": "2024-10-31T19:10:47.997525Z" | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"fgraph, _ = fgraph_from_model(m)\n", | |
"# fgraph.dprint()\n", | |
"_ = summary_stats_rewrite.apply(fgraph)\n", | |
"# fgraph.dprint()\n", | |
"\n", | |
"new_m = model_from_fgraph(fgraph)\n", | |
"new_m.to_graphviz()" | |
], | |
"id": "a75390e598ae4fe2", | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.0.0 (20240803.0821)\n -->\n<!-- Pages: 1 -->\n<svg width=\"264pt\" height=\"213pt\"\n viewBox=\"0.00 0.00 264.15 213.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 209)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-209 260.15,-209 260.15,4 -4,4\"/>\n<!-- y_var -->\n<g id=\"node1\" class=\"node\">\n<title>y_var</title>\n<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"70.18\" cy=\"-42.25\" rx=\"50.03\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-54.82\" font-family=\"Times,serif\" font-size=\"14.00\">y_var</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-37.57\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-20.32\" font-family=\"Times,serif\" font-size=\"14.00\">Gamma</text>\n</g>\n<!-- sigma -->\n<g id=\"node2\" class=\"node\">\n<title>sigma</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"70.18\" cy=\"-162.75\" rx=\"70.18\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-175.32\" font-family=\"Times,serif\" font-size=\"14.00\">sigma</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-158.07\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-140.82\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n</g>\n<!-- sigma->y_var -->\n<g id=\"edge3\" class=\"edge\">\n<title>sigma->y_var</title>\n<path fill=\"none\" stroke=\"black\" d=\"M70.18,-120.3C70.18,-112.5 70.18,-104.24 70.18,-96.15\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"73.68,-96.39 70.18,-86.39 66.68,-96.39 73.68,-96.39\"/>\n</g>\n<!-- y_mean -->\n<g id=\"node4\" class=\"node\">\n<title>y_mean</title>\n<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"197.18\" cy=\"-42.25\" rx=\"49.5\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"197.18\" y=\"-54.82\" font-family=\"Times,serif\" font-size=\"14.00\">y_mean</text>\n<text text-anchor=\"middle\" x=\"197.18\" y=\"-37.57\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"197.18\" y=\"-20.32\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- sigma->y_mean -->\n<g id=\"edge1\" class=\"edge\">\n<title>sigma->y_mean</title>\n<path fill=\"none\" stroke=\"black\" d=\"M107.6,-126.84C122.78,-112.67 140.38,-96.25 155.94,-81.73\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"158.31,-84.31 163.23,-74.92 153.53,-79.19 158.31,-84.31\"/>\n</g>\n<!-- mu -->\n<g id=\"node3\" class=\"node\">\n<title>mu</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"207.18\" cy=\"-162.75\" rx=\"48.97\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-175.32\" font-family=\"Times,serif\" font-size=\"14.00\">mu</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-158.07\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-140.82\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- mu->y_mean -->\n<g id=\"edge2\" class=\"edge\">\n<title>mu->y_mean</title>\n<path fill=\"none\" stroke=\"black\" d=\"M203.68,-120.3C203.02,-112.5 202.33,-104.24 201.64,-96.15\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"205.15,-96.05 200.82,-86.38 198.17,-96.64 205.15,-96.05\"/>\n</g>\n</g>\n</svg>\n", | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7c261da83c80>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"execution_count": 4 | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2024-10-31T19:10:48.745053Z", | |
"start_time": "2024-10-31T19:10:48.076171Z" | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"# Confirm equivalent (up to an additive (in log space) normalization constant)\n", | |
"m_logp = m.compile_logp()\n", | |
"new_m_logp = new_m.compile_logp()\n", | |
"\n", | |
"ip = m.initial_point()\n", | |
"ip" | |
], | |
"id": "9ba2337e7ff5f6a3", | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'mu': array(0.), 'sigma_log__': array(0.)}" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"execution_count": 5 | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2024-10-31T19:10:48.751046Z", | |
"start_time": "2024-10-31T19:10:48.746750Z" | |
} | |
}, | |
"cell_type": "code", | |
"source": "m_logp(ip) - new_m_logp(ip)", | |
"id": "29f91a23ff4b0bf8", | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-711.5313632106429" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"execution_count": 6 | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2024-10-31T19:10:48.758491Z", | |
"start_time": "2024-10-31T19:10:48.752459Z" | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"ip[\"mu\"] += 0.5\n", | |
"ip[\"sigma_log__\"] += 1.5\n", | |
"\n", | |
"m_logp(ip) - new_m_logp(ip)" | |
], | |
"id": "5af996f937779528", | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-711.5313632106386" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"execution_count": 7 | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment