Last active
November 27, 2018 02:25
-
-
Save JonathanRaiman/18251daab19c2f69c64d016db8d22b0c to your computer and use it in GitHub Desktop.
Dali graph transformation Plan
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Micro-dali JIT Plan: | |
- contains gemm, operator fusion, elementwise/reduction ops. | |
- supports tensordot | |
- supports 'jit' | |
- supports conversion from gemm + im2col to conv2d (NHWC) | |
- supports 'optimization' passes | |
- supports 'implementation' registries for specialization | |
(e.g. int vs float) | |
TODO: | |
- ? | |
STRECH GOALS: | |
- allow multiple outputs | |
- efficient patterns for optimization registry | |
- allow forced assignments (that should not be replaced by optimization) | |
- reorder matmuls? | |
- paced JIT running (do reductions, then do elementwise, etc...) | |
- async running with locks/mutexes? | |
""" | |
import numpy as np | |
import time | |
import tensorflow as tf | |
## REGISTRIES for implementations or graph transformations | |
IMPLEMENTATIONS = {} | |
OPTIMIZATIONS = [] | |
def tname(x): | |
return type(x).__name__ | |
def register_implementation(opname, implname): | |
""" | |
Mark that elements built from `opname` (e.g. MatMul) | |
can be implemented using `implname`. implname is | |
a callback that receives the current Array and | |
returns the appropriate implementation class. | |
""" | |
IMPLEMENTATIONS[opname.__name__] = implname | |
def register_optimization(condition, transformation): | |
""" | |
Update the computation graph if the local Array | |
matches the condition given by calling transformation | |
on that node (note: optimizations are run bottom-up). | |
""" | |
OPTIMIZATIONS.append((condition, transformation)) | |
class Expression(object): | |
def ndim(self): | |
return len(self.shape()) | |
def __str__(self): | |
return tname(self) + "(" + ", ".join([str(arg) for arg in self.arguments()]) + ")" | |
def __eq__(self, other): | |
if type(self) != type(other): | |
return False | |
self_args = self.arguments() | |
other_args = other.arguments() | |
if len(self_args) != len(other_args): | |
return False | |
return all(arg == oarg for arg, oarg in zip(self_args, other_args)) | |
def supports_operator(self, operator): | |
return operator == "=" | |
def shape_to_trivial_strides(shape): | |
res = [0 for _ in shape] | |
residual_shape = 1 | |
for i in reversed(range(0, len(shape))): | |
res[i] = residual_shape | |
residual_shape *= shape[i] | |
return res | |
class Buffer(Expression): | |
def __init__(self, data): | |
self._data = data | |
def __eq__(self, other): | |
if not isinstance(other, Buffer) or self._data.shape != other._data.shape: | |
return False | |
return np.alltrue(np.equal(self._data, other._data)) | |
def shape(self): | |
return self._data.shape | |
def dtype(self): | |
return self._data.dtype | |
def value(self): | |
return self._data | |
def arguments(self): | |
return [] | |
def data(self): | |
return self._data | |
def reshape(self, shape): | |
return Buffer(self._data.reshape(shape)) | |
def dimshuffle(self, axes): | |
return Buffer(self._data.transpose(axes)) | |
def contiguous_memory(self): | |
return self._data.flags.c_contiguous | |
def strides(self): | |
item_size = self._data.itemsize | |
return [stride // item_size for stride in self._data.strides] | |
def is_transpose(self): | |
ndim = self.ndim() | |
if ndim <= 1: | |
return True | |
if self.contiguous_memory(): | |
return False | |
reversed_shape = list(reversed(self.shape())) | |
reversed_strides = shape_to_trivial_strides(reversed_shape) | |
strides = self.strides() | |
for i in range(0, ndim): | |
if reversed_strides[i] != strides[ndim - 1 - i]: | |
return False | |
return True | |
def supports_operator(self, operator): | |
return True | |
class Array(object): | |
def __init__(self, internal): | |
self._expression = internal | |
self._simplified = False | |
def __eq__(self, other): | |
if not isinstance(other, Array): | |
return False | |
return self._expression == other._expression | |
def canonical(self): | |
node = self | |
# assignment pass | |
node = all_assignments_or_buffers(node) | |
# simplification pass (jit, merge, etc...) | |
return simplify_destination(node) | |
def eval(self): | |
if not isinstance(self._expression, Buffer): | |
node = self.canonical() | |
computable = convert_to_ops(node) | |
# run (DAG evaluation) | |
for step in computable: | |
step.run() | |
self._expression = node._expression._left._expression | |
def value(self): | |
self.eval() | |
return self._expression.value() | |
def __str__(self): | |
return str(self._expression) | |
@property | |
def dtype(self): | |
return self._expression.dtype() | |
@property | |
def shape(self): | |
return self._expression.shape() | |
@property | |
def ndim(self): | |
return self._expression.ndim() | |
@property | |
def T(self): | |
return transpose(self) | |
def buffer(x): | |
return Array(Buffer(x.copy())) | |
class Node(Expression): | |
pass | |
class Assignment(Expression): | |
def __init__(self, left, operator, right): | |
self._left = left | |
self._operator = operator | |
self._right = right | |
def arguments(self): | |
return [self._left, self._right] | |
def shape(self): | |
return self._left._expression.shape() | |
def dtype(self): | |
return self._left.dtype | |
def data(self): | |
return self._left._expression.data() | |
def __str__(self): | |
return (tname(self) + "(" + self._operator + ", " + | |
str(self._left) + ", " + str(self._right) + ")") | |
class ControlFlow(Expression): | |
"""Signify that a buffer (left) requires all the operations | |
in 'conditions' to complete before progressing""" | |
def __init__(self, left, conditions): | |
self._left = left | |
self._conditions = conditions | |
def arguments(self): | |
return [self._left] + self._conditions | |
def shape(self): | |
return self._left._expression.shape() | |
def dtype(self): | |
return self._left.dtype | |
def data(self): | |
return self._left._expression.data() | |
def type_promotion(left_dtype, right_dtype): | |
np_left_dtype = np.dtype(left_dtype) | |
np_right_dtype = np.dtype(right_dtype) | |
if np_left_dtype.kind == "i" and np_right_dtype.kind == "i": | |
if np_left_dtype.itemsize > np_right_dtype.itemsize: | |
return left_dtype | |
return right_dtype | |
if np_left_dtype.kind == "f": | |
if np_right_dtype.kind == "f" and np_right_dtype.itemsize > np_left_dtype.itemsize: | |
return right_dtype | |
return left_dtype | |
raise ValueError("unknown dtype promotion.") | |
def autoreduce_assign(left, right): | |
""" | |
Special autoreduction useful for gradient propagation. | |
Implements reduce all non matching dimensions to be 1 and | |
add to the left. Can be JIT-ed and combined with other | |
element-wise operations (if benefitial) | |
""" | |
if not isinstance(right._expression, Buffer): | |
right = to_assignment(right) | |
reduction_axes = [] | |
for axis, (left_dim, right_dim) in enumerate(zip(left.shape, right.shape)): | |
if left_dim == 1 and right_dim > 1: | |
reduction_axes.append(axis) | |
return assign(left, "+=", reduce_sum(right, reduction_axes, keep_dims=True)) | |
def assign(left, operator, right): | |
"""Assign right data to left destination with operator. | |
If the right side is not yet evaluated, and the operator | |
is not simple equality, then a temporary destination for | |
saving the right side is added. | |
""" | |
if operator == "=": | |
return Array(Assignment(left, operator, right)) | |
elif operator == "<<=": | |
res = autoreduce_assign(left, right) | |
return res | |
else: | |
if not isinstance(right._expression, Buffer): | |
right = to_assignment(right) | |
# a temp is added so that non overwriting operators | |
# can be run independently from the right side's evaluation. | |
return Array(Assignment(left, operator, right)) | |
def to_assignment(node): | |
return assign(buffer(np.zeros(node.shape, node.dtype)), | |
"=", Array(node._expression)) | |
def right_args(node): | |
return node._expression._right._expression.arguments() | |
def buffer_arg(node): | |
if isinstance(node._expression, Buffer): | |
return node | |
if isinstance(node._expression, (Assignment, ControlFlow)): | |
return node._expression._left | |
return None | |
class Allocate(object): | |
"""Dummy allocation step. Could involve data movement to/from GPU.""" | |
def __init__(self, node): | |
self._node = node | |
def run(self): | |
pass | |
### GRAPH TRANSFORMATION PASSES ### | |
def buffer_buffer_op(node): | |
print(node) | |
raise ValueError("hell no") | |
identity_node = Array(Identity(node._expression._right)) | |
el = Array(Assignment(node._expression._left, | |
node._expression._operator, | |
Array(JITRunner(identity_node, [node._expression._right])))) | |
return el | |
def convert_to_ops(root): | |
steps = [] | |
elements = [root] | |
while len(elements) > 0: | |
element = elements.pop() | |
if isinstance(element._expression, Buffer): | |
steps.append(Allocate(element)) | |
elif isinstance(element._expression, Assignment): | |
#if isinstance(element._expression._right._expression, Buffer) or | |
if isinstance(element._expression._right._expression, Assignment): | |
# TODO: clean this up | |
element = buffer_buffer_op(element) | |
name_of_object = tname(element._expression._right._expression) | |
if name_of_object in IMPLEMENTATIONS: | |
steps.append(IMPLEMENTATIONS[name_of_object](element._expression._right)( | |
element._expression._right, | |
element._expression._operator, | |
element._expression._left)) | |
# ensure destination is allocated: | |
elements.append(element._expression._left) | |
elements.extend(right_args(element)) | |
else: | |
raise ValueError("no way to implement %r" % (name_of_object,)) | |
elif isinstance(element._expression, ControlFlow): | |
# add all the dependencies of this node as a step to complete: | |
elements.extend(element._expression.arguments()) | |
else: | |
raise ValueError("can only convert Assignments and Buffers to ops (got %s)." % (str(element),)) | |
return list(reversed(steps)) | |
def can_copyless_reshape(node, shape): | |
""" | |
Returns True if the shape/strides of the node | |
are compatible with the new shape without requiring | |
a copy. | |
""" | |
if node._expression.contiguous_memory(): | |
return True | |
ndim = node.ndim | |
shape_ = node.shape | |
if len(shape) > ndim: | |
# check if the lowest dimensions will be identical | |
matching_lowest = True | |
for i in range(0, ndim): | |
if shape[len(shape) - i - 1] != shape_[ndim - i - 1]: | |
matching_lowest = False | |
break | |
is_ones_elsewhere = True | |
for i in range(0, len(shape) - ndim): | |
if shape[i] != 1: | |
is_ones_elsewhere = False | |
break | |
if matching_lowest and is_ones_elsewhere: | |
return True | |
return False | |
def can_reshape_inplace(node): | |
if not (isinstance(node._expression, Assignment) and | |
node._expression._operator == "=" and | |
isinstance(node._expression._right._expression, Reshape)): | |
return False | |
reshape_node = node._expression._right | |
return can_copyless_reshape(buffer_arg(reshape_node._expression._node), | |
reshape_node._expression._shape) | |
def to_reshape_inplace(node): | |
""" | |
Replace an assignment of buffer = reshape(buffer) by | |
an inplace reshape(buffer) + control flow for maintaining | |
antecedents of the buffer. | |
""" | |
reshape_node = node._expression._right | |
buffer_node = buffer_arg(reshape_node._expression._node) | |
shape = reshape_node._expression._shape | |
return Array(ControlFlow(Array(buffer_node._expression.reshape(shape)), | |
[reshape_node._expression._node])) | |
def is_jit_assignment(node): | |
return (isinstance(node._expression, Assignment) and | |
isinstance(node._expression._right._expression, JITNode) and | |
not isinstance(node._expression._right._expression, JITRunner)) | |
GEMM_OPERATORS = ("+=", "-=") | |
def is_chained_assignment(node): | |
""" assign(Left, operator, assign(Temp, '=', Right)) => | |
assign(Left, operator, Right)""" | |
return (isinstance(node._expression, Assignment) and | |
node._expression._operator == "=" and | |
isinstance(node._expression._right._expression, Assignment) and | |
node._expression._right._expression._operator == "=") | |
def is_chained_or_gemm_assignment(node): | |
return (isinstance(node._expression, Assignment) and | |
isinstance(node._expression._right._expression, Assignment) and | |
node._expression._right._expression._operator == "=" and | |
(node._expression._operator == "=" or | |
(node._expression._operator in GEMM_OPERATORS and | |
isinstance(node._expression._right._expression._right._expression, MatMul)))) | |
def jit_root(node): | |
if isinstance(node._expression, JITRunner): | |
return node._expression._root | |
return node | |
def replace_assign_with_inplace(node): | |
rightside = jit_root(node._expression._right) | |
operator = node._expression._operator | |
if operator == "=": | |
return rightside, None | |
elif operator == "+=": | |
return add(node._expression._left, rightside), node._expression._left | |
elif operator == "-=": | |
return substract(node._expression._left, rightside), node._expression._left | |
elif operator == "*=": | |
return eltmul(node._expression._left, rightside), node._expression._left | |
elif operator == "/=": | |
return eltdiv(node._expression._left, rightside), node._expression._left | |
else: | |
raise ValueError("cannot replace assign inplace with operator %r" % (operator,)) | |
def assign_merge(root): | |
original_root_buffer = root._expression._left | |
original_root_operator = root._expression._operator | |
return Array(Assignment(original_root_buffer, | |
original_root_operator, | |
root._expression._right._expression._right)) | |
def jit_merge(root): | |
leaves = [] | |
root_buffer = root._expression._left | |
root_operator = root._expression._operator | |
for arg in right_args(root): | |
if (isinstance(arg._expression, Assignment) and | |
isinstance(arg._expression._right._expression, JITRunner)): | |
# grab leaves from existing jit-runner recursively: | |
leaves.extend(arg._expression._right._expression._leaves) | |
# if the node is an assignment to a buffer, ensure that | |
# the assignment op gets included within this op | |
# (e.g. by spoofing the assignment and replacing it with | |
# the equivalent JIT op) | |
replaced, left_leaf = replace_assign_with_inplace(arg) | |
# if the assignment involves using the left-side (e.g. | |
# left += right -> left + right), then keep the left node | |
# as a dependency leaf: | |
if left_leaf is not None: | |
leaves.append(left_leaf) | |
# now that the jitrunners and assignments are gone, connect | |
# up the new operation in the graph: | |
arg._expression = replaced._expression | |
# elif isinstance(arg._expression, Assignment): | |
# new_arg = Array(arg._expression) | |
# arg._expression = arg._expression._left._expression | |
# leaves.append(new_arge) | |
else: | |
# this node is either an assignment, or a buffer, | |
# and is needed as an input here: | |
leaves.append(arg) | |
new_root = root._expression._right | |
return Array(Assignment( | |
# keep the original target buffer: | |
root_buffer, root_operator, | |
# use the merged operation instead | |
Array(JITRunner(new_root, leaves)))) | |
register_optimization(can_reshape_inplace, to_reshape_inplace) | |
register_optimization(is_jit_assignment, jit_merge) | |
register_optimization(is_chained_or_gemm_assignment, assign_merge) | |
def conv2d_merge(root): | |
original_root_buffer = root._expression._left | |
gemm_node = root._expression._conditions[0]._expression._right | |
im2col_node = gemm_node._expression._left._expression._conditions[0]._expression._right | |
x = im2col_node._expression._input | |
strides = im2col_node._expression._strides | |
padding = im2col_node._expression._padding | |
data_format = im2col_node._expression._data_format | |
filter_size = im2col_node._expression._filter_size | |
w = reshape(gemm_node._expression._right, (filter_size[0], filter_size[1], x.shape[3], | |
original_root_buffer.shape[3])) | |
return Array(Assignment(original_root_buffer, | |
"=", | |
Array(Conv2D(x, w, strides, padding, data_format)))) | |
def is_im2col_gemm(node): | |
"""6 micro second check""" | |
# TODO: add check if conv2D is supported for this data type. | |
is_cflow_assign = (isinstance(node._expression, ControlFlow) and | |
len(node._expression._conditions) == 1 and | |
isinstance(node._expression._conditions[0]._expression, Assignment) and | |
node._expression._conditions[0]._expression._operator == "=") | |
if not is_cflow_assign: | |
return False | |
gemm_node = node._expression._conditions[0]._expression._right | |
is_assign_gemm = (isinstance(gemm_node._expression, MatMul) and | |
isinstance(gemm_node._expression._left._expression, ControlFlow) and | |
len(gemm_node._expression._left._expression._conditions) == 1 and | |
isinstance(gemm_node._expression._left._expression._conditions[0]._expression, Assignment)) | |
if not is_assign_gemm: | |
return False | |
im2col_node = gemm_node._expression._left._expression._conditions[0]._expression._right | |
is_im2col = isinstance(im2col_node._expression, Im2col) | |
return is_im2col | |
register_optimization(is_im2col_gemm, conv2d_merge) | |
def simplify_destination(root): | |
# leaf node: | |
if isinstance(root._expression, Buffer): | |
return root | |
# recurse on children: | |
children = ([root._expression._right] | |
if isinstance(root._expression, Assignment) | |
else root._expression.arguments()) | |
# recurse on arguments of node: | |
for arg in children: | |
arg._expression = simplify_destination(arg)._expression | |
for condition, transformation in OPTIMIZATIONS: | |
if condition(root): | |
root = transformation(root) | |
return root | |
def all_assignments_or_buffers(root): | |
""" | |
Transform graph so that it only uses | |
Buffers or assignments of buffers. | |
(e.g. give everyone a destination) | |
""" | |
if isinstance(root._expression, Buffer): | |
return root | |
if not isinstance(root._expression, Assignment): | |
root = to_assignment(root) | |
if (isinstance(root._expression._right._expression, Assignment) and | |
root._expression._right._expression._operator == "=" and | |
root._expression._right._expression._right._expression.supports_operator(root._expression._operator)): | |
root._expression._right._expression = root._expression._right._expression._right._expression | |
for arg in right_args(root): | |
arg._expression = all_assignments_or_buffers(arg)._expression | |
return root | |
### OPS (REGISTRY) ### | |
class Computation(object): | |
"""Abstract Computation""" | |
def __init__(self, op, operator, target): | |
self._op = op | |
self._operator = operator | |
self._target = target | |
class JITNode(Node): | |
def supports_operator(self, operator): | |
return True | |
class JITRunner(JITNode): | |
"""Merged jit nodes into one.""" | |
def __init__(self, root, leaves): | |
if isinstance(root._expression, JITRunner): | |
raise ValueError("JITRunner should not contain a JITRunner.") | |
self._root = root | |
self._leaves = leaves | |
def arguments(self): | |
return self._leaves | |
def shape(self): | |
return self._root._expression.shape() | |
def dtype(self): | |
return self._root._expression.dtype() | |
def __str__(self): | |
# pretty print jit merged op into JIT[kernel](inputs) | |
return "JIT[" + str(self._root).replace("Buffer()", "X") + "](" + ", ".join([str(arg) for arg in self.arguments()]) + ")" | |
def jit_execute(root): | |
"""Recursive function for simulating JIT execution.""" | |
if isinstance(root._expression, (Buffer, Assignment, ControlFlow)): | |
return root._expression.data() | |
elif isinstance(root._expression, Add): | |
return jit_execute(root._expression._left) + jit_execute(root._expression._right) | |
elif isinstance(root._expression, Subtract): | |
return jit_execute(root._expression._left) - jit_execute(root._expression._right) | |
elif isinstance(root._expression, EltMul): | |
return jit_execute(root._expression._left) * jit_execute(root._expression._right) | |
elif isinstance(root._expression, EltDiv): | |
return jit_execute(root._expression._left) / jit_execute(root._expression._right) | |
elif isinstance(root._expression, Tanh): | |
return np.tanh(jit_execute(root._expression._node)) | |
elif isinstance(root._expression, Identity): | |
return jit_execute(root._expression._node) | |
elif isinstance(root._expression, ReduceSum): | |
return np_reduce_sum(jit_execute(root._expression._node), | |
root._expression._axis, | |
root._expression._keep_dims) | |
else: | |
raise ValueError("no jit execution for %r (%r)" % (str(root), type(root._expression))) | |
def assign_with_operator(left, operator, right): | |
if operator == "=": | |
left[:] = right | |
elif operator == "+=": | |
left[:] += right | |
elif operator == "-=": | |
left[:] -= right | |
elif operator == "*=": | |
left[:] *= right | |
elif operator == "/=": | |
left[:] /= right | |
else: | |
raise ValueError("unknown operator behavior %r" % (operator,)) | |
class JITRunnerImpl(Computation): | |
def run(self): | |
# simulate generating a kernel | |
# and doing the actual work in one call | |
root = self._op._expression._root | |
assign_with_operator(self._target._expression.data(), | |
self._operator, | |
jit_execute(root)) | |
register_implementation(JITRunner, lambda x: JITRunnerImpl) | |
def buffer_impl(op, operator, target): | |
right_identity = Array(Identity(op)) | |
right_runner = Array(JITRunner(right_identity, [op])) | |
return JITRunnerImpl(right_runner, operator, target) | |
register_implementation(Buffer, lambda x: buffer_impl) | |
class BinaryElementWise(JITNode): | |
def __init__(self, left, right): | |
self._left = left | |
self._right = right | |
def arguments(self): | |
return [self._left, self._right] | |
def shape(self): | |
return self._left.shape | |
def dtype(self): | |
return type_promotion(self._left._expression.dtype(), | |
self._right._expression.dtype()) | |
class UnitaryElementWise(JITNode): | |
def __init__(self, node): | |
self._node = node | |
def arguments(self): | |
return [self._node] | |
def shape(self): | |
return self._node.shape | |
def dtype(self): | |
return self._node.dtype | |
class Add(BinaryElementWise): | |
pass | |
class Subtract(BinaryElementWise): | |
pass | |
class EltMul(BinaryElementWise): | |
pass | |
class EltDiv(BinaryElementWise): | |
pass | |
def add(a, b): | |
return Array(Add(a, b)) | |
def substract(a, b): | |
return Array(Subtract(a, b)) | |
def eltmul(a, b): | |
return Array(EltMul(a, b)) | |
def eltdiv(a, b): | |
return Array(EltDiv(a, b)) | |
class Tanh(UnitaryElementWise): | |
"""tanh""" | |
def dtype(self): | |
return np.float64 | |
class ReduceSum(UnitaryElementWise): | |
def __init__(self, node, axis, keep_dims): | |
self._node = node | |
self._axis = axis | |
self._keep_dims = keep_dims | |
def shape(self): | |
shape = list(self._node.shape) | |
for ax in self._axis: | |
shape[ax] = 1 if self._keep_dims else 0 | |
if self._keep_dims: | |
return shape | |
return [dim for dim in shape if dim > 0] | |
def reduce_sum(node, axis, keep_dims=False): | |
# TODO: distinguish contiguous vs. non contiguous reduction | |
return Array(ReduceSum(node, axis, keep_dims)) | |
def tanh(x): | |
return Array(Tanh(x)) | |
class MatMul(Node): | |
def __init__(self, left, right): | |
self._left = left | |
self._right = right | |
def arguments(self): | |
return [self._left, self._right] | |
def shape(self): | |
return (self._left._expression.shape()[0], | |
self._right._expression.shape()[1]) | |
def dtype(self): | |
return type_promotion(self._left.dtype, self._right.dtype) | |
def calc_pad(pad, in_siz, out_siz, stride, ksize): | |
"""Calculate padding width. | |
Args: | |
pad: padding method, "SAME", "VALID", or manually speicified. | |
ksize: kernel size [I, J]. | |
Returns: | |
pad_: Actual padding width. | |
""" | |
if pad == 'SAME': | |
return int((out_siz - 1) * stride + ksize - in_siz) | |
elif pad == 'VALID': | |
return 0 | |
else: | |
return pad | |
def calc_size(h, kh, pad, sh): | |
"""Calculate output image size on one dimension. | |
Args: | |
h: input image size. | |
kh: kernel size. | |
pad: padding strategy. | |
sh: stride. | |
Returns: | |
s: output size. | |
""" | |
if pad == 'VALID': | |
return int(np.ceil((h - kh + 1) / sh)) | |
elif pad == 'SAME': | |
return int(np.ceil(h / sh)) | |
else: | |
return int(np.ceil((h - kh + pad + 1) / sh)) | |
def extract_sliding_windows(x, ksize, padding, strides, floor_first=True, out=None): | |
"""Converts a tensor to sliding windows. | |
Args: | |
x: [N, H, W, C] | |
k: [KH, KW] | |
pad: [PH, PW] | |
strides: [NBATCH, SH, SW, NCHANNELS] | |
Returns: | |
y: [N, (H-KH+PH+1)/SH, (W-KW+PW+1)/SW, KH * KW, C] | |
""" | |
n = x.shape[0] | |
h = x.shape[1] | |
w = x.shape[2] | |
c = x.shape[3] | |
kh = ksize[0] | |
kw = ksize[1] | |
sh = strides[1] | |
sw = strides[2] | |
h2 = calc_size(h, kh, padding, sh) | |
w2 = calc_size(w, kw, padding, sw) | |
ph = calc_pad(padding, h, h2, sh, kh) | |
pw = calc_pad(padding, w, w2, sw, kw) | |
ph0 = int(np.floor(ph / 2)) | |
ph1 = int(np.ceil(ph / 2)) | |
pw0 = int(np.floor(pw / 2)) | |
pw1 = int(np.ceil(pw / 2)) | |
if floor_first: | |
pph = (ph0, ph1) | |
ppw = (pw0, pw1) | |
else: | |
pph = (ph1, ph0) | |
ppw = (pw1, pw0) | |
x = np.pad( | |
x, ((0, 0), pph, ppw, (0, 0)), | |
mode='constant', | |
constant_values=(0.0, )) | |
if out is None: | |
out = np.zeros([n, h2, w2, kh, kw, c], dtype=x.dtype) | |
for ii in range(h2): | |
for jj in range(w2): | |
xx = ii * sh | |
yy = jj * sw | |
out[:, ii, jj, :, :, :] = x[:, xx:xx + kh, yy:yy + kw, :] | |
return out | |
class Im2col(Node): | |
def __init__(self, input, filter_size, strides, padding, data_format): | |
self._input = input | |
self._filter_size = filter_size | |
self._strides = strides | |
self._padding = padding | |
self._data_format = data_format | |
def arguments(self): | |
return [self._input] | |
def shape(self): | |
ksize = self._filter_size | |
x = self._input | |
n = x.shape[0] | |
h = x.shape[1] | |
w = x.shape[2] | |
c = x.shape[3] | |
kh = ksize[0] | |
kw = ksize[1] | |
sh = self._strides[1] | |
sw = self._strides[2] | |
h2 = calc_size(h, kh, self._padding, sh) | |
w2 = calc_size(w, kw, self._padding, sw) | |
ph = calc_pad(self._padding, h, h2, sh, kh) | |
pw = calc_pad(self._padding, w, w2, sw, kw) | |
ph0 = int(np.floor(ph / 2)) | |
ph1 = int(np.ceil(ph / 2)) | |
pw0 = int(np.floor(pw / 2)) | |
pw1 = int(np.ceil(pw / 2)) | |
return (n, h2, w2, kh, kw, c) | |
def dtype(self): | |
return self._input.dtype | |
class Im2colImpl(Computation): | |
def run(self): | |
op = self._op._expression | |
extract_sliding_windows(op._input._expression.data(), op._filter_size, | |
padding=op._padding, | |
strides=op._strides, | |
out=self._target._expression.data()) | |
register_implementation(Im2col, lambda x: Im2colImpl) | |
class Conv2D(Node): | |
def __init__(self, input, filter, strides, padding, data_format): | |
self._input = input | |
self._filter = filter | |
self._strides = strides | |
self._padding = padding | |
self._data_format = data_format | |
def arguments(self): | |
return [self._input, self._filter] | |
def shape(self): | |
return (self._input.shape[0], | |
calc_size(self._input.shape[1], self._filter.shape[0], self._padding, self._strides[1]), | |
calc_size(self._input.shape[2], self._filter.shape[1], self._padding, self._strides[2]), | |
self._filter.shape[3]) | |
def dtype(self): | |
return type_promotion(self._input.dtype, self._filter.dtype) | |
class Conv2DImpl(Computation): | |
def run(self): | |
w = self._op._expression._filter._expression.data() | |
x = self._op._expression._input._expression.data() | |
ksize = w.shape[:2] | |
x = extract_sliding_windows(x, ksize, | |
padding=self._op._expression._padding, | |
strides=self._op._expression._strides) | |
ws = w.shape | |
w = w.reshape([ws[0] * ws[1] * ws[2], ws[3]]) | |
xs = x.shape | |
x = x.reshape([xs[0] * xs[1] * xs[2], xs[3] * xs[4] * xs[5]]) | |
out = self._target._expression.data() | |
out = out.reshape([x.shape[0], w.shape[1]]) | |
gemm(x, w, out, alpha=1.0, beta=0.0) | |
register_implementation(Conv2D, lambda x: Conv2DImpl) | |
def conv2d(input, filter, strides, padding, data_format="NHWC"): | |
assert(data_format == "NHWC") | |
# test for dimensions here... | |
return Array(Conv2D(input, filter, strides, padding, data_format)) | |
def im2col(input, kernel_size, strides, padding, data_format="NHWC"): | |
assert(data_format == "NHWC") | |
return Array(Im2col(input, kernel_size, strides, padding, data_format)) | |
def im2col_conv2d(input, filter, strides, padding, data_format="NHWC"): | |
ksize = filter.shape[:2] | |
patches = im2col(input, ksize, strides, padding, data_format) | |
patches_2d = reshape(patches, (patches.shape[0] * patches.shape[1] * patches.shape[2], | |
patches.shape[3] * patches.shape[4] * patches.shape[5])) | |
filter_2d = reshape(filter, (filter.shape[0] * filter.shape[1] * filter.shape[2], | |
filter.shape[3])) | |
output_2d = dot(patches_2d, filter_2d) | |
return reshape(output_2d, (input.shape[0], patches.shape[1], | |
patches.shape[2], filter.shape[3])) | |
def gemm(a, b, c, alpha, beta): | |
"""Reference implementation for actual BLAS gemm | |
Note: does nothing special with transposes etc...""" | |
if beta == 0.0: | |
np.matmul(a, b, c) | |
c *= alpha | |
else: | |
c[:] = c * beta + np.matmul(a, b) * alpha | |
class MatMulImpl(Computation): | |
def _get_alpha(self): | |
return -1.0 if self._operator == "-=" else 1.0 | |
def _get_beta(self): | |
return 0.0 if self._operator == "=" else 1.0 | |
def run(self): | |
gemm(self._op._expression._left._expression.data(), | |
self._op._expression._right._expression.data(), | |
self._target._expression.data(), | |
alpha=self._get_alpha(), | |
beta=self._get_beta()) | |
class IMatMulImpl(MatMulImpl): | |
def run(self): | |
gemm(self._op._expression._left._expression.data(), | |
self._op._expression._right._expression.data(), | |
self._target._expression.data(), | |
alpha=int(self._get_alpha()), | |
beta=int(self._get_beta())) | |
def choose_matmul(x): | |
if x.dtype == np.float32 or x.dtype == np.float64: | |
return MatMulImpl | |
elif x.dtype == np.int32 or x.dtype == np.int64: | |
return IMatMulImpl | |
else: | |
raise ValueError("no implementation found.") | |
register_implementation(MatMul, choose_matmul) | |
class Dimshuffle(Node): | |
def __init__(self, node, axes): | |
self._node = node | |
self._axes = axes | |
def arguments(self): | |
return [self._node] | |
def shape(self): | |
original_shape = self._node._expression.shape() | |
return tuple([original_shape[i] for i in self._axes]) | |
def dtype(self): | |
return self._node._expression.dtype() | |
class DimshuffleImpl(Computation): | |
def run(self): | |
print("dimshuffle...") | |
assign_with_operator(self._target._expression.data(), | |
self._operator, | |
self._op._expression._node._expression.data().transpose(self._op._expression._axes)) | |
register_implementation(Dimshuffle, lambda x: DimshuffleImpl) | |
def dimshuffle(node, axes): | |
if isinstance(node._expression, Buffer): | |
return Array(node._expression.dimshuffle(axes)) | |
for i, ax in enumerate(axes): | |
if i != ax: | |
return Array(Dimshuffle(node, axes)) | |
return node | |
def transpose(node, axes=None): | |
if axes is None: | |
axes = list(reversed(range(node._expression.ndim()))) | |
return dimshuffle(node, axes) | |
class Reshape(Node): | |
def __init__(self, node, shape): | |
self._node = node | |
self._shape = shape | |
def arguments(self): | |
return [self._node] | |
def shape(self): | |
return self._shape | |
def dtype(self): | |
return self._node._expression.dtype() | |
class ReshapeImpl(Computation): | |
def run(self): | |
print("reshape...") | |
assign_with_operator(self._target._expression.data(), | |
self._operator, | |
self._op._expression._node._expression.data().reshape(self._op._expression._shape)) | |
register_implementation(Reshape, lambda x: ReshapeImpl) | |
def reshape(node, shape): | |
if tuple(node._expression.shape()) == tuple(shape): | |
return node | |
if isinstance(node._expression, Buffer) and can_copyless_reshape(node, shape): | |
return Array(node._expression.reshape(shape)) | |
return Array(Reshape(node, shape)) | |
class Identity(JITNode): | |
def __init__(self, node): | |
self._node = node | |
def arguments(self): | |
return [self._node] | |
def shape(self): | |
return self._node._expression.shape() | |
def dtype(self): | |
return self._node._expression.dtype() | |
def ascontiguousarray(array): | |
buffer_node = buffer_arg(array) | |
if buffer_node is None: | |
return ascontiguousarray(to_assignment(array)) | |
elif buffer_node._expression.contiguous_memory(): | |
return array | |
else: | |
return Array(Identity(array)) | |
def identity(array): | |
return ascontiguousarray(array) | |
### TENSORDOT ### | |
def check_tensordot_reduce_axes(operand_shape, | |
name, | |
reduce_axes, | |
batched): | |
# Do not reduce over more dimensions than operand_shape.size(). | |
if len(reduce_axes) > len(operand_shape): | |
raise ValueError(("length of argument {name}_reduce_axes " | |
"should be less than the dimensions of {name}" | |
" ({name}.ndim()={operand_shape}" | |
", {name}_reduce_axes.size()={size}).").format( | |
name=name, | |
operand_shape=operand_shape, | |
size=len(reduce_axes))) | |
# all reduction axes must be less than operand_shape.size() | |
max_reduce_dim = max(reduce_axes) | |
if not (len(reduce_axes) == 0 or max_reduce_dim < len(operand_shape)): | |
raise ValueError(("{name}_reduce_axes contains reduction dimensions " | |
" that are greater than or equal to " | |
"{name}.ndim() (" | |
"{name}.ndim()={size}" | |
", and found max({name}_reduce_axes)" | |
"={max_reduce_dim}).").format( | |
name=name, | |
size=len(operand_shape), | |
max_reduce_dim=max_reduce_dim)) | |
if batched and 0 in reduce_axes: | |
raise ValueError(("axes to sum over must not contain the batch axis " | |
"({name}_reduce_axes={reduce_axes}).").format( | |
name=name, reduce_axes=reduce_axes)) | |
def tensordot_nonreduced_axes(ndim, reduce_axes, batched): | |
"""Returns all the axes that are not being reduced.""" | |
other_axes = [] | |
for x in range(0, ndim): | |
# when batched, 0 is always kept | |
# as leading dim, and thus will not | |
# be dimshuffled | |
if batched and x == 0: | |
continue | |
if x not in reduce_axes: | |
other_axes.append(x) | |
return other_axes | |
def matrix_multiply_with_reshape(a, b, output_shape, output_shape_2d): | |
if a._expression.ndim() != 2: | |
raise ValueError("a must have ndim=2") | |
if b._expression.ndim() != 2: | |
raise ValueError("b must have ndim=2") | |
left = output_shape_2d[0] | |
middle = max(a._expression.shape()[1], b._expression.shape()[0]) | |
right = output_shape_2d[1] | |
# if the broadcasting fails let ReshapedMatrixMultiplyFunction | |
# throw an error. | |
new_a = ascontiguousarray(reshape(a, (left, middle))) | |
new_b = ascontiguousarray(reshape(b, (middle, right))) | |
return reshape(Array(MatMul(new_a, new_b)), output_shape) | |
def tensordot_as_dot(a, b, a_reduce_axes=None, b_reduce_axes=None, | |
batched=False, axis=None): | |
# This code follows the logic from theano's tensordot as dot | |
# [source https://github.com/Theano/Theano/blob/master/theano/tensor/basic.py#L5628] | |
# Theano code was also originally based elsewhere on | |
# Tijmen Tieleman's gnumpy: | |
# [source http://www.cs.toronto.edu/~tijmen/gnumpy.html] | |
# if 'axes' is a single number of axes to multiply and sum over | |
# (trailing axes of a, leading axes of b), we can just reshape | |
# and use dot. | |
# validate that the axis used for summing | |
# is not out of bounds for the arguments a and b | |
if axis is not None: | |
if axis < 0: | |
raise ValueError(("axis must be a non-negative " | |
"integer (got {axis}).").format(axis=axis)) | |
for i in range(0, 2): | |
operand = a if i == 0 else b | |
operand_name = "a" if i == 0 else "b" | |
if axis > operand._expression.ndim(): | |
raise ValueError(("axis can not be larger than the dimension of " | |
"{name} ({name}.ndim()={ndim}, axis={axis}).").format( | |
axis=axis, name=operand_name, ndim=operand._expression.ndim())) | |
if axis == operand._expression.ndim() and batched: | |
raise ValueError(("axis to sum over must not include the batch axis " | |
"of {name} ({name}.ndim()={ndim}, axis={axis}).").format( | |
name=name, axis=axis, ndim=operand._expression.ndim())) | |
batch_axes = 1 if batched else 0 | |
a_shape, b_shape = [1, 1], [1, 1] | |
a_old_shape = a._expression.shape() | |
b_old_shape = b._expression.shape() | |
# compute total size of summed axes | |
for i in range(0, axis): | |
a_shape[1] *= a_old_shape[len(a_old_shape) - (i + 1)] | |
b_shape[0] *= b_old_shape[batch_axes + i] | |
# compute total size of other axes | |
for i in range(0, a._expression.ndim() - axis - batch_axes): | |
a_shape[0] *= a_old_shape[batch_axes + i] | |
for i in range(0, b._expression.ndim() - axis - batch_axes): | |
b_shape[1] *= b_old_shape[len(b_old_shape) -(i + 1)] | |
if batched: | |
a_shape.insert(0, a_old_shape[0]) | |
b_shape.insert(0, b_old_shape[0]) | |
output_shape = a_old_shape[:len(a_old_shape) - axis] + b_old_shape[batch_axes + axis:] | |
return matrix_multiply_with_reshape( | |
reshape(a, a_shape), | |
reshape(b, b_shape), | |
output_shape, | |
(a_shape[0], b_shape[1])) | |
else: | |
if a_reduce_axes is None or b_reduce_axes is None: | |
raise ValueError("a_reduce_axes and b_reduce_axes must " | |
"not be None if axis is None.") | |
if len(a_reduce_axes) != len(b_reduce_axes): | |
raise ValueError(("must have as many reduction axes for a than b " | |
"(got a_reduce_axes=%r and " | |
"b_reduce_axes=%r).") % (a_reduce_axes, | |
b_reduce_axes)) | |
check_tensordot_reduce_axes(a._expression.shape(), "a", a_reduce_axes, batched) | |
check_tensordot_reduce_axes(b._expression.shape(), "b", b_reduce_axes, batched) | |
a_new_axes = tensordot_nonreduced_axes( | |
a._expression.ndim(), a_reduce_axes, batched) | |
b_new_axes = tensordot_nonreduced_axes( | |
b._expression.ndim(), b_reduce_axes, batched) | |
# for A: add reduction axis at the end of shape | |
a_new_axes.extend(a_reduce_axes) | |
# for B: add reduction axis at the beginning of shape | |
b_new_axes = b_reduce_axes + b_new_axes | |
if batched: | |
a_new_axes.insert(0, 0) | |
b_new_axes.insert(0, 0) | |
return tensordot_as_dot(dimshuffle(a, a_new_axes), | |
dimshuffle(b, b_new_axes), | |
axis=len(a_reduce_axes), | |
batched=batched) | |
def ascontiguousarray_or_simple_transpose(node): | |
"""Gemms support transposed matrix multiplies, but strided | |
memory generally is unsupported.""" | |
buff = buffer_arg(node) | |
if buff is not None and (buff._expression.contiguous_memory() or buff._expression.is_transpose()): | |
return node | |
return ascontiguousarray(node) | |
def dot(a, b): | |
a_ndim, b_ndim = a._expression.ndim(), b._expression.ndim() | |
if a_ndim == 2 and b_ndim == 2: | |
a = ascontiguousarray_or_simple_transpose(a) | |
b = ascontiguousarray_or_simple_transpose(b) | |
return Array(MatMul(a, b)) | |
elif a_ndim > 2 or b_ndim > 2: | |
return tensordot_as_dot(a, b, | |
a_reduce_axes=[a_ndim-1], | |
b_reduce_axes=[b_ndim-2]) | |
else: | |
raise ValueError("dot not implemented for a.ndim = %r, b.ndim = %r" % ( | |
a_ndim, b_ndim)) | |
def expect_result(op, expected): | |
return np.testing.assert_allclose(op.value(), expected) | |
def np_reduce_sum(array, axis, keep_dims): | |
for ax in axis: | |
array = np.expand_dims(array.sum(ax), ax) | |
if not keep_dims: | |
array = np.squeeze(array, axis) | |
return array | |
def tf_conv2d(input, filter, strides, padding, session=None): | |
if session is None: | |
session = tf.InteractiveSession() | |
return session.run(tf.nn.conv2d(input, filter, strides=strides, padding=padding)) | |
def main(): | |
## arrays: | |
m3x3 = np.ones((3, 3)) | |
z3x3 = np.zeros((3, 3)) | |
z4x3 = np.zeros((4, 3)) | |
array = np.arange(12).reshape((3, 4)).astype(np.float32) | |
array_strided = np.zeros((3, 2, 4))[:, 0, :] | |
array_strided[:] = array | |
# Additions: | |
expect_result(add(add(buffer(m3x3), buffer(m3x3)), | |
add(buffer(m3x3), buffer(m3x3))), | |
m3x3 + m3x3 + m3x3 + m3x3) | |
expect_result(add(buffer(m3x3), add(buffer(m3x3), buffer(m3x3))), | |
m3x3 + m3x3 + m3x3) | |
expect_result(add(buffer(m3x3), buffer(m3x3)), m3x3 + m3x3) | |
# Additions & Tanh | |
expect_result(add(tanh(buffer(m3x3)), tanh(buffer(m3x3))), | |
np.tanh(m3x3) + np.tanh(m3x3)) | |
expect_result(add(tanh(buffer(array)), tanh(buffer(array))), | |
np.tanh(array) + np.tanh(array)) | |
expect_result(tanh(buffer(array)), np.tanh(array)) | |
# GEMM | |
# mix of matrix multiply + elementwise | |
expect_result(tanh(dot(buffer(array), buffer(array).T)), | |
np.tanh(np.dot(array, array.T))) | |
# GEMM sum: | |
expect_result(add(dot(buffer(m3x3), buffer(m3x3)), | |
dot(buffer(m3x3), buffer(m3x3))), | |
np.dot(m3x3, m3x3) + np.dot(m3x3, m3x3)) | |
# CONV op: | |
batch_size = 128 | |
height = 10 | |
width = 10 | |
channels = 3 | |
out_channels = 10 | |
strides = (1, 1, 1, 1) | |
padding = "SAME" | |
x = buffer(np.ones((batch_size, height, width, channels)).astype(np.float32)) | |
w = buffer(np.arange(height * width * channels * out_channels).astype(np.float32).reshape((height, width, channels, out_channels))) | |
expect_result(conv2d(x, w, strides=strides, padding=padding), | |
tf_conv2d(x._expression.data(), w._expression.data(), | |
strides=strides, padding=padding)) | |
expect_result(im2col_conv2d(x, w, strides=strides, padding=padding), | |
tf_conv2d(x._expression.data(), w._expression.data(), | |
strides=strides, padding=padding)) | |
## Tensor dots: | |
expect_result(dot(buffer(np.arange(18).reshape((2, 3, 3))), | |
buffer(np.arange(9).reshape((3, 3)))), | |
np.dot(np.arange(18).reshape((2, 3, 3)), | |
np.arange(9).reshape((3, 3)))) | |
## Transposes: | |
# the transpose is compatible with gemm, no copy: | |
expect_result(dot(buffer(array), buffer(array).T), | |
np.dot(array, array.T)) | |
# the strided nature forces a copy before a gemm: | |
expect_result(dot(buffer(array), buffer(array_strided).T), | |
np.dot(array, array_strided.T)) | |
## Autoreduction: | |
autoreduce_assign_dest = np.zeros((3, 2, 1, 2, 1)) | |
autoreduce_assign_source = np.ones((3, 2, 10, 2, 10)) | |
expect_result(assign(buffer(autoreduce_assign_dest), "<<=", | |
buffer(autoreduce_assign_source)), | |
np_reduce_sum(autoreduce_assign_source, (2, 4), True)) | |
expect_result(reduce_sum(buffer(autoreduce_assign_source), (2, 4), False), | |
np_reduce_sum(autoreduce_assign_source, (2, 4), False)) | |
# create a storage location with data: | |
a = buffer(m3x3) | |
# now subtract from that location | |
c = assign(a, "-=", buffer(m3x3)) | |
# look at the data before evaluation: | |
expect_result(a, m3x3) | |
# calling the assignment changes a's value: | |
c.eval() | |
expect_result(a, m3x3 - m3x3) | |
## Canonicalization: | |
buff = buffer(array) | |
ops = [(identity(identity(identity(buff))), buff), | |
(identity(identity(identity(transpose(buff)))), | |
assign(buffer(z4x3), "=", Array(JITRunner(identity(buff.T), [buff.T])))), | |
(assign(buffer(m3x3), "*=", dot(buffer(m3x3), buffer(m3x3))), | |
Array(Assignment(buffer(m3x3), "*=", Array(Assignment(buffer(z3x3), "=", dot(buffer(m3x3), buffer(m3x3))))))), | |
(assign(buffer(m3x3), "+=", dot(buffer(m3x3), buffer(m3x3))), | |
Array(Assignment(buffer(m3x3), "+=", dot(buffer(m3x3), buffer(m3x3)))))] | |
for op, proposed in ops: | |
assert(op.canonical() == proposed), (str(op.canonical()), str(proposed)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment