Created
July 29, 2018 22:00
-
-
Save JonathanRaiman/f798aa2a0afcab72fd473e65921ffcf9 to your computer and use it in GitHub Desktop.
Memoized Graph Optimization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager | |
import time | |
CURRENT_SCOPE = [] | |
@contextmanager | |
def printing_scope(message): | |
CURRENT_SCOPE.append(message) | |
yield | |
last = CURRENT_SCOPE.pop() | |
assert(last == message) | |
def scope(): | |
return "/".join(CURRENT_SCOPE) | |
OPTIMIZATION_SCOPE = None | |
@contextmanager | |
def optimization_scope(): | |
global OPTIMIZATION_SCOPE | |
old_array_scope = OPTIMIZATION_SCOPE | |
OPTIMIZATION_SCOPE = True | |
yield | |
OPTIMIZATION_SCOPE = old_array_scope | |
class Array(object): | |
def __init__(self, expression): | |
self._expression = expression | |
# print(scope(), self) | |
def shape(self): | |
return self._expression.shape() | |
def is_buffer(self): | |
return isinstance(self._expression, Buffer) | |
def is_jit_node(self): | |
return isinstance(self._expression, JITNode) | |
def is_jit_runner(self): | |
return isinstance(self._expression, JITRunner) | |
def is_scalar(self): | |
return False | |
def is_assignment(self): | |
return isinstance(self._expression, Assignment) | |
def is_available_buffer(self): | |
return self.is_buffer() and len(self._expression.arguments()) == 0 | |
def is_assignable(self): | |
return self._expression.is_assignable() | |
def expression(self): | |
return self._expression | |
def is_stateless(self): | |
return self._expression is None | |
def set_expression(self, expression): | |
self._expression = expression | |
def buffer_arg(self): | |
barg = self._expression.buffer_arg() | |
return Array(barg) | |
def __str__(self): | |
return str(self._expression) | |
def __repr__(self): | |
return str(self) | |
class Expression(object): | |
def __init__(self, arguments): | |
self._optimization_scope = OPTIMIZATION_SCOPE | |
self._arguments = arguments | |
def shape(self): | |
return "Shape" | |
def arguments(self): | |
return self._arguments | |
def is_assignable(self): | |
return False | |
def __str__(self): | |
return type(self).__name__ + "(" + ", ".join([str(arg) for arg in self._arguments]) + ")" | |
def __repr__(self): | |
return str(self) | |
def supports_operator(self, op): | |
if op == "=": | |
return True | |
return False | |
class JITNode(Expression): | |
pass | |
class Buffer(Expression): | |
def __init__(self, shape=None, shape_fun=None): | |
super().__init__([]) | |
self._shape = shape | |
self.shape_fun = shape_fun | |
def shape(self): | |
return self._shape | |
def is_assignable(self): | |
return True | |
def buffer_arg(self): | |
return self | |
def rebuild_instructions(self, source_graph): | |
buffer_shape_location = find_node(self.shape_fun, source_graph) | |
return lambda new_graph, args: Array(Buffer(shape=recover_node(new_graph, buffer_shape_location, 0).shape())) | |
class Add(JITNode): | |
def __init__(self, left, right): | |
super().__init__([left, right]) | |
def right(self): | |
return self._arguments[1] | |
def left(self): | |
return self._arguments[0] | |
class Assignment(Expression): | |
def __init__(self, left, operator_t, right): | |
super().__init__([left, right]) | |
self._operator_t = operator_t | |
def right(self): | |
return self._arguments[1] | |
def left(self): | |
return self._arguments[0] | |
def buffer_arg(self): | |
return self.left().expression().buffer_arg() | |
def is_assignable(self): | |
return True | |
def rebuild_instructions(self, source_graph): | |
operator = self._operator_t | |
return lambda new_graph, args: Array(Assignment(args[0], operator, args[1])) | |
class JITRunner(JITNode): | |
def __init__(self, root, leaves, operator_t, dest): | |
super().__init__(leaves) | |
self._dest = dest | |
self._root = root | |
self._operator_t = operator_t | |
def __str__(self): | |
return "JITRunner[" + str(self._root) + "](" + ", ".join([str(arg) for arg in self._arguments]) + ")" | |
def rebuild_instructions(self, source_graph): | |
root_tgraph = transformation_graph(source_graph, construct_pointer_tree(self._root)) | |
dest_tgraph = transformation_graph(source_graph, construct_pointer_tree(self._dest)) | |
operator = self._operator_t | |
def rebuild(new_graph, args): | |
return Array(JITRunner(transform_with_graph(new_graph, root_tgraph), | |
args, | |
operator, | |
transform_with_graph(new_graph, dest_tgraph))) | |
return rebuild | |
class Identity(JITNode): | |
def __init__(self, arg): | |
super().__init__([arg]) | |
def to_assignment(node): | |
with printing_scope("to_assignment"): | |
if OPTIMIZATION_SCOPE: | |
buff = Array(Buffer(shape=node.shape(), shape_fun=node.expression())) | |
else: | |
buff = Array(Buffer(shape=node.shape())) | |
return Array(Assignment(buff, "=", Array(node.expression()))) | |
def right_args(node): | |
return node.expression().right().expression().arguments() | |
def identity(node): | |
return Array(Identity(node)) | |
def jit_merge_arguments(arguments, leaves): | |
for arg_idx, arg in enumerate(arguments): | |
if arg.is_assignment() and arg.expression().right().is_jit_node(): | |
if arg.expression().right().is_jit_runner(): | |
# grab leaves from existing jit-runner recursively: | |
extra_leaves = arg.expression().right().expression().arguments() | |
leaves.extend(extra_leaves) | |
# if the node is an assignment to a buffer, ensure that | |
# the assignment op gets included within this op | |
# (e.g. by spoofing the assignment and replacing it with | |
# the equivalent JIT op) | |
replaced, left_leaf = replace_assign_with_inplace(arg) | |
# if the assignment involves using the left-side (e.g. | |
# left += right -> left + right), then keep the left node | |
# as a dependency leaf: | |
if not left_leaf.is_stateless(): | |
leaves.append(left_leaf) | |
# now that the jitrunners and assignments are gone, connect | |
# up the new operation in the graph: | |
arguments[arg_idx] = Array(replaced.expression()) | |
elif arg.is_assignment or arg.is_buffer(): | |
# detach the assignment subgraph and only keep the left node(bufferview) | |
leaf_arg = Array(arg.expression()) | |
# both assignments and buffers get stripped of their data. If a buffer | |
# had control dependencies, those are temporarily dropped to build the JIT graph | |
arg.set_expression(arg.expression().buffer_arg()) | |
leaves.append(leaf_arg) | |
elif arg.is_jit_node(): | |
# has already been added to the operation | |
pass | |
else: | |
# this node is either an assignment, or a buffer, | |
# and is needed as an input here: | |
leaves.append(arg) | |
def jit_root(array): | |
if array.is_jit_runner(): | |
return array.expression()._root | |
return array | |
def replace_assign_with_inplace(node): | |
assign = node.expression() | |
rightside = jit_root(assign.right()) | |
operator_t = assign._operator_t | |
if operator_t == "=": | |
# in cases where assignment was an implicit cast, ensure that type gets | |
# assigned to new destination | |
if rightside.is_scalar() and not node.is_scalar(): | |
rightside = jit_tile_scalar(rightside) | |
return (rightside, Array(None)) | |
else: | |
raise ValueError() | |
#} else if (operator_t == OPERATOR_T_ADD) { | |
# return std::tuple<Array, Array>(op::add(assign->left(), rightside), assign->left()); | |
#} else if (operator_t == OPERATOR_T_SUB) { | |
# return std::tuple<Array, Array>(op::subtract(assign->left(), rightside), assign->left()); | |
#} else if (operator_t == OPERATOR_T_MUL) { | |
# return std::tuple<Array, Array>(op::eltmul(assign->left(), rightside), assign->left()); | |
#} else if (operator_t == OPERATOR_T_DIV) { | |
# return std::tuple<Array, Array>(op::eltdiv(assign->left(), rightside), assign->left()); | |
#} else { | |
# throw std::runtime_error(utils::make_message("No way to replace_assign_with_inplace using operator ", | |
# operator_to_name(operator_t), ".")); | |
#} | |
def jit_merge(root): | |
leaves = [] | |
assign = root.expression() | |
assert assign.left().is_assignable(), "Assignment destination is not assignable" | |
root_buffer = root.buffer_arg() | |
assert not root_buffer.is_stateless(), "Assignment destination for JIT assignmentdoes not contain a valid buffer destination (check if left side of the assignment is assignable)." | |
root_operator = assign._operator_t | |
new_root = Array(assign.right().expression()) | |
if new_root.is_assignment() or new_root.is_buffer(): | |
leaves.append(new_root); | |
new_root = identity(new_root.buffer_arg()) | |
jit_merge_arguments(new_root.expression().arguments(), leaves) | |
jit_merge_arguments(root_buffer.expression().arguments(), leaves) | |
return Array(Assignment( | |
# keep the original target buffer: | |
assign.left(), root_operator, | |
# use the merged operation instead | |
Array(JITRunner(new_root, leaves, root_operator, root_buffer)))) | |
def all_assignments_or_buffers(node, parent_is_assignment, expression_to_assign, assignment_visited): | |
if node.is_buffer(): | |
for arg_idx, arg in enumerate(node.expression().arguments()): | |
new_arg_expr = all_assignments_or_buffers(arg, | |
False, | |
node, | |
expression_to_assign, | |
assignment_visited).expression() | |
if new_arg_expr != arg.expression(): | |
arg.set_expression(new_arg_expr) | |
return node | |
else: | |
if not parent_is_assignment and not node.is_assignment() and not node.is_assignable(): | |
if node.expression() not in expression_to_assign: | |
prev_expr = node.expression() | |
node.set_expression(to_assignment(node).expression()) | |
expression_to_assign[prev_expr] = node.expression() | |
else: | |
node.set_expression(expression_to_assign[node.expression()]) | |
if node.is_assignment(): | |
if node.expression() not in assignment_visited: | |
assignment_visited.add(node.expression()) | |
prev_expr = node.expression() | |
node_assign = node.expression() | |
node_right_is_assign = node_assign.right().is_assignment() | |
# replace chained assignments with a direct assignment: | |
# TODO(jonathan): check if anyone was depending on the intermediary storage | |
if node_right_is_assign: | |
node_right_assign = node_assign.right().expression() | |
if (node_right_assign._operator_t == "=" and | |
node_right_assign.right().expression().supports_operator(node_assign._operator_t)): | |
new_node_right_assign_right_expr = node_right_assign.right().expression() | |
node_assign.right().set_expression(new_node_right_assign_right_expr) | |
node_right_is_assign = node_assign.right().is_assignment() | |
# ensure children of the assignment are also only assignments or buffers: | |
for arg in right_args(node): | |
new_arg_expr = all_assignments_or_buffers(arg, node_right_is_assign, expression_to_assign, assignment_visited).expression() | |
arg.set_expression(new_arg_expr) | |
new_left_expr = all_assignments_or_buffers(node_assign.left(), False, expression_to_assign, assignment_visited).expression() | |
node_assign.left().set_expression(new_left_expr) | |
else: | |
for arg_idx, arg in enumerate(node.expression().arguments()): | |
new_arg_expr = all_assignments_or_buffers(arg, | |
False, | |
expression_to_assign, | |
assignment_visited).expression() | |
arg.set_expression(new_arg_expr) | |
return node | |
OPTIMIZATIONS = [] | |
class Optimization(object): | |
def __init__(self, condition, transform): | |
self._condition = condition | |
self._transform = transform | |
def matches(self, root): | |
return self._condition(root) | |
def transform(self, root): | |
return self._transform(root) | |
def simplify_destination(root, expression_to_opti): | |
# leaf node: | |
if root.is_available_buffer(): | |
return | |
# TODO(jonathan): this optimization lookup could possibly be avoided by pre-recognizing that certain nodes | |
# can be grouped and optimized once. | |
if root.expression() not in expression_to_opti: | |
ptr = root.expression() | |
# recurse on children/arguments of node: | |
for arg in root.expression().arguments(): | |
simplify_destination(arg, expression_to_opti) | |
for optimization in OPTIMIZATIONS: | |
if optimization.matches(root): | |
new_root = optimization.transform(root) | |
root.set_expression(new_root.expression()) | |
expression_to_opti[ptr] = root.expression() | |
else: | |
root.set_expression(expression_to_opti[root.expression()]) | |
class PointerGraph(object): | |
def __init__(self, value, children): | |
self._value = value | |
self._children = children | |
def construct_pointer_tree(node, constructed_nodes=None): | |
if constructed_nodes is None: | |
constructed_nodes = {} | |
if node.expression() in constructed_nodes: | |
return constructed_nodes[node.expression()] | |
else: | |
root = PointerGraph(node.expression(), []) | |
constructed_nodes[node.expression()] = root | |
root._children = [construct_pointer_tree(arg, constructed_nodes=constructed_nodes) for arg in node.expression().arguments()] | |
return root | |
def canonical(node): | |
with optimization_scope(): | |
with printing_scope("all_assignments_or_buffers"): | |
expression_to_assign = {} | |
assignment_visited = set() | |
pointer_tree = construct_pointer_tree(node) | |
node = all_assignments_or_buffers(node, False, expression_to_assign, assignment_visited) | |
with printing_scope("simplify_destination"): | |
expression_to_opti = {} | |
simplify_destination(node, expression_to_opti) | |
new_pointer_tree = construct_pointer_tree(node) | |
return node, (pointer_tree, new_pointer_tree) | |
def is_nested_jit_assignment(node): | |
return (node.is_assignment() and | |
((node.expression().left().is_jit_node() and len(node.right().expression().arguments()) > 0) or | |
(node.expression().right().is_jit_node() and | |
not node.expression().right().is_jit_runner() and | |
len(node.expression().right().expression().arguments()) > 0))) | |
OPTIMIZATIONS.append(Optimization(is_nested_jit_assignment, jit_merge)) | |
def add(left, right): | |
with printing_scope("add"): | |
return Array(Add(left, right)) | |
# Memoization | |
def find_node(node, graph, path=None): | |
if path is None: | |
path = [] | |
if node == graph._value: | |
return path | |
for idx, child in enumerate(graph._children): | |
solution = find_node(node, child, path=path + [idx]) | |
if solution is not None: | |
return solution | |
return None | |
class TransformationGraph(object): | |
def __init__(self, source_getter, children): | |
self._source_getter = source_getter | |
self._children = children | |
def __str__(self): | |
return str(self._source_getter) + "(" + ", ".join([str(arg) for arg in self._children]) + ")" | |
def __repr__(self): | |
return str(self) | |
def transformation_graph(source_graph, target_graph, constructed_nodes=None): | |
if constructed_nodes is None: | |
constructed_nodes = {} | |
if target_graph in constructed_nodes: | |
return constructed_nodes[target_graph] | |
else: | |
if target_graph._value._optimization_scope: | |
node_location = target_graph._value.rebuild_instructions(source_graph) | |
else: | |
node_location = find_node(target_graph._value, source_graph) | |
assert node_location is not None, "Could not locate {} in source.".format(target_graph._value) | |
transfo_graph = TransformationGraph(node_location, []) | |
constructed_nodes[target_graph] = transfo_graph | |
transfo_graph._children = [transformation_graph(source_graph, arg, constructed_nodes) | |
for arg in target_graph._children] | |
return transfo_graph | |
def recover_node(graph, directions, offset): | |
if offset == len(directions): | |
return graph | |
else: | |
return recover_node(graph.expression().arguments()[directions[offset]], directions, offset+1) | |
def transform_with_graph(node, stored_transformation): | |
with optimization_scope(): | |
arguments = [transform_with_graph(node, arg) for arg in stored_transformation._children] | |
if isinstance(stored_transformation._source_getter, list): | |
array = recover_node(node, stored_transformation._source_getter, 0) | |
assert isinstance(array, Array) | |
if len(arguments) > 0: | |
# TODO: ensure this rewired argument gets only set once. | |
array.expression()._arguments = arguments | |
else: | |
with optimization_scope(): | |
array = stored_transformation._source_getter(node, arguments) | |
assert isinstance(array, Array), "did not get Array with {}".format(stored_transformation._source_getter) | |
return array | |
# test logic: | |
def compare(left, other, path=""): | |
if isinstance(left, other.__class__): | |
if isinstance(left, (int, str, bool, type(None))): | |
return left == other | |
if isinstance(left, list): | |
return len(left) == len(other) and all([compare(l, r, path+ "[{}]".format(idx)) for idx, (l, r) in enumerate(zip(left, other))]) | |
if list(left.__dict__.keys()) != list(other.__dict__.keys()): | |
print("different keys {}".format(path)) | |
return False | |
for key in left.__dict__.keys(): | |
if key == "shape_fun": | |
continue | |
if not compare(left.__dict__[key], other.__dict__[key], path=path + "." + key): | |
return False | |
return True | |
else: | |
print("different classes {}".format(path), type(left), type(other)) | |
return False | |
def main(): | |
samples = 10000 | |
elapsed1 = 0.0 | |
transformation = None | |
for i in range(samples): | |
CURRENT_SCOPE.clear() | |
with printing_scope("main"): | |
addition = add(Array(Buffer()), to_assignment(to_assignment(add(Array(Buffer()), Array(Buffer()))))) | |
t0 = time.time() | |
old_addition, trees = canonical(addition) | |
t1 = time.time() | |
elapsed1 += t1 - t0 | |
if i == 0: | |
transformation = transformation_graph(trees[0], trees[1]) | |
elapsed2 = 0.0 | |
for i in range(samples): | |
CURRENT_SCOPE.clear() | |
with printing_scope("main"): | |
addition = add(Array(Buffer()), to_assignment(to_assignment(add(Array(Buffer()), Array(Buffer()))))) | |
t0 = time.time() | |
new_addition = transform_with_graph(addition, transformation) | |
t1 = time.time() | |
elapsed2 += t1 - t0 | |
assert compare(new_addition, old_addition, path="root"), "did not get the same transformation." |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment