Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active October 31, 2024 19:11
Show Gist options
  • Save ricardoV94/62f58a564ba39fbbfdc09755f703f0f8 to your computer and use it in GitHub Desktop.
Save ricardoV94/62f58a564ba39fbbfdc09755f703f0f8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-10-31T19:10:45.531028Z",
"start_time": "2024-10-31T19:10:43.534555Z"
}
},
"source": [
"import pytensor.tensor as pt\n",
"import pymc as pm\n",
"\n",
"from pytensor.graph.fg import FunctionGraph\n",
"from pytensor.graph.rewriting.basic import node_rewriter, out2in\n",
"from pymc.model.fgraph import fgraph_from_model, model_from_fgraph, ModelObservedRV, model_observed_rv"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-31T19:10:45.547421Z",
"start_time": "2024-10-31T19:10:45.533816Z"
}
},
"cell_type": "code",
"source": [
"@node_rewriter(tracks=[ModelObservedRV])\n",
"def summary_stats_normal(fgraph: FunctionGraph, node):\n",
" \"\"\"This applies the equivalence (up to a normalizing constant) described in:\n",
" \n",
" https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics \n",
" \"\"\"\n",
" [observed_rv] = node.outputs\n",
" [rv, data] = node.inputs\n",
" \n",
" if not isinstance(rv.owner.op, pm.Normal):\n",
" return None\n",
" \n",
" # Check the normal RV is not just a scalar\n",
" if all(rv.type.broadcastable):\n",
" return None\n",
"\n",
" # Check that the observed RV is not used anywhere else (like a Potential or Deterministic)\n",
" # There should be only one use: as an \"output\"\n",
" if len(fgraph.clients[observed_rv]) > 1:\n",
" return None\n",
"\n",
" mu, sigma = rv.owner.op.dist_params(rv.owner)\n",
"\n",
" # Check if mu and sigma are scalar RVs\n",
" if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable):\n",
" return None\n",
" \n",
" # Check that mu and sigma are not used anywhere else\n",
" # Note: This is too restrictive, it's fine if they're used in Deterministics!\n",
" # There should only be two uses: as an \"output\" and as the param of the `rv`\n",
" if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2:\n",
" return None\n",
" \n",
" # Remove expand_dims\n",
" mu = mu.squeeze()\n",
" sigma = sigma.squeeze()\n",
" \n",
" # Apply the rewrite\n",
" mean_data = pt.mean(data)\n",
" mean_data.name = None\n",
" var_data = pt.var(data, ddof=1)\n",
" var_data.name = None\n",
" N = data.size\n",
" sqrt_N = pt.sqrt(N)\n",
" nm1_over2 = 0.5 * (N - 1)\n",
" \n",
" observed_mean = model_observed_rv(\n",
" pm.Normal.dist(mu=mu, sigma=sigma / sqrt_N),\n",
" mean_data,\n",
" )\n",
" observed_mean.name = f\"{rv.name}_mean\"\n",
" \n",
" observed_var = model_observed_rv(\n",
" pm.Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma ** 2)),\n",
" var_data,\n",
" )\n",
" observed_var.name = f\"{rv.name}_var\"\n",
" \n",
" fgraph.add_output(observed_mean, import_missing=True)\n",
" fgraph.add_output(observed_var, import_missing=True)\n",
" fgraph.remove_node(node)\n",
"\n",
"summary_stats_rewrite = out2in(summary_stats_normal, ignore_newtrees=True)"
],
"id": "16c05fe619c60446",
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-31T19:10:47.995140Z",
"start_time": "2024-10-31T19:10:45.548754Z"
}
},
"cell_type": "code",
"source": [
"y_data = pm.draw(pm.Normal.dist(mu=1, sigma=0.5, size=(1000,)), random_seed=42)\n",
"\n",
"with pm.Model() as m:\n",
" mu = pm.Normal(\"mu\")\n",
" sigma = pm.HalfNormal(\"sigma\")\n",
" y = pm.Normal(\"y\", mu=mu, sigma=sigma, observed=y_data)\n",
" \n",
"m.to_graphviz()"
],
"id": "e1352af22b0a4f2a",
"outputs": [
{
"data": {
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.0.0 (20240803.0821)\n -->\n<!-- Pages: 1 -->\n<svg width=\"264pt\" height=\"254pt\"\n viewBox=\"0.00 0.00 264.15 254.25\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 250.25)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-250.25 260.15,-250.25 260.15,4 -4,4\"/>\n<g id=\"clust1\" class=\"cluster\">\n<title>cluster1000</title>\n<path fill=\"none\" stroke=\"black\" d=\"M93.18,-8C93.18,-8 183.18,-8 183.18,-8 189.18,-8 195.18,-14 195.18,-20 195.18,-20 195.18,-121.75 195.18,-121.75 195.18,-127.75 189.18,-133.75 183.18,-133.75 183.18,-133.75 93.18,-133.75 93.18,-133.75 87.18,-133.75 81.18,-127.75 81.18,-121.75 81.18,-121.75 81.18,-20 81.18,-20 81.18,-14 87.18,-8 93.18,-8\"/>\n<text text-anchor=\"middle\" x=\"169.18\" y=\"-15.95\" font-family=\"Times,serif\" font-size=\"14.00\">1000</text>\n</g>\n<!-- sigma -->\n<g id=\"node1\" class=\"node\">\n<title>sigma</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"70.18\" cy=\"-204\" rx=\"70.18\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-216.57\" font-family=\"Times,serif\" font-size=\"14.00\">sigma</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-199.32\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-182.07\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n</g>\n<!-- y -->\n<g id=\"node3\" class=\"node\">\n<title>y</title>\n<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"138.18\" cy=\"-83.5\" rx=\"48.97\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"138.18\" y=\"-96.07\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n<text text-anchor=\"middle\" x=\"138.18\" y=\"-78.82\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"138.18\" y=\"-61.57\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- sigma&#45;&gt;y -->\n<g id=\"edge1\" class=\"edge\">\n<title>sigma&#45;&gt;y</title>\n<path fill=\"none\" stroke=\"black\" d=\"M92.82,-163.54C98.68,-153.34 105.04,-142.25 111.1,-131.69\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"114.09,-133.51 116.03,-123.09 108.02,-130.02 114.09,-133.51\"/>\n</g>\n<!-- mu -->\n<g id=\"node2\" class=\"node\">\n<title>mu</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"207.18\" cy=\"-204\" rx=\"48.97\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-216.57\" font-family=\"Times,serif\" font-size=\"14.00\">mu</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-199.32\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-182.07\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- mu&#45;&gt;y -->\n<g id=\"edge2\" class=\"edge\">\n<title>mu&#45;&gt;y</title>\n<path fill=\"none\" stroke=\"black\" d=\"M185.54,-165.83C179.18,-154.91 172.14,-142.82 165.48,-131.38\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"168.56,-129.73 160.51,-122.84 162.52,-133.25 168.56,-129.73\"/>\n</g>\n</g>\n</svg>\n",
"text/plain": [
"<graphviz.graphs.Digraph at 0x7c2622fd2300>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-31T19:10:48.074849Z",
"start_time": "2024-10-31T19:10:47.997525Z"
}
},
"cell_type": "code",
"source": [
"fgraph, _ = fgraph_from_model(m)\n",
"# fgraph.dprint()\n",
"_ = summary_stats_rewrite.apply(fgraph)\n",
"# fgraph.dprint()\n",
"\n",
"new_m = model_from_fgraph(fgraph)\n",
"new_m.to_graphviz()"
],
"id": "a75390e598ae4fe2",
"outputs": [
{
"data": {
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.0.0 (20240803.0821)\n -->\n<!-- Pages: 1 -->\n<svg width=\"264pt\" height=\"213pt\"\n viewBox=\"0.00 0.00 264.15 213.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 209)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-209 260.15,-209 260.15,4 -4,4\"/>\n<!-- y_var -->\n<g id=\"node1\" class=\"node\">\n<title>y_var</title>\n<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"70.18\" cy=\"-42.25\" rx=\"50.03\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-54.82\" font-family=\"Times,serif\" font-size=\"14.00\">y_var</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-37.57\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-20.32\" font-family=\"Times,serif\" font-size=\"14.00\">Gamma</text>\n</g>\n<!-- sigma -->\n<g id=\"node2\" class=\"node\">\n<title>sigma</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"70.18\" cy=\"-162.75\" rx=\"70.18\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-175.32\" font-family=\"Times,serif\" font-size=\"14.00\">sigma</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-158.07\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"70.18\" y=\"-140.82\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n</g>\n<!-- sigma&#45;&gt;y_var -->\n<g id=\"edge3\" class=\"edge\">\n<title>sigma&#45;&gt;y_var</title>\n<path fill=\"none\" stroke=\"black\" d=\"M70.18,-120.3C70.18,-112.5 70.18,-104.24 70.18,-96.15\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"73.68,-96.39 70.18,-86.39 66.68,-96.39 73.68,-96.39\"/>\n</g>\n<!-- y_mean -->\n<g id=\"node4\" class=\"node\">\n<title>y_mean</title>\n<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"197.18\" cy=\"-42.25\" rx=\"49.5\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"197.18\" y=\"-54.82\" font-family=\"Times,serif\" font-size=\"14.00\">y_mean</text>\n<text text-anchor=\"middle\" x=\"197.18\" y=\"-37.57\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"197.18\" y=\"-20.32\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- sigma&#45;&gt;y_mean -->\n<g id=\"edge1\" class=\"edge\">\n<title>sigma&#45;&gt;y_mean</title>\n<path fill=\"none\" stroke=\"black\" d=\"M107.6,-126.84C122.78,-112.67 140.38,-96.25 155.94,-81.73\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"158.31,-84.31 163.23,-74.92 153.53,-79.19 158.31,-84.31\"/>\n</g>\n<!-- mu -->\n<g id=\"node3\" class=\"node\">\n<title>mu</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"207.18\" cy=\"-162.75\" rx=\"48.97\" ry=\"42.25\"/>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-175.32\" font-family=\"Times,serif\" font-size=\"14.00\">mu</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-158.07\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n<text text-anchor=\"middle\" x=\"207.18\" y=\"-140.82\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n</g>\n<!-- mu&#45;&gt;y_mean -->\n<g id=\"edge2\" class=\"edge\">\n<title>mu&#45;&gt;y_mean</title>\n<path fill=\"none\" stroke=\"black\" d=\"M203.68,-120.3C203.02,-112.5 202.33,-104.24 201.64,-96.15\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"205.15,-96.05 200.82,-86.38 198.17,-96.64 205.15,-96.05\"/>\n</g>\n</g>\n</svg>\n",
"text/plain": [
"<graphviz.graphs.Digraph at 0x7c261da83c80>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-31T19:10:48.745053Z",
"start_time": "2024-10-31T19:10:48.076171Z"
}
},
"cell_type": "code",
"source": [
"# Confirm equivalent (up to an additive (in log space) normalization constant)\n",
"m_logp = m.compile_logp()\n",
"new_m_logp = new_m.compile_logp()\n",
"\n",
"ip = m.initial_point()\n",
"ip"
],
"id": "9ba2337e7ff5f6a3",
"outputs": [
{
"data": {
"text/plain": [
"{'mu': array(0.), 'sigma_log__': array(0.)}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-31T19:10:48.751046Z",
"start_time": "2024-10-31T19:10:48.746750Z"
}
},
"cell_type": "code",
"source": "m_logp(ip) - new_m_logp(ip)",
"id": "29f91a23ff4b0bf8",
"outputs": [
{
"data": {
"text/plain": [
"-711.5313632106429"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-31T19:10:48.758491Z",
"start_time": "2024-10-31T19:10:48.752459Z"
}
},
"cell_type": "code",
"source": [
"ip[\"mu\"] += 0.5\n",
"ip[\"sigma_log__\"] += 1.5\n",
"\n",
"m_logp(ip) - new_m_logp(ip)"
],
"id": "5af996f937779528",
"outputs": [
{
"data": {
"text/plain": [
"-711.5313632106386"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 7
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment