Skip to content

Instantly share code, notes, and snippets.

@BeMg
Created February 3, 2020 12:20
Show Gist options
  • Save BeMg/7e1ea91c81d85e38b59ad4ce4358f701 to your computer and use it in GitHub Desktop.
Save BeMg/7e1ea91c81d85e38b59ad4ce4358f701 to your computer and use it in GitHub Desktop.
AAAAA
def schedule_conv2d_2(_,outs,target):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if 'broadcast' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
#print(op.tag)
if 'conv2d_nchw' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
C = conv_out
n, cc, h, w, cb = C.op.axis
print("C shape: {} {} {} {} {}".format(C.shape[0].value, C.shape[1].value, C.shape[2].value, C.shape[3].value, C.shape[4].value))
rc, ry, rx = C.op.reduce_axis
s[C].reorder(n, h, w, cc, cb)
c = s[C].fuse(h, w)
fused = s[C].fuse(n, c)
if C.shape[4].value % 8 == 0:
cbo, cbi = s[C].split(cb, factor=8)
elif C.shape[4].value % 4 == 0:
cbo, cbi = s[C].split(cb, factor=4)
else:
cbo, cbi = s[C].split(cb, factor=1)
s[C].reorder(fused, rc, cc, cbo, ry, rx, cbi) # move rc to outer loop
s[C].unroll(rx)
s[C].unroll(ry)
s[C].vectorize(cbi)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
def schedule_conv2d(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_, _, kh, kw = get_const_tuple(kernel.shape)
is_kernel_1x1 = kh == 1 and kw == 1
args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
print("IN X86")
if is_kernel_1x1:
conv2d_avx_1x1._schedule_conv(*args)
else:
conv2d_avx_common._schedule_conv(*args)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment