Last active
May 21, 2020 22:14
-
-
Save comaniac/1f399dfdfee05a2f7a087c65c21f550c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import relay | |
from tvm.relay.dataflow_pattern import * | |
#class ExtractArgs(tvm.relay.ExprMutator): | |
# # Mutate the graph to replace the matched graph input with new vars. | |
# def __init__(self, mapped_vars): | |
# super(ExtractArgs, self).__init__() | |
# self.mapped_vars = set(mapped_vars) | |
# self.vars = [] | |
# self.var_cnt = 0 | |
# | |
# def visit(self, node): | |
# post_node = super().visit(node) | |
# if node in self.mapped_vars: | |
# new_var = relay.Var('in{}'.format(self.var_cnt)) | |
# self.var_cnt += 1 | |
# self.vars.append(new_var) | |
# return new_var | |
# return post_node | |
#class PatternCallback(DFPatternCallback): | |
# # Replace matched graph to a function call and create a composite function. | |
# def __init__(self, pat_vars, pattern): | |
# self.pat_vars = pat_vars | |
# self.pattern = pattern | |
# | |
# def callback(self, pre, post, node_map): | |
# extractor = ExtractArgs([node_map[v][0] for v in self.pat_vars]) | |
# new_post = extractor.visit(post) | |
# | |
# func = relay.Function(extractor.get_vars(), new_post) | |
# func.with_attr('Composite', 'my-inst') | |
# | |
# return relay.Call(func, [node_map[v][0] for v in self.pat_vars]) | |
# Relay graphs to be matched and partitioned. | |
def graph_add(): | |
# A Relay graph with 4 adds. | |
x = relay.var('x') | |
y = relay.var('y') | |
y_add = relay.add(y, y) | |
n0 = relay.add(x, y_add) | |
n1 = relay.add(x, n0) | |
return relay.add(n1, n0) | |
def graph_conv2d(): | |
x = relay.var('x', shape=(1, 10, 10, 10)) | |
w1 = relay.var('w', shape=(10, 10, 3, 3)) | |
w2 = relay.var('w', shape=(10, 10, 3, 3)) | |
b1 = relay.var('b', shape=(8,)) | |
b2 = relay.var('b', shape=(8,)) | |
conv = relay.nn.conv2d(x, | |
w1, | |
kernel_size=(3, 3), | |
kernel_layout="OIHW", | |
data_layout="NHWC") | |
bias = relay.nn.bias_add(conv, b1) | |
relu = relay.nn.relu(bias) | |
conv = relay.nn.conv2d(relu, | |
w2, | |
kernel_size=(3, 3), | |
kernel_layout="OIHW", | |
data_layout="NHWC") | |
bias = relay.nn.bias_add(conv, b2) | |
return bias | |
# Make a graph list. | |
graphs = [graph_add(), graph_conv2d()] | |
# Patterns. | |
def pattern_add(): | |
# A pattern with 3 adds. | |
a = wildcard() | |
b = wildcard() | |
n0 = is_op('add')(a, b) | |
n1 = is_op('add')(n0, a) | |
return is_op('add')(n0, n1) | |
def pattern_conv2d_bias_relu(): | |
# A conv2d+bias+relu pattern. Enforce conv2d to be in NHWC. | |
x = wildcard() | |
y = wildcard() | |
z = wildcard() | |
conv = is_op('nn.conv2d')(x, y).has_attr({'data_layout': 'NHWC'}) | |
bias = is_op('nn.bias_add')(conv, z) | |
relu = is_op('nn.relu')(bias) | |
return relu | |
def pattern_conv2d_bias(): | |
# A conv2d+bias pattern. A subpattern of conv2d+bias+relu. | |
x = wildcard() | |
y = wildcard() | |
z = wildcard() | |
conv = is_op('nn.conv2d')(x, y).has_attr({'data_layout': 'NHWC'}) | |
bias = is_op('nn.bias_add')(conv, z) | |
return bias | |
# Make a pattern map. | |
pattern_list = [('inst-add', pattern_add()), | |
('inst-conv2d_bias_relu', pattern_conv2d_bias_relu()), | |
('inst-conv2d_bias', pattern_conv2d_bias())] | |
for before in graphs: | |
print('===================') | |
print('=== Relay Graph ===') | |
print('===================') | |
print(before) | |
after = before | |
for label, pattern in pattern_list: | |
after = pattern.partition(after, {'Composite': label}) | |
print('=== After %s ===' % label) | |
print(after) | |
# Output logs | |
#=================== | |
#=== Relay Graph === | |
#=================== | |
#free_var %x | |
#free_var %y | |
#%0 = add(%y, %y); | |
#%1 = add(%x, %0); | |
#%2 = add(%x, %1); | |
#add(%2, %1) | |
#=== After inst-add === | |
#free_var %x | |
#free_var %y | |
#%0 = add(%y, %y); | |
#%3 = fn (%FunctionVar_0_0, %FunctionVar_0_1, Composite="inst-add", PartitionedFromPattern="add_add_add_") { | |
# %1 = add(%FunctionVar_0_0, %FunctionVar_0_1); | |
# %2 = add(%FunctionVar_0_0, %1); | |
# add(%2, %1) | |
#}; | |
#%3(%x, %0) | |
#=== After inst-conv2d_bias_relu === | |
#free_var %x | |
#free_var %y | |
#%0 = add(%y, %y); | |
#%3 = fn (%FunctionVar_0_0, %FunctionVar_0_1, Composite="inst-add", PartitionedFromPattern="add_add_add_") { | |
# %1 = add(%FunctionVar_0_0, %FunctionVar_0_1); | |
# %2 = add(%FunctionVar_0_0, %1); | |
# add(%2, %1) | |
#}; | |
#%3(%x, %0) | |
#=== After inst-conv2d_bias === | |
#free_var %x | |
#free_var %y | |
#%0 = add(%y, %y); | |
#%3 = fn (%FunctionVar_0_0, %FunctionVar_0_1, Composite="inst-add", PartitionedFromPattern="add_add_add_") { | |
# %1 = add(%FunctionVar_0_0, %FunctionVar_0_1); | |
# %2 = add(%FunctionVar_0_0, %1); | |
# add(%2, %1) | |
#}; | |
#%3(%x, %0) | |
#=================== | |
#=== Relay Graph === | |
#=================== | |
#free_var %x: Tensor[(1, 10, 10, 10), float32] | |
#free_var %w: Tensor[(10, 10, 3, 3), float32] | |
#%0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
#free_var %b: Tensor[(8), float32] | |
#%1 = nn.bias_add(%0, %b); | |
#%2 = nn.relu(%1); | |
#free_var %w1: Tensor[(10, 10, 3, 3), float32] | |
#%3 = nn.conv2d(%2, %w1, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
#free_var %b1: Tensor[(8), float32] | |
#nn.bias_add(%3, %b1) | |
#=== After inst-add === | |
#free_var %x: Tensor[(1, 10, 10, 10), float32] | |
#free_var %w: Tensor[(10, 10, 3, 3), float32] | |
#%0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
#free_var %b: Tensor[(8), float32] | |
#%1 = nn.bias_add(%0, %b); | |
#%2 = nn.relu(%1); | |
#free_var %w1: Tensor[(10, 10, 3, 3), float32] | |
#%3 = nn.conv2d(%2, %w1, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
#free_var %b1: Tensor[(8), float32] | |
#nn.bias_add(%3, %b1) | |
#=== After inst-conv2d_bias_relu === | |
#free_var %x: Tensor[(1, 10, 10, 10), float32] | |
#free_var %w: Tensor[(10, 10, 3, 3), float32] | |
#free_var %b: Tensor[(8), float32] | |
#%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, Composite="inst-conv2d_bias_relu", PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") { | |
# %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
# %1 = nn.bias_add(%0, %FunctionVar_0_2); | |
# nn.relu(%1) | |
#}; | |
#%3 = %2(%x, %w, %b); | |
#free_var %w1: Tensor[(10, 10, 3, 3), float32] | |
#%4 = nn.conv2d(%3, %w1, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
#free_var %b1: Tensor[(8), float32] | |
#nn.bias_add(%4, %b1) | |
#=== After inst-conv2d_bias === | |
#free_var %x: Tensor[(1, 10, 10, 10), float32] | |
#free_var %w: Tensor[(10, 10, 3, 3), float32] | |
#free_var %b: Tensor[(8), float32] | |
#%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, Composite="inst-conv2d_bias_relu", PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") { | |
# %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
# %1 = nn.bias_add(%0, %FunctionVar_0_2); | |
# nn.relu(%1) | |
#}; | |
#%3 = %2(%x, %w, %b); | |
#free_var %w1: Tensor[(10, 10, 3, 3), float32] | |
#free_var %b1: Tensor[(8), float32] | |
#%5 = fn (%FunctionVar_0_01, %FunctionVar_0_11, %FunctionVar_0_21, Composite="inst-conv2d_bias", PartitionedFromPattern="nn.conv2d_nn.bias_add_") { | |
# %4 = nn.conv2d(%FunctionVar_0_01, %FunctionVar_0_11, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC"); | |
# nn.bias_add(%4, %FunctionVar_0_21) | |
#}; | |
#%5(%3, %w1, %b1) |
I actually just wanted to demonstrate overlapped patterns but this is a good example too. Will do.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Something @masahi talked about in the discuss forum that could be applied here:
You're getting duplicate information in the tags:
Composite="inst-conv2d_bias_relu", PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_")
If you change it to have just one pattern, like this:
And pass in a single Composite attr inst-conv2d, and then use the PartitionedFromPattern tag to figure out if you found the relu version or not.