Created
August 28, 2019 22:07
-
-
Save yzhliu/5439777c6b2d8fd3f5aeef43ee2048e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
import topi | |
from topi.util import get_const_tuple | |
import numpy as np | |
from topi.nn.pad import pad | |
# on a2: python3 -m tvm.exec.rpc_server --port=8499 | |
# target = 'llvm -mcpu=core-avx2' | |
# target = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.4a,+fp16fml,+fullfp16' | |
target = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+fullfp16,+fp-armv8,+dotprod,+crc,+crypto,+neon' | |
dtype = 'float16' | |
def get_fp32_len(): | |
return 8 | |
def _fallback_schedule(in_channel, height, width, num_filter, filter_height, filter_width, padding, strides): | |
WPAD, HPAD = padding | |
WSTR, HSTR = strides | |
simd_width = get_fp32_len() | |
out_width = (width + 2 * WPAD - filter_width) // WSTR + 1 | |
oc_bn = 1 | |
for bn in range(simd_width, 0, -1): | |
if num_filter % bn == 0: | |
oc_bn = bn | |
break | |
ic_bn = 1 | |
for bn in range(oc_bn, 0, -1): | |
if in_channel % bn == 0: | |
ic_bn = bn | |
break | |
reg_n = 1 | |
for n in range(31, 0, -1): | |
if out_width % n == 0: | |
reg_n = n | |
break | |
return ic_bn, oc_bn, reg_n, False | |
def conv_compute(data, kernel, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides): | |
out_dtype = data.dtype | |
dilation_h, dilation_w = 1, 1 | |
HPAD, WPAD = padding | |
HSTR, WSTR = strides | |
batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) | |
num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) | |
pad_height = in_height + 2 * HPAD | |
pad_width = in_width + 2 * WPAD | |
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 | |
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 | |
out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1 | |
out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1 | |
# pack data | |
DOPAD = (HPAD != 0 or WPAD != 0) | |
if DOPAD: | |
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") | |
else: | |
data_pad = data | |
# fetch schedule | |
ic_bn, oc_bn, reg_n, unroll_kw = _fallback_schedule(in_channel, height, width, num_filter, | |
filter_height, filter_width, padding, strides) | |
shape = (batch_size, in_channel // ic_bn, pad_height, ic_bn, pad_width) | |
data_vec = tvm.compute(shape, | |
lambda n, C, h, c, w: data_pad[n, C * ic_bn + c, h, w], | |
name='data_vec') | |
# pack kernel | |
shape = (num_filter//oc_bn, in_channel//ic_bn, | |
kernel_height, kernel_width, ic_bn, oc_bn) | |
kernel_vec = tvm.compute(shape, | |
lambda CO, CI, h, w, ci, co: | |
kernel[CO * oc_bn + co, CI * ic_bn + ci, h, w], | |
name='kernel_vec') | |
# convolution | |
oshape = (batch_size, num_filter//oc_bn, out_height, out_width, oc_bn) | |
unpack_shape = (batch_size, num_filter, out_height, out_width) | |
ic = tvm.reduce_axis((0, in_channel), name='ic') | |
kh = tvm.reduce_axis((0, kernel_height), name='kh') | |
kw = tvm.reduce_axis((0, kernel_width), name='kw') | |
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: | |
tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn, | |
ow*WSTR+kw*dilation_w].astype(out_dtype) * | |
kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, | |
oc_block].astype(out_dtype), | |
axis=[ic, kh, kw]), name='conv') | |
unpack = tvm.compute(unpack_shape, | |
lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn] | |
.astype(out_dtype), | |
name='output_unpack', | |
tag='conv2d_nchw') | |
return unpack | |
def conv_schedule(C, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides): | |
s = tvm.create_schedule(C.op) | |
op = C.op | |
output = op.output(0) | |
conv_out = op.input_tensors[0] | |
kernel_vec = conv_out.op.input_tensors[1] | |
kernel = kernel_vec.op.input_tensors[0] | |
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: | |
s[kernel].compute_inline() | |
data_vec = conv_out.op.input_tensors[0] | |
data = data_vec.op.input_tensors[0] | |
data_pad = None | |
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: | |
data_pad = data | |
data = data_pad.op.input_tensors[0] | |
_, _, kh, kw = get_const_tuple(kernel.shape) | |
# fetch schedule | |
ic_bn, oc_bn, reg_n, unroll_kw = _fallback_schedule(in_channel, height, width, num_filter, | |
filter_height, filter_width, padding, strides) | |
# no stride and padding info here | |
HPAD, WPAD = padding | |
DOPAD = (HPAD != 0 or WPAD != 0) | |
A, W = data, kernel_vec | |
A0, A1 = data_pad, data_vec | |
# schedule data | |
if DOPAD: | |
s[A0].compute_inline() | |
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis | |
parallel_axis = s[A1].fuse(ic_chunk, ih) | |
s[A1].parallel(parallel_axis) | |
# schedule kernel pack | |
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis | |
s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) | |
if oc_bn > 1: | |
s[W].vectorize(oc_block) | |
parallel_axis = s[W].fuse(oc_chunk, oh) | |
s[W].parallel(parallel_axis) | |
# schedule conv | |
C, O0 = conv_out, output | |
CC = s.cache_write(C, 'global') | |
_, oc_chunk, oh, ow, oc_block = s[C].op.axis | |
ow_chunk, ow_block = s[C].split(ow, factor=reg_n) | |
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) | |
s[C].fuse(oc_chunk, oh) | |
s[C].vectorize(oc_block) | |
s[CC].compute_at(s[C], ow_chunk) | |
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis | |
ic, kh, kw = s[CC].op.reduce_axis | |
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) | |
ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) | |
if unroll_kw: | |
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) | |
s[CC].unroll(kw) | |
else: | |
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block) | |
s[CC].fuse(oc_chunk, oh) | |
s[CC].vectorize(oc_block) | |
s[CC].unroll(ow_block) | |
return s | |
def run_conv2d(batch_size, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides): | |
A = tvm.placeholder((batch_size, in_channel, height, width), name='A', dtype=dtype) | |
W = tvm.placeholder((num_filter, in_channel, filter_height, filter_width), name='W', dtype=dtype) | |
a_shape = get_const_tuple(A.shape) | |
w_shape = get_const_tuple(W.shape) | |
def get_ref_data(): | |
a_np = np.random.uniform(size=a_shape).astype(dtype) | |
w_np = np.random.uniform(size=w_shape).astype(dtype) | |
from topi.testing.conv2d_nchw_python import conv2d_nchw_python | |
conv_np = conv2d_nchw_python(a_np, w_np, stride=(1,1), padding=(1,1)) | |
return a_np, w_np, conv_np | |
a_np, w_np, conv_np = get_ref_data() | |
C = conv_compute(A, W, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides) | |
s = conv_schedule(C, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides) | |
# s = tvm.create_schedule(C.op) | |
print(tvm.lower(s, [A, W, C], simple_mode=True)) | |
from tvm import rpc | |
from tvm.contrib import util | |
host = '0.0.0.0' | |
port = 8499 | |
remote = rpc.connect(host, port) | |
ctx = remote.cpu() | |
# ctx = tvm.cpu() | |
func = tvm.build(s, [A, W, C], target=target) | |
func.save('conv.s') | |
temp = util.tempdir() | |
path = temp.relpath('lib.tar') | |
func.export_library(path) | |
remote.upload(path) | |
func = remote.load_module('lib.tar') | |
conv = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx) | |
time_f = func.time_evaluator(func.entry_name, ctx, number=50) | |
cost_conv = time_f(tvm.nd.array(a_np, ctx), tvm.nd.array(w_np, ctx), conv).mean | |
print('conv: %g ms/op' % (cost_conv * 1000.0)) | |
# np.testing.assert_allclose(conv.asnumpy(), conv_np, rtol=1e-5) | |
if __name__ == "__main__": | |
run_conv2d(batch_size=1, in_channel=64, height=56, width=56, num_filter=64, filter_height=3, filter_width=3, | |
padding=(1, 1), strides=(1, 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment