Created
August 30, 2021 22:20
-
-
Save brandonwillard/10f8bdfe1c307c590472f73500f6ae67 to your computer and use it in GitHub Desktop.
Some preliminary work to remove "constant" sub-graphs with minimal/no changes to multiplication order
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import aesara.tensor as aet | |
from aesara import config | |
from aesara.compile import optdb | |
from aesara.graph.basic import ancestors, applys_between, graph_inputs | |
from aesara.graph.fg import FunctionGraph | |
from aesara.graph.opt import ( | |
EquilibriumOptimizer, | |
GlobalOptimizer, | |
MergeFeature, | |
MergeOptimizer, | |
PatternSub, | |
local_optimizer, | |
) | |
from aesara.graph.optdb import OptimizationQuery | |
from aesara.printing import debugprint as dprint | |
from aesara.tensor.math_opt import copy_stack_trace, get_clients | |
# We don't need to waste time compiling graphs to C | |
config.cxx = "" | |
def optimize_graph( | |
fgraph, include=["canonicalize"], custom_opt=None, clone=False, **kwargs | |
): | |
if not isinstance(fgraph, FunctionGraph): | |
inputs = list(graph_inputs([fgraph])) | |
fgraph = FunctionGraph(inputs, [fgraph], clone=clone) | |
canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs)) | |
_ = canonicalize_opt.optimize(fgraph) | |
if custom_opt: | |
custom_opt.optimize(fgraph) | |
return fgraph | |
# a / b -> a * 1/b, for a != 1 and b != 1 | |
div_to_mul_pattern = PatternSub( | |
(aet.true_div, "a", "b"), | |
(aet.mul, "a", (aet.reciprocal, "b")), | |
allow_multiple_clients=True, | |
name="div_to_mul", | |
tracks=[aet.true_div], | |
get_nodes=get_clients, | |
) | |
# a - b -> a + (-b) | |
sub_to_add_pattern = PatternSub( | |
(aet.sub, "a", "b"), | |
(aet.add, "a", (aet.neg, "b")), | |
allow_multiple_clients=True, | |
name="sub_to_add", | |
tracks=[aet.sub], | |
get_nodes=get_clients, | |
) | |
# a * (x + y) -> a * x + a * y | |
distribute_mul_pattern_lhs = PatternSub( | |
(aet.mul, "a", (aet.add, "x", "y")), | |
(aet.add, (aet.mul, "a", "x"), (aet.mul, "a", "y")), | |
allow_multiple_clients=True, | |
name="distribute_mul_lhs", | |
tracks=[aet.mul], | |
get_nodes=get_clients, | |
) | |
# (x + y) * a -> x * a + y * a | |
distribute_mul_pattern_rhs = PatternSub( | |
(aet.mul, (aet.add, "x", "y"), "a"), | |
(aet.add, (aet.mul, "x", "a"), (aet.mul, "y", "a")), | |
allow_multiple_clients=True, | |
name="distribute_mul_rhs", | |
tracks=[aet.mul], | |
get_nodes=get_clients, | |
) | |
# a.dot(x + y) -> a.dot(x) + a.dot(y) | |
distribute_dot_pattern_lhs = PatternSub( | |
(aet.dot, "a", (aet.add, "x", "y")), | |
(aet.add, (aet.dot, "a", "x"), (aet.dot, "a", "y")), | |
allow_multiple_clients=True, | |
name="distribute_dot_lhs", | |
tracks=[aet.dot], | |
get_nodes=get_clients, | |
) | |
# (x + y).dot(a) -> x.dot(a) + y.dot(a) | |
distribute_dot_pattern_rhs = PatternSub( | |
(aet.dot, (aet.add, "x", "y"), "a"), | |
(aet.add, (aet.dot, "x", "a"), (aet.dot, "y", "a")), | |
allow_multiple_clients=True, | |
name="distribute_dot_rhs", | |
tracks=[aet.dot], | |
get_nodes=get_clients, | |
) | |
@local_optimizer([aet.add, aet.mul]) | |
def local_add_mul_fusion(fgraph, node): | |
"""Fuse consecutive add or mul in one such node with more inputs.""" | |
if node.op not in (aet.add, aet.mul): | |
return False | |
outer_op = node.op | |
new_inputs = [] | |
fused = False | |
orig_inputs = len(node.inputs) | |
max_inputs = float("inf") | |
if hasattr(node.op, "max_inputs"): | |
max_inputs = node.op.max_inputs(node) | |
for inp in node.inputs: | |
if ( | |
inp.owner | |
and inp.owner.op == outer_op | |
and (orig_inputs + len(inp.owner.inputs) - 1) <= max_inputs | |
): | |
new_inputs.extend(inp.owner.inputs) | |
fused = True | |
else: | |
new_inputs.append(inp) | |
if fused: | |
output = node.op(*new_inputs) | |
copy_stack_trace(node.outputs[0], output) | |
# Do the recursion here to help lower the number of `FusionOptimizer` | |
# iterations. | |
if output.owner: | |
next_output = local_add_mul_fusion.transform(fgraph, output.owner) | |
if next_output: | |
return next_output | |
return [output] | |
class ExpandOptimizer(EquilibriumOptimizer): | |
def __init__(self): | |
super().__init__( | |
[ | |
div_to_mul_pattern, | |
distribute_mul_pattern_lhs, | |
distribute_mul_pattern_rhs, | |
distribute_dot_pattern_lhs, | |
distribute_dot_pattern_rhs, | |
sub_to_add_pattern, | |
local_add_mul_fusion, | |
], | |
ignore_newtrees=False, | |
tracks_on_change_inputs=True, | |
max_use_ratio=10000, | |
) | |
def uniquify_terms(fgraph): | |
"""Unique-ify repeated terms in a graph. | |
This is an easy way perform manipulations on a copy of a graph and later | |
determine which manipulated terms correspond to terms in the original | |
graph. | |
Consider this graph: ``a * (b + c) + log(a) * a``. If it were transformed | |
into a gathered product, i.e. ``a * (b + c + log(a))``, and decided that the | |
outer-most ``mul(a, ...)`` should be removed *in the original graph*, one | |
would need to determine that the ``a`` in the first index of the | |
``mul(a, b + c)`` node and the ``a`` in the second index of the | |
``mul(log(a), a)`` node correspond to the outer-most ``mul(a, ...)`` in the | |
transformed graph. | |
This function replaces variables that are used in more than one place in a | |
graph. In the example above, ``a * (b + c) + log(a) * a`` is converted to | |
``a_2 * (b + c) + log(a_1) * a_0`` and the map | |
``{a_0: (a, node(log(a_1) * a_0), 1), a_2: (a, node(a_2 * (b + c)), 0), ...}`` | |
is returned. | |
The values in the map correspond to the replaced term, the node in which | |
the variable was replaced, and the input index for the term that was | |
replace in the node (i.e. ``node.inputs[index] == term``). Together these | |
values tell you which duplicate term was replaced and exactly where. | |
""" | |
vars_to_clones = {} | |
nodes_to_update = [] | |
replaced_nodes = set() | |
if not hasattr(fgraph, "merge_feature"): | |
fgraph.attach_feature(MergeFeature()) | |
MergeOptimizer().apply(fgraph) | |
rev_ordered_vars = ( | |
sum([n.outputs for n in reversed(fgraph.toposort())], []) + fgraph.inputs | |
) | |
for var in rev_ordered_vars: | |
clients = fgraph.clients[var] | |
if var.owner and var.owner.op in (aet.mul, aet.add, aet.neg, aet.dot): | |
continue | |
# We don't want to recreate variables that we're trying to rewrite. | |
# e.g. if we want to apply the rewrites `a -> a_0` and `log(a) -> | |
# log_a_0`, then we can't replace `a` in the `log(a)` node; | |
# otherwise, we would produce a new `log(a_0)` for which `log(a) -> | |
# log_a_0` no longer applies. | |
filtered_clients = [(c, i) for c, i in clients if c not in replaced_nodes] | |
if len(filtered_clients) > 1: | |
n = 0 | |
for c, i in filtered_clients: | |
new_var = var.clone() | |
vars_to_clones[(var, c, i)] = new_var | |
new_var.tag.root_var = var | |
new_var.name = f"{var.name or var}_{n}" | |
n += 1 | |
nodes_to_update.append(c) | |
# XXX: This can be expensive and should be replaced with something | |
# more "incremental" | |
replaced_nodes |= set(applys_between(fgraph.inputs, [var])) | |
# Now that we have all the unique-ified variables, we can recreate the | |
# nodes that need to use them. We couldn't do that above, because some | |
# nodes will contain more than one unique-ified variable, and we need to | |
# replace them all at once (i.e. we don't want to keep recreating nodes | |
# when we find new terms to unique-ify). | |
replacements = {} | |
for c in nodes_to_update: | |
new_c_inputs = [ | |
vars_to_clones.get((inp, c, idx), inp) for idx, inp in enumerate(c.inputs) | |
] | |
new_c = c.op.make_node(*new_c_inputs) | |
for new_out, old_out in zip(new_c.outputs, c.outputs): | |
replacements[old_out] = new_out | |
breakpoint() | |
# XXX: The replacements need to inside the replaced terms, as well. | |
fgraph.replace_all( | |
list(replacements.items()), reason="uniquify", import_missing=True | |
) | |
return {v: k for k, v in vars_to_clones.items()} | |
def test_uniquify_terms_1(): | |
a, b, c, x = aet.scalars("abcx") | |
# log_1 = aet.log(a + x) | |
# mul_1 = aet.mul(a, x, b, log_1) | |
# mul_2 = a * log_1 | |
# mul_3 = aet.mul(a, c, log_1) | |
# f_x = aet.add( | |
# mul_1, | |
# mul_2, | |
# mul_3, | |
# ) | |
# f_x = FunctionGraph([a, b, c, x], [f_x], clone=False) | |
# vars_to_orig = uniquify_terms(f_x) | |
# dprint(f_x) | |
# # TODO: Add some `assert`s | |
# assert tuple(vars_to_orig.values()) == ((log_1, mul_1.owner, 0), (log_1, mul_1.owner, 0)) | |
# This example should confirm that the following `log(a + b)` | |
# terms will be unique-ified--and *not* their interiors | |
log_1 = aet.log(a + b) | |
mul_1 = log_1 * x | |
mul_2 = log_1 * c | |
f_x = mul_1 * b + mul_2 | |
f_x = FunctionGraph([a, b, c, x], [f_x], clone=False) | |
vars_to_orig = uniquify_terms(f_x) | |
dprint(f_x.outputs) | |
# These should not be touched; only the log terms should | |
assert set(vars_to_orig.values()) == set( | |
((log_1, mul_2.owner, 0), (log_1, mul_1.owner, 0)) | |
) | |
def test_uniquify_terms_2(): | |
a, b, c, x = aet.scalars("abcx") | |
mul_1 = a * (x * b) | |
mul_2 = a * c | |
mul_2.name = "mul_2" | |
mul_3 = a * b | |
mul_3.name = "mul_3" | |
add_1 = a + mul_2 | |
f_x = mul_1 + add_1 * aet.log(mul_3 * x + mul_2) | |
f_x = FunctionGraph([a, b, c, x], [f_x], clone=False) | |
dprint(f_x) | |
# f_x.replace(mul_2, x) | |
# dprint(f_x) | |
vars_to_orig = uniquify_terms(f_x) | |
dprint(f_x) | |
class RemoveMulConstants(GlobalOptimizer): | |
def __init__(self, input_vars): | |
self.input_vars = input_vars | |
self.expand_opt = ExpandOptimizer() | |
def add_requirements(self, fgraph): | |
from aesara.graph.features import ReplaceValidate | |
fgraph.attach_feature(ReplaceValidate()) | |
if not hasattr(fgraph, "merge_feature"): | |
fgraph.attach_feature(MergeFeature()) | |
def apply(self, fgraph): | |
# Make sure all duplicate terms are merged so that we can collect them | |
MergeOptimizer().apply(fgraph) | |
# We need to collapse and distribute terms across multiplication and | |
# addition so that we have a simple graph from which to collect the | |
# required information about terms (i.e. which multiplicative and | |
# additive terms can be removed because they don't depend on | |
# `input_vars`). | |
# This is performed on a _different_, intermediate graph | |
# (i.e. `int_fgraph`), so we need a map that takes us from the terms in | |
# the intermediate graph to the original graph (i.e. `fgraph`). | |
int_fgraph, equiv = fgraph.clone_get_equiv() | |
rev_equiv = {v: k for k, v in equiv.items()} | |
vars_to_orig = { | |
k: (rev_equiv[v], rev_equiv[n], i) | |
for k, (v, n, i) in uniquify_terms(int_fgraph) | |
} | |
new_input_vars = {equiv[v] for v in self.input_vars} | |
_ = self.expand_opt.optimize(int_fgraph) | |
for trans_out in int_fgraph.outputs: | |
trans_out_node = trans_out.owner | |
if trans_out_node and trans_out_node.op == aet.add: | |
# TODO: Get rid of addends with no connections to `input_vars` | |
# I.e. `(a * b * ...) + (c * d * ...) + ...` | |
if all( | |
i.owner and i.owner.op == aet.mul for i in trans_out_node.inputs | |
): | |
# We want the multiplicative terms that are present in | |
# every addend, but we need to use the "root_var"s for each | |
# term in these `mul`s, because every term was made unique | |
# above. | |
# To do this, we create a map from the root var to the | |
inputs_maps = [] | |
for i in trans_out_node.inputs: | |
in_dict = defaultdict(list) | |
inputs_maps.append(in_dict) | |
for ii in i.owner.inputs: | |
in_dict[ii].append(getattr(ii.tag, "root_var", ii)) | |
distributed_terms = set.intersection( | |
*(set(im.keys()) for im in inputs_maps) | |
) | |
var_root_vars = set() | |
for dt in distributed_terms: | |
for m in inputs_maps: | |
var_root_vars |= {(a, dt) for a in m[dt]} | |
replacements = [] | |
for var, root_var in var_root_vars: | |
if any(a in new_input_vars for a in ancestors([root_var])): | |
# We don't want to remove terms that contain the `input_vars` | |
# TODO: We should cache the results of this | |
# condition so they can be reused in the additive | |
# terms check | |
continue | |
# Here, we obtain the original node in which this | |
# distributed term, `dterm`, resided. | |
# E.g. if `mul(a, add(b, c))` was rewritten to | |
# `add(mul(a, b), mul(a, c))`, then we need a reference | |
# to the old `mul(a, ...)` node so we can | |
# replace/remove it in the original `fgraph` (after | |
# remapping with `rev_equiv`). | |
(orig_var, orig_node, orig_in_idx) = vars_to_orig[var] | |
assert orig_node in fgraph.apply_nodes | |
# Create a new version of the node containing this | |
# "constant" term from the fgraph | |
new_inputs = list(orig_node.inputs) | |
del new_inputs[orig_in_idx] | |
if len(new_inputs) == 1: | |
# Get rid of the `add`/`mul` entirely | |
new_out = new_inputs[0] | |
else: | |
# Remove the distributed argument from the `add`/`mul` | |
# XXX: This will only work for single-output nodes | |
new_out = orig_node.op.make_node(*new_inputs).outputs[0] | |
replacements.append((orig_node.outputs[0], new_out)) | |
fgraph.replace_all(replacements) | |
def test_RemoveMulConstants(): | |
a, b, c, x = aet.scalars("abcx") | |
f_x = a * (x * b + c * aet.log(a * (b * x + c))) | |
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x])) | |
dprint(f_x_fg) | |
# It should've removed the outer-most `mul(a, ...)` | |
assert f_x_fg.outputs[0].owner.op == aet.add | |
# It shouldn't have touched the rest | |
assert f_x_fg.outputs[0] == f_x.owner.inputs[1] | |
f_x = aet.log(x) * b + aet.log(x) * c | |
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x])) | |
dprint(f_x_fg) | |
# Make sure that the `log(x)` terms haven't been removed | |
mul1 = f_x_fg.outputs[0].owner.inputs[0].owner | |
mul2 = f_x_fg.outputs[0].owner.inputs[1].owner | |
assert x in mul1.inputs[0].owner.inputs | |
assert x in mul2.inputs[0].owner.inputs | |
f_x = (aet.log(a + b) * x) * b + aet.log(a + b) * c | |
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x])) | |
dprint(f_x_fg) | |
# Make sure that the `log(a)` terms have been removed | |
mul_term = f_x_fg.outputs[0].owner.inputs[0].owner | |
c_term = f_x_fg.outputs[0].owner.inputs[1] | |
assert mul_term.op == aet.mul | |
assert c_term is c | |
f_x = a * x * b + c * aet.log(a * (b * x + c)) * a | |
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x])) | |
dprint(f_x_fg) | |
assert f_x_fg.outputs[0].owner.op == aet.add | |
add_node = f_x_fg.outputs[0].owner | |
# Make sure that `a` was removed from the inner-two `mul`s | |
assert a not in add_node.inputs[0].owner.inputs | |
assert a not in add_node.inputs[1].owner.inputs | |
# Make sure `a` wasn't removed from the `log` term | |
log_node = add_node.inputs[1].owner.inputs[1].owner | |
assert a in log_node.inputs[0].owner.inputs | |
f_x = a * x * b + (a + a * c) * aet.log(a * b * x + a * c) | |
# dprint(optimize_graph(f_x, custom_opt=ExpandOptimizer())) | |
f_x_fg = optimize_graph(f_x, custom_opt=RemoveMulConstants([x])) | |
dprint(f_x_fg) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment