Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Created August 30, 2021 22:20
Show Gist options
  • Save brandonwillard/10f8bdfe1c307c590472f73500f6ae67 to your computer and use it in GitHub Desktop.
Save brandonwillard/10f8bdfe1c307c590472f73500f6ae67 to your computer and use it in GitHub Desktop.
Some preliminary work to remove "constant" sub-graphs with minimal/no changes to multiplication order
from collections import defaultdict
import aesara.tensor as aet
from aesara import config
from aesara.compile import optdb
from aesara.graph.basic import ancestors, applys_between, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import (
EquilibriumOptimizer,
GlobalOptimizer,
MergeFeature,
MergeOptimizer,
PatternSub,
local_optimizer,
)
from aesara.graph.optdb import OptimizationQuery
from aesara.printing import debugprint as dprint
from aesara.tensor.math_opt import copy_stack_trace, get_clients
# We don't need to waste time compiling graphs to C
config.cxx = ""
def optimize_graph(
fgraph, include=["canonicalize"], custom_opt=None, clone=False, **kwargs
):
if not isinstance(fgraph, FunctionGraph):
inputs = list(graph_inputs([fgraph]))
fgraph = FunctionGraph(inputs, [fgraph], clone=clone)
canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
if custom_opt:
custom_opt.optimize(fgraph)
return fgraph
# a / b -> a * 1/b, for a != 1 and b != 1
div_to_mul_pattern = PatternSub(
(aet.true_div, "a", "b"),
(aet.mul, "a", (aet.reciprocal, "b")),
allow_multiple_clients=True,
name="div_to_mul",
tracks=[aet.true_div],
get_nodes=get_clients,
)
# a - b -> a + (-b)
sub_to_add_pattern = PatternSub(
(aet.sub, "a", "b"),
(aet.add, "a", (aet.neg, "b")),
allow_multiple_clients=True,
name="sub_to_add",
tracks=[aet.sub],
get_nodes=get_clients,
)
# a * (x + y) -> a * x + a * y
distribute_mul_pattern_lhs = PatternSub(
(aet.mul, "a", (aet.add, "x", "y")),
(aet.add, (aet.mul, "a", "x"), (aet.mul, "a", "y")),
allow_multiple_clients=True,
name="distribute_mul_lhs",
tracks=[aet.mul],
get_nodes=get_clients,
)
# (x + y) * a -> x * a + y * a
distribute_mul_pattern_rhs = PatternSub(
(aet.mul, (aet.add, "x", "y"), "a"),
(aet.add, (aet.mul, "x", "a"), (aet.mul, "y", "a")),
allow_multiple_clients=True,
name="distribute_mul_rhs",
tracks=[aet.mul],
get_nodes=get_clients,
)
# a.dot(x + y) -> a.dot(x) + a.dot(y)
distribute_dot_pattern_lhs = PatternSub(
(aet.dot, "a", (aet.add, "x", "y")),
(aet.add, (aet.dot, "a", "x"), (aet.dot, "a", "y")),
allow_multiple_clients=True,
name="distribute_dot_lhs",
tracks=[aet.dot],
get_nodes=get_clients,
)
# (x + y).dot(a) -> x.dot(a) + y.dot(a)
distribute_dot_pattern_rhs = PatternSub(
(aet.dot, (aet.add, "x", "y"), "a"),
(aet.add, (aet.dot, "x", "a"), (aet.dot, "y", "a")),
allow_multiple_clients=True,
name="distribute_dot_rhs",
tracks=[aet.dot],
get_nodes=get_clients,
)
@local_optimizer([aet.add, aet.mul])
def local_add_mul_fusion(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs."""
if node.op not in (aet.add, aet.mul):
return False
outer_op = node.op
new_inputs = []
fused = False
orig_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs:
if (
inp.owner
and inp.owner.op == outer_op
and (orig_inputs + len(inp.owner.inputs) - 1) <= max_inputs
):
new_inputs.extend(inp.owner.inputs)
fused = True
else:
new_inputs.append(inp)
if fused:
output = node.op(*new_inputs)
copy_stack_trace(node.outputs[0], output)
# Do the recursion here to help lower the number of `FusionOptimizer`
# iterations.
if output.owner:
next_output = local_add_mul_fusion.transform(fgraph, output.owner)
if next_output:
return next_output
return [output]
class ExpandOptimizer(EquilibriumOptimizer):
def __init__(self):
super().__init__(
[
div_to_mul_pattern,
distribute_mul_pattern_lhs,
distribute_mul_pattern_rhs,
distribute_dot_pattern_lhs,
distribute_dot_pattern_rhs,
sub_to_add_pattern,
local_add_mul_fusion,
],
ignore_newtrees=False,
tracks_on_change_inputs=True,
max_use_ratio=10000,
)
def uniquify_terms(fgraph):
"""Unique-ify repeated terms in a graph.
This is an easy way perform manipulations on a copy of a graph and later
determine which manipulated terms correspond to terms in the original
graph.
Consider this graph: ``a * (b + c) + log(a) * a``. If it were transformed
into a gathered product, i.e. ``a * (b + c + log(a))``, and decided that the
outer-most ``mul(a, ...)`` should be removed *in the original graph*, one
would need to determine that the ``a`` in the first index of the
``mul(a, b + c)`` node and the ``a`` in the second index of the
``mul(log(a), a)`` node correspond to the outer-most ``mul(a, ...)`` in the
transformed graph.
This function replaces variables that are used in more than one place in a
graph. In the example above, ``a * (b + c) + log(a) * a`` is converted to
``a_2 * (b + c) + log(a_1) * a_0`` and the map
``{a_0: (a, node(log(a_1) * a_0), 1), a_2: (a, node(a_2 * (b + c)), 0), ...}``
is returned.
The values in the map correspond to the replaced term, the node in which
the variable was replaced, and the input index for the term that was
replace in the node (i.e. ``node.inputs[index] == term``). Together these
values tell you which duplicate term was replaced and exactly where.
"""
vars_to_clones = {}
nodes_to_update = []
replaced_nodes = set()
if not hasattr(fgraph, "merge_feature"):
fgraph.attach_feature(MergeFeature())
MergeOptimizer().apply(fgraph)
rev_ordered_vars = (
sum([n.outputs for n in reversed(fgraph.toposort())], []) + fgraph.inputs
)
for var in rev_ordered_vars:
clients = fgraph.clients[var]
if var.owner and var.owner.op in (aet.mul, aet.add, aet.neg, aet.dot):
continue
# We don't want to recreate variables that we're trying to rewrite.
# e.g. if we want to apply the rewrites `a -> a_0` and `log(a) ->
# log_a_0`, then we can't replace `a` in the `log(a)` node;
# otherwise, we would produce a new `log(a_0)` for which `log(a) ->
# log_a_0` no longer applies.
filtered_clients = [(c, i) for c, i in clients if c not in replaced_nodes]
if len(filtered_clients) > 1:
n = 0
for c, i in filtered_clients:
new_var = var.clone()
vars_to_clones[(var, c, i)] = new_var
new_var.tag.root_var = var
new_var.name = f"{var.name or var}_{n}"
n += 1
nodes_to_update.append(c)
# XXX: This can be expensive and should be replaced with something
# more "incremental"
replaced_nodes |= set(applys_between(fgraph.inputs, [var]))
# Now that we have all the unique-ified variables, we can recreate the
# nodes that need to use them. We couldn't do that above, because some
# nodes will contain more than one unique-ified variable, and we need to
# replace them all at once (i.e. we don't want to keep recreating nodes
# when we find new terms to unique-ify).
replacements = {}
for c in nodes_to_update:
new_c_inputs = [
vars_to_clones.get((inp, c, idx), inp) for idx, inp in enumerate(c.inputs)
]
new_c = c.op.make_node(*new_c_inputs)
for new_out, old_out in zip(new_c.outputs, c.outputs):
replacements[old_out] = new_out
breakpoint()
# XXX: The replacements need to inside the replaced terms, as well.
fgraph.replace_all(
list(replacements.items()), reason="uniquify", import_missing=True
)
return {v: k for k, v in vars_to_clones.items()}
def test_uniquify_terms_1():
a, b, c, x = aet.scalars("abcx")
# log_1 = aet.log(a + x)
# mul_1 = aet.mul(a, x, b, log_1)
# mul_2 = a * log_1
# mul_3 = aet.mul(a, c, log_1)
# f_x = aet.add(
# mul_1,
# mul_2,
# mul_3,
# )
# f_x = FunctionGraph([a, b, c, x], [f_x], clone=False)
# vars_to_orig = uniquify_terms(f_x)
# dprint(f_x)
# # TODO: Add some `assert`s
# assert tuple(vars_to_orig.values()) == ((log_1, mul_1.owner, 0), (log_1, mul_1.owner, 0))
# This example should confirm that the following `log(a + b)`
# terms will be unique-ified--and *not* their interiors
log_1 = aet.log(a + b)
mul_1 = log_1 * x
mul_2 = log_1 * c
f_x = mul_1 * b + mul_2
f_x = FunctionGraph([a, b, c, x], [f_x], clone=False)
vars_to_orig = uniquify_terms(f_x)
dprint(f_x.outputs)
# These should not be touched; only the log terms should
assert set(vars_to_orig.values()) == set(
((log_1, mul_2.owner, 0), (log_1, mul_1.owner, 0))
)
def test_uniquify_terms_2():
a, b, c, x = aet.scalars("abcx")
mul_1 = a * (x * b)
mul_2 = a * c
mul_2.name = "mul_2"
mul_3 = a * b
mul_3.name = "mul_3"
add_1 = a + mul_2
f_x = mul_1 + add_1 * aet.log(mul_3 * x + mul_2)
f_x = FunctionGraph([a, b, c, x], [f_x], clone=False)
dprint(f_x)
# f_x.replace(mul_2, x)
# dprint(f_x)
vars_to_orig = uniquify_terms(f_x)
dprint(f_x)
class RemoveMulConstants(GlobalOptimizer):
def __init__(self, input_vars):
self.input_vars = input_vars
self.expand_opt = ExpandOptimizer()
def add_requirements(self, fgraph):
from aesara.graph.features import ReplaceValidate
fgraph.attach_feature(ReplaceValidate())
if not hasattr(fgraph, "merge_feature"):
fgraph.attach_feature(MergeFeature())
def apply(self, fgraph):
# Make sure all duplicate terms are merged so that we can collect them
MergeOptimizer().apply(fgraph)
# We need to collapse and distribute terms across multiplication and
# addition so that we have a simple graph from which to collect the
# required information about terms (i.e. which multiplicative and
# additive terms can be removed because they don't depend on
# `input_vars`).
# This is performed on a _different_, intermediate graph
# (i.e. `int_fgraph`), so we need a map that takes us from the terms in
# the intermediate graph to the original graph (i.e. `fgraph`).
int_fgraph, equiv = fgraph.clone_get_equiv()
rev_equiv = {v: k for k, v in equiv.items()}
vars_to_orig = {
k: (rev_equiv[v], rev_equiv[n], i)
for k, (v, n, i) in uniquify_terms(int_fgraph)
}
new_input_vars = {equiv[v] for v in self.input_vars}
_ = self.expand_opt.optimize(int_fgraph)
for trans_out in int_fgraph.outputs:
trans_out_node = trans_out.owner
if trans_out_node and trans_out_node.op == aet.add:
# TODO: Get rid of addends with no connections to `input_vars`
# I.e. `(a * b * ...) + (c * d * ...) + ...`
if all(
i.owner and i.owner.op == aet.mul for i in trans_out_node.inputs
):
# We want the multiplicative terms that are present in
# every addend, but we need to use the "root_var"s for each
# term in these `mul`s, because every term was made unique
# above.
# To do this, we create a map from the root var to the
inputs_maps = []
for i in trans_out_node.inputs:
in_dict = defaultdict(list)
inputs_maps.append(in_dict)
for ii in i.owner.inputs:
in_dict[ii].append(getattr(ii.tag, "root_var", ii))
distributed_terms = set.intersection(
*(set(im.keys()) for im in inputs_maps)
)
var_root_vars = set()
for dt in distributed_terms:
for m in inputs_maps:
var_root_vars |= {(a, dt) for a in m[dt]}
replacements = []
for var, root_var in var_root_vars:
if any(a in new_input_vars for a in ancestors([root_var])):
# We don't want to remove terms that contain the `input_vars`
# TODO: We should cache the results of this
# condition so they can be reused in the additive
# terms check
continue
# Here, we obtain the original node in which this
# distributed term, `dterm`, resided.
# E.g. if `mul(a, add(b, c))` was rewritten to
# `add(mul(a, b), mul(a, c))`, then we need a reference
# to the old `mul(a, ...)` node so we can
# replace/remove it in the original `fgraph` (after
# remapping with `rev_equiv`).
(orig_var, orig_node, orig_in_idx) = vars_to_orig[var]
assert orig_node in fgraph.apply_nodes
# Create a new version of the node containing this
# "constant" term from the fgraph
new_inputs = list(orig_node.inputs)
del new_inputs[orig_in_idx]
if len(new_inputs) == 1:
# Get rid of the `add`/`mul` entirely
new_out = new_inputs[0]
else:
# Remove the distributed argument from the `add`/`mul`
# XXX: This will only work for single-output nodes
new_out = orig_node.op.make_node(*new_inputs).outputs[0]
replacements.append((orig_node.outputs[0], new_out))
fgraph.replace_all(replacements)
def test_RemoveMulConstants():
a, b, c, x = aet.scalars("abcx")
f_x = a * (x * b + c * aet.log(a * (b * x + c)))
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x]))
dprint(f_x_fg)
# It should've removed the outer-most `mul(a, ...)`
assert f_x_fg.outputs[0].owner.op == aet.add
# It shouldn't have touched the rest
assert f_x_fg.outputs[0] == f_x.owner.inputs[1]
f_x = aet.log(x) * b + aet.log(x) * c
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x]))
dprint(f_x_fg)
# Make sure that the `log(x)` terms haven't been removed
mul1 = f_x_fg.outputs[0].owner.inputs[0].owner
mul2 = f_x_fg.outputs[0].owner.inputs[1].owner
assert x in mul1.inputs[0].owner.inputs
assert x in mul2.inputs[0].owner.inputs
f_x = (aet.log(a + b) * x) * b + aet.log(a + b) * c
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x]))
dprint(f_x_fg)
# Make sure that the `log(a)` terms have been removed
mul_term = f_x_fg.outputs[0].owner.inputs[0].owner
c_term = f_x_fg.outputs[0].owner.inputs[1]
assert mul_term.op == aet.mul
assert c_term is c
f_x = a * x * b + c * aet.log(a * (b * x + c)) * a
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x]))
dprint(f_x_fg)
assert f_x_fg.outputs[0].owner.op == aet.add
add_node = f_x_fg.outputs[0].owner
# Make sure that `a` was removed from the inner-two `mul`s
assert a not in add_node.inputs[0].owner.inputs
assert a not in add_node.inputs[1].owner.inputs
# Make sure `a` wasn't removed from the `log` term
log_node = add_node.inputs[1].owner.inputs[1].owner
assert a in log_node.inputs[0].owner.inputs
f_x = a * x * b + (a + a * c) * aet.log(a * b * x + a * c)
# dprint(optimize_graph(f_x, custom_opt=ExpandOptimizer()))
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x]))
dprint(f_x_fg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment