Last active
January 1, 2021 11:00
-
-
Save twiecki/e758db2c3d2df5f3368fc49e6087e58f to your computer and use it in GitHub Desktop.
Skeleton to write graph optimizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import numpy as np\nimport theano.tensor as tt\nfrom theano import config\nfrom theano.compile import optdb\nfrom theano.gof.fg import FunctionGraph\nfrom theano.gof.graph import inputs as tt_inputs\nfrom theano.gof.opt import EquilibriumOptimizer, PatternSub\nfrom theano.gof.optdb import Query\nfrom theano.printing import debugprint as tt_dprint\nfrom theano.tensor.opt import get_clients\n\n# We don't need to waste time compiling graphs to C\nconfig.cxx = \"\"\n\n\n# a / b -> a * 1/b, for a != 1 and b != 1\ndiv_to_mul_pattern = PatternSub(\n (tt.true_div, \"a\", \"b\"),\n (tt.mul, \"a\", (tt.inv, \"b\")),\n allow_multiple_clients=True,\n name=\"div_to_mul\",\n tracks=[tt.true_div],\n get_nodes=get_clients,\n)\n\n# a - b -> a + (-b)\nsub_to_add_pattern = PatternSub(\n (tt.sub, \"a\", \"b\"),\n (tt.add, \"a\", (tt.neg, \"b\")),\n allow_multiple_clients=True,\n name=\"sub_to_add\",\n tracks=[tt.sub],\n get_nodes=get_clients,\n)\n\n# a * (x + y) -> a * x + a * y\ndistribute_mul_pattern = PatternSub(\n (tt.mul, \"a\", (tt.add, \"x\", \"y\")),\n (tt.add, (tt.mul, \"a\", \"x\"), (tt.mul, \"a\", \"y\")),\n allow_multiple_clients=True,\n name=\"distribute_mul\",\n tracks=[tt.mul],\n get_nodes=get_clients,\n)\n\nfrom theano.scalar import float64, add, mul, true_div\n \nclass RemoveNormalizingConstants(gof.GlobalOptimizer):\n def add_requirements(self, fgraph):\n fgraph.attach_feature(toolbox.ReplaceValidate())\n\n def apply(self, fgraph):\n for node in fgraph.toposort():\n #print(node)\n if node.op == tt.add:\n x, y = node.inputs\n z = node.outputs[0]\n # Find if value occurs in either branch\n #import pdb; pdb.set_trace()\n if x.name == \"value\" or y.name == \"value\":\n print(\"value found in subgraph\")\n # Mark subgraph as not to be deleted\n \n\nexpand_opt = EquilibriumOptimizer(\n [div_to_mul_pattern, distribute_mul_pattern, sub_to_add_pattern, RemoveNormalizingConstants()],\n ignore_newtrees=False,\n tracks_on_change_inputs=True,\n max_use_ratio=config.optdb__max_use_ratio,\n)\n\n\ndef optimize_graph(fgraph, include=[\"canonicalize\"], custom_opt=None, **kwargs):\n if not isinstance(fgraph, FunctionGraph):\n inputs = tt_inputs([fgraph])\n fgraph = FunctionGraph(inputs, [fgraph], clone=False)\n\n canonicalize_opt = optdb.query(Query(include=include, **kwargs))\n _ = canonicalize_opt.optimize(fgraph)\n\n if custom_opt:\n custom_opt.optimize(fgraph)\n\n return fgraph\n\n\ntau = tt.dscalar(\"tau\")\nvalue = tt.dscalar(\"value\")\nmu = tt.dscalar(\"mu\")\n\nlogp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0\n\nlogp_fg = optimize_graph(logp, custom_opt=expand_opt)\n\ntt_dprint(logp_fg)\n\n# TODO: Remove additive terms that do not contain the desired terms (e.g. `mu`\n# and `tau` when is only a function of `mu`, `tau`)\n\n# This is what we want from the optimization\n# logp_goal = -tau * (value - mu) ** 2", | |
"execution_count": 41, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "value found in subgraph\nElemwise{add,no_inplace} [id A] '' 8\n |Elemwise{mul,no_inplace} [id B] '' 7\n | |TensorConstant{0.5} [id C]\n | |Elemwise{mul,no_inplace} [id D] '' 6\n | |TensorConstant{-1.0} [id E]\n | |tau [id F]\n | |Elemwise{pow,no_inplace} [id G] '' 5\n | |Elemwise{add,no_inplace} [id H] '' 4\n | | |value [id I]\n | | |Elemwise{neg,no_inplace} [id J] '' 3\n | | |mu [id K]\n | |TensorConstant{2} [id L]\n |Elemwise{mul,no_inplace} [id M] '' 2\n |TensorConstant{0.5} [id C]\n |Elemwise{log,no_inplace} [id N] '' 1\n |Elemwise{mul,no_inplace} [id O] '' 0\n |TensorConstant{0.15915494309189535} [id P]\n |tau [id F]\n", | |
"name": "stdout" | |
} | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "pymc3theano", | |
"display_name": "pymc3theano", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.8.5", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"gist": { | |
"id": "e758db2c3d2df5f3368fc49e6087e58f", | |
"data": { | |
"description": "Skeleton to write graph optimizer", | |
"public": true | |
} | |
}, | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/e758db2c3d2df5f3368fc49e6087e58f" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment