Skip to content

Instantly share code, notes, and snippets.

@JonathanRaiman
Created July 29, 2018 22:00
Show Gist options
  • Save JonathanRaiman/f798aa2a0afcab72fd473e65921ffcf9 to your computer and use it in GitHub Desktop.
Save JonathanRaiman/f798aa2a0afcab72fd473e65921ffcf9 to your computer and use it in GitHub Desktop.
Memoized Graph Optimization
from contextlib import contextmanager
import time
CURRENT_SCOPE = []
@contextmanager
def printing_scope(message):
CURRENT_SCOPE.append(message)
yield
last = CURRENT_SCOPE.pop()
assert(last == message)
def scope():
return "/".join(CURRENT_SCOPE)
OPTIMIZATION_SCOPE = None
@contextmanager
def optimization_scope():
global OPTIMIZATION_SCOPE
old_array_scope = OPTIMIZATION_SCOPE
OPTIMIZATION_SCOPE = True
yield
OPTIMIZATION_SCOPE = old_array_scope
class Array(object):
def __init__(self, expression):
self._expression = expression
# print(scope(), self)
def shape(self):
return self._expression.shape()
def is_buffer(self):
return isinstance(self._expression, Buffer)
def is_jit_node(self):
return isinstance(self._expression, JITNode)
def is_jit_runner(self):
return isinstance(self._expression, JITRunner)
def is_scalar(self):
return False
def is_assignment(self):
return isinstance(self._expression, Assignment)
def is_available_buffer(self):
return self.is_buffer() and len(self._expression.arguments()) == 0
def is_assignable(self):
return self._expression.is_assignable()
def expression(self):
return self._expression
def is_stateless(self):
return self._expression is None
def set_expression(self, expression):
self._expression = expression
def buffer_arg(self):
barg = self._expression.buffer_arg()
return Array(barg)
def __str__(self):
return str(self._expression)
def __repr__(self):
return str(self)
class Expression(object):
def __init__(self, arguments):
self._optimization_scope = OPTIMIZATION_SCOPE
self._arguments = arguments
def shape(self):
return "Shape"
def arguments(self):
return self._arguments
def is_assignable(self):
return False
def __str__(self):
return type(self).__name__ + "(" + ", ".join([str(arg) for arg in self._arguments]) + ")"
def __repr__(self):
return str(self)
def supports_operator(self, op):
if op == "=":
return True
return False
class JITNode(Expression):
pass
class Buffer(Expression):
def __init__(self, shape=None, shape_fun=None):
super().__init__([])
self._shape = shape
self.shape_fun = shape_fun
def shape(self):
return self._shape
def is_assignable(self):
return True
def buffer_arg(self):
return self
def rebuild_instructions(self, source_graph):
buffer_shape_location = find_node(self.shape_fun, source_graph)
return lambda new_graph, args: Array(Buffer(shape=recover_node(new_graph, buffer_shape_location, 0).shape()))
class Add(JITNode):
def __init__(self, left, right):
super().__init__([left, right])
def right(self):
return self._arguments[1]
def left(self):
return self._arguments[0]
class Assignment(Expression):
def __init__(self, left, operator_t, right):
super().__init__([left, right])
self._operator_t = operator_t
def right(self):
return self._arguments[1]
def left(self):
return self._arguments[0]
def buffer_arg(self):
return self.left().expression().buffer_arg()
def is_assignable(self):
return True
def rebuild_instructions(self, source_graph):
operator = self._operator_t
return lambda new_graph, args: Array(Assignment(args[0], operator, args[1]))
class JITRunner(JITNode):
def __init__(self, root, leaves, operator_t, dest):
super().__init__(leaves)
self._dest = dest
self._root = root
self._operator_t = operator_t
def __str__(self):
return "JITRunner[" + str(self._root) + "](" + ", ".join([str(arg) for arg in self._arguments]) + ")"
def rebuild_instructions(self, source_graph):
root_tgraph = transformation_graph(source_graph, construct_pointer_tree(self._root))
dest_tgraph = transformation_graph(source_graph, construct_pointer_tree(self._dest))
operator = self._operator_t
def rebuild(new_graph, args):
return Array(JITRunner(transform_with_graph(new_graph, root_tgraph),
args,
operator,
transform_with_graph(new_graph, dest_tgraph)))
return rebuild
class Identity(JITNode):
def __init__(self, arg):
super().__init__([arg])
def to_assignment(node):
with printing_scope("to_assignment"):
if OPTIMIZATION_SCOPE:
buff = Array(Buffer(shape=node.shape(), shape_fun=node.expression()))
else:
buff = Array(Buffer(shape=node.shape()))
return Array(Assignment(buff, "=", Array(node.expression())))
def right_args(node):
return node.expression().right().expression().arguments()
def identity(node):
return Array(Identity(node))
def jit_merge_arguments(arguments, leaves):
for arg_idx, arg in enumerate(arguments):
if arg.is_assignment() and arg.expression().right().is_jit_node():
if arg.expression().right().is_jit_runner():
# grab leaves from existing jit-runner recursively:
extra_leaves = arg.expression().right().expression().arguments()
leaves.extend(extra_leaves)
# if the node is an assignment to a buffer, ensure that
# the assignment op gets included within this op
# (e.g. by spoofing the assignment and replacing it with
# the equivalent JIT op)
replaced, left_leaf = replace_assign_with_inplace(arg)
# if the assignment involves using the left-side (e.g.
# left += right -> left + right), then keep the left node
# as a dependency leaf:
if not left_leaf.is_stateless():
leaves.append(left_leaf)
# now that the jitrunners and assignments are gone, connect
# up the new operation in the graph:
arguments[arg_idx] = Array(replaced.expression())
elif arg.is_assignment or arg.is_buffer():
# detach the assignment subgraph and only keep the left node(bufferview)
leaf_arg = Array(arg.expression())
# both assignments and buffers get stripped of their data. If a buffer
# had control dependencies, those are temporarily dropped to build the JIT graph
arg.set_expression(arg.expression().buffer_arg())
leaves.append(leaf_arg)
elif arg.is_jit_node():
# has already been added to the operation
pass
else:
# this node is either an assignment, or a buffer,
# and is needed as an input here:
leaves.append(arg)
def jit_root(array):
if array.is_jit_runner():
return array.expression()._root
return array
def replace_assign_with_inplace(node):
assign = node.expression()
rightside = jit_root(assign.right())
operator_t = assign._operator_t
if operator_t == "=":
# in cases where assignment was an implicit cast, ensure that type gets
# assigned to new destination
if rightside.is_scalar() and not node.is_scalar():
rightside = jit_tile_scalar(rightside)
return (rightside, Array(None))
else:
raise ValueError()
#} else if (operator_t == OPERATOR_T_ADD) {
# return std::tuple<Array, Array>(op::add(assign->left(), rightside), assign->left());
#} else if (operator_t == OPERATOR_T_SUB) {
# return std::tuple<Array, Array>(op::subtract(assign->left(), rightside), assign->left());
#} else if (operator_t == OPERATOR_T_MUL) {
# return std::tuple<Array, Array>(op::eltmul(assign->left(), rightside), assign->left());
#} else if (operator_t == OPERATOR_T_DIV) {
# return std::tuple<Array, Array>(op::eltdiv(assign->left(), rightside), assign->left());
#} else {
# throw std::runtime_error(utils::make_message("No way to replace_assign_with_inplace using operator ",
# operator_to_name(operator_t), "."));
#}
def jit_merge(root):
leaves = []
assign = root.expression()
assert assign.left().is_assignable(), "Assignment destination is not assignable"
root_buffer = root.buffer_arg()
assert not root_buffer.is_stateless(), "Assignment destination for JIT assignmentdoes not contain a valid buffer destination (check if left side of the assignment is assignable)."
root_operator = assign._operator_t
new_root = Array(assign.right().expression())
if new_root.is_assignment() or new_root.is_buffer():
leaves.append(new_root);
new_root = identity(new_root.buffer_arg())
jit_merge_arguments(new_root.expression().arguments(), leaves)
jit_merge_arguments(root_buffer.expression().arguments(), leaves)
return Array(Assignment(
# keep the original target buffer:
assign.left(), root_operator,
# use the merged operation instead
Array(JITRunner(new_root, leaves, root_operator, root_buffer))))
def all_assignments_or_buffers(node, parent_is_assignment, expression_to_assign, assignment_visited):
if node.is_buffer():
for arg_idx, arg in enumerate(node.expression().arguments()):
new_arg_expr = all_assignments_or_buffers(arg,
False,
node,
expression_to_assign,
assignment_visited).expression()
if new_arg_expr != arg.expression():
arg.set_expression(new_arg_expr)
return node
else:
if not parent_is_assignment and not node.is_assignment() and not node.is_assignable():
if node.expression() not in expression_to_assign:
prev_expr = node.expression()
node.set_expression(to_assignment(node).expression())
expression_to_assign[prev_expr] = node.expression()
else:
node.set_expression(expression_to_assign[node.expression()])
if node.is_assignment():
if node.expression() not in assignment_visited:
assignment_visited.add(node.expression())
prev_expr = node.expression()
node_assign = node.expression()
node_right_is_assign = node_assign.right().is_assignment()
# replace chained assignments with a direct assignment:
# TODO(jonathan): check if anyone was depending on the intermediary storage
if node_right_is_assign:
node_right_assign = node_assign.right().expression()
if (node_right_assign._operator_t == "=" and
node_right_assign.right().expression().supports_operator(node_assign._operator_t)):
new_node_right_assign_right_expr = node_right_assign.right().expression()
node_assign.right().set_expression(new_node_right_assign_right_expr)
node_right_is_assign = node_assign.right().is_assignment()
# ensure children of the assignment are also only assignments or buffers:
for arg in right_args(node):
new_arg_expr = all_assignments_or_buffers(arg, node_right_is_assign, expression_to_assign, assignment_visited).expression()
arg.set_expression(new_arg_expr)
new_left_expr = all_assignments_or_buffers(node_assign.left(), False, expression_to_assign, assignment_visited).expression()
node_assign.left().set_expression(new_left_expr)
else:
for arg_idx, arg in enumerate(node.expression().arguments()):
new_arg_expr = all_assignments_or_buffers(arg,
False,
expression_to_assign,
assignment_visited).expression()
arg.set_expression(new_arg_expr)
return node
OPTIMIZATIONS = []
class Optimization(object):
def __init__(self, condition, transform):
self._condition = condition
self._transform = transform
def matches(self, root):
return self._condition(root)
def transform(self, root):
return self._transform(root)
def simplify_destination(root, expression_to_opti):
# leaf node:
if root.is_available_buffer():
return
# TODO(jonathan): this optimization lookup could possibly be avoided by pre-recognizing that certain nodes
# can be grouped and optimized once.
if root.expression() not in expression_to_opti:
ptr = root.expression()
# recurse on children/arguments of node:
for arg in root.expression().arguments():
simplify_destination(arg, expression_to_opti)
for optimization in OPTIMIZATIONS:
if optimization.matches(root):
new_root = optimization.transform(root)
root.set_expression(new_root.expression())
expression_to_opti[ptr] = root.expression()
else:
root.set_expression(expression_to_opti[root.expression()])
class PointerGraph(object):
def __init__(self, value, children):
self._value = value
self._children = children
def construct_pointer_tree(node, constructed_nodes=None):
if constructed_nodes is None:
constructed_nodes = {}
if node.expression() in constructed_nodes:
return constructed_nodes[node.expression()]
else:
root = PointerGraph(node.expression(), [])
constructed_nodes[node.expression()] = root
root._children = [construct_pointer_tree(arg, constructed_nodes=constructed_nodes) for arg in node.expression().arguments()]
return root
def canonical(node):
with optimization_scope():
with printing_scope("all_assignments_or_buffers"):
expression_to_assign = {}
assignment_visited = set()
pointer_tree = construct_pointer_tree(node)
node = all_assignments_or_buffers(node, False, expression_to_assign, assignment_visited)
with printing_scope("simplify_destination"):
expression_to_opti = {}
simplify_destination(node, expression_to_opti)
new_pointer_tree = construct_pointer_tree(node)
return node, (pointer_tree, new_pointer_tree)
def is_nested_jit_assignment(node):
return (node.is_assignment() and
((node.expression().left().is_jit_node() and len(node.right().expression().arguments()) > 0) or
(node.expression().right().is_jit_node() and
not node.expression().right().is_jit_runner() and
len(node.expression().right().expression().arguments()) > 0)))
OPTIMIZATIONS.append(Optimization(is_nested_jit_assignment, jit_merge))
def add(left, right):
with printing_scope("add"):
return Array(Add(left, right))
# Memoization
def find_node(node, graph, path=None):
if path is None:
path = []
if node == graph._value:
return path
for idx, child in enumerate(graph._children):
solution = find_node(node, child, path=path + [idx])
if solution is not None:
return solution
return None
class TransformationGraph(object):
def __init__(self, source_getter, children):
self._source_getter = source_getter
self._children = children
def __str__(self):
return str(self._source_getter) + "(" + ", ".join([str(arg) for arg in self._children]) + ")"
def __repr__(self):
return str(self)
def transformation_graph(source_graph, target_graph, constructed_nodes=None):
if constructed_nodes is None:
constructed_nodes = {}
if target_graph in constructed_nodes:
return constructed_nodes[target_graph]
else:
if target_graph._value._optimization_scope:
node_location = target_graph._value.rebuild_instructions(source_graph)
else:
node_location = find_node(target_graph._value, source_graph)
assert node_location is not None, "Could not locate {} in source.".format(target_graph._value)
transfo_graph = TransformationGraph(node_location, [])
constructed_nodes[target_graph] = transfo_graph
transfo_graph._children = [transformation_graph(source_graph, arg, constructed_nodes)
for arg in target_graph._children]
return transfo_graph
def recover_node(graph, directions, offset):
if offset == len(directions):
return graph
else:
return recover_node(graph.expression().arguments()[directions[offset]], directions, offset+1)
def transform_with_graph(node, stored_transformation):
with optimization_scope():
arguments = [transform_with_graph(node, arg) for arg in stored_transformation._children]
if isinstance(stored_transformation._source_getter, list):
array = recover_node(node, stored_transformation._source_getter, 0)
assert isinstance(array, Array)
if len(arguments) > 0:
# TODO: ensure this rewired argument gets only set once.
array.expression()._arguments = arguments
else:
with optimization_scope():
array = stored_transformation._source_getter(node, arguments)
assert isinstance(array, Array), "did not get Array with {}".format(stored_transformation._source_getter)
return array
# test logic:
def compare(left, other, path=""):
if isinstance(left, other.__class__):
if isinstance(left, (int, str, bool, type(None))):
return left == other
if isinstance(left, list):
return len(left) == len(other) and all([compare(l, r, path+ "[{}]".format(idx)) for idx, (l, r) in enumerate(zip(left, other))])
if list(left.__dict__.keys()) != list(other.__dict__.keys()):
print("different keys {}".format(path))
return False
for key in left.__dict__.keys():
if key == "shape_fun":
continue
if not compare(left.__dict__[key], other.__dict__[key], path=path + "." + key):
return False
return True
else:
print("different classes {}".format(path), type(left), type(other))
return False
def main():
samples = 10000
elapsed1 = 0.0
transformation = None
for i in range(samples):
CURRENT_SCOPE.clear()
with printing_scope("main"):
addition = add(Array(Buffer()), to_assignment(to_assignment(add(Array(Buffer()), Array(Buffer())))))
t0 = time.time()
old_addition, trees = canonical(addition)
t1 = time.time()
elapsed1 += t1 - t0
if i == 0:
transformation = transformation_graph(trees[0], trees[1])
elapsed2 = 0.0
for i in range(samples):
CURRENT_SCOPE.clear()
with printing_scope("main"):
addition = add(Array(Buffer()), to_assignment(to_assignment(add(Array(Buffer()), Array(Buffer())))))
t0 = time.time()
new_addition = transform_with_graph(addition, transformation)
t1 = time.time()
elapsed2 += t1 - t0
assert compare(new_addition, old_addition, path="root"), "did not get the same transformation."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment