Skip to content

Instantly share code, notes, and snippets.

@comaniac
Created February 3, 2021 01:40
Show Gist options
  • Save comaniac/720b120425d41208966f8a3d03140da6 to your computer and use it in GitHub Desktop.
Save comaniac/720b120425d41208966f8a3d03140da6 to your computer and use it in GitHub Desktop.
import numpy as np
import tvm
from tvm import auto_scheduler, te, topi
from tvm.te import schedule
# The last layer in resnet
H, W, CO, CI, KH, KW, strides, padding = 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
def conv2d(N, H, W, CO, CI, KH, KW, stride, padding):
data = te.placeholder((N, CI, H, W), name="data")
kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
out = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
return [data, kernel, out]
data, kernel, out = conv2d(1, H, W, CO, CI, KH, KW, strides, padding)
def schedule_0(data, kernel, out):
s = te.create_schedule([out.op])
compute = out
pad_temp, _ = out.op.input_tensors
pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)
compute_nn, compute_ff, compute_yy, compute_xx, compute_rc, compute_ry, compute_rx = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)
compute_local, = s.cache_write([compute], "local")
compute_local_nn_c, compute_local_ff_c, compute_local_yy_c, compute_local_xx_c, compute_local_rc, compute_local_ry, compute_local_rx = tuple(compute_local.op.axis) + tuple(compute_local.op.reduce_axis)
compute_local_nn_c_o_i, compute_local_nn_c_i = s[compute_local].split(compute_local_nn_c, factor=1)
compute_local_nn_c_o_o_i, compute_local_nn_c_o_i = s[compute_local].split(compute_local_nn_c_o_i, factor=1)
compute_local_nn_c_o_o_o_i, compute_local_nn_c_o_o_i = s[compute_local].split(compute_local_nn_c_o_o_i, factor=1)
compute_local_nn_c_o_o_o_o, compute_local_nn_c_o_o_o_i = s[compute_local].split(compute_local_nn_c_o_o_o_i, factor=1)
compute_local_ff_c_o_i, compute_local_ff_c_i = s[compute_local].split(compute_local_ff_c, factor=1)
compute_local_ff_c_o_o_i, compute_local_ff_c_o_i = s[compute_local].split(compute_local_ff_c_o_i, factor=1)
compute_local_ff_c_o_o_o_i, compute_local_ff_c_o_o_i = s[compute_local].split(compute_local_ff_c_o_o_i, factor=16)
compute_local_ff_c_o_o_o_o, compute_local_ff_c_o_o_o_i = s[compute_local].split(compute_local_ff_c_o_o_o_i, factor=2)
compute_local_yy_c_o_i, compute_local_yy_c_i = s[compute_local].split(compute_local_yy_c, factor=7)
compute_local_yy_c_o_o_i, compute_local_yy_c_o_i = s[compute_local].split(compute_local_yy_c_o_i, factor=1)
compute_local_yy_c_o_o_o_i, compute_local_yy_c_o_o_i = s[compute_local].split(compute_local_yy_c_o_o_i, factor=1)
compute_local_yy_c_o_o_o_o, compute_local_yy_c_o_o_o_i = s[compute_local].split(compute_local_yy_c_o_o_o_i, factor=1)
compute_local_xx_c_o_i, compute_local_xx_c_i = s[compute_local].split(compute_local_xx_c, factor=1)
compute_local_xx_c_o_o_i, compute_local_xx_c_o_i = s[compute_local].split(compute_local_xx_c_o_i, factor=1)
compute_local_xx_c_o_o_o_i, compute_local_xx_c_o_o_i = s[compute_local].split(compute_local_xx_c_o_o_i, factor=7)
compute_local_xx_c_o_o_o_o, compute_local_xx_c_o_o_o_i = s[compute_local].split(compute_local_xx_c_o_o_o_i, factor=1)
compute_local_rc_o_i, compute_local_rc_i = s[compute_local].split(compute_local_rc, factor=8)
compute_local_rc_o_o, compute_local_rc_o_i = s[compute_local].split(compute_local_rc_o_i, factor=4)
compute_local_ry_o_i, compute_local_ry_i = s[compute_local].split(compute_local_ry, factor=3)
compute_local_ry_o_o, compute_local_ry_o_i = s[compute_local].split(compute_local_ry_o_i, factor=1)
compute_local_rx_o_i, compute_local_rx_i = s[compute_local].split(compute_local_rx, factor=3)
compute_local_rx_o_o, compute_local_rx_o_i = s[compute_local].split(compute_local_rx_o_i, factor=1)
s[compute_local].reorder(compute_local_nn_c_o_o_o_o, compute_local_ff_c_o_o_o_o, compute_local_yy_c_o_o_o_o, compute_local_xx_c_o_o_o_o, compute_local_nn_c_o_o_o_i, compute_local_ff_c_o_o_o_i, compute_local_yy_c_o_o_o_i, compute_local_xx_c_o_o_o_i, compute_local_nn_c_o_o_i, compute_local_ff_c_o_o_i, compute_local_yy_c_o_o_i, compute_local_xx_c_o_o_i, compute_local_rc_o_o, compute_local_ry_o_o, compute_local_rx_o_o, compute_local_rc_o_i, compute_local_ry_o_i, compute_local_rx_o_i, compute_local_nn_c_o_i, compute_local_ff_c_o_i, compute_local_yy_c_o_i, compute_local_xx_c_o_i, compute_local_rc_i, compute_local_ry_i, compute_local_rx_i, compute_local_nn_c_i, compute_local_ff_c_i, compute_local_yy_c_i, compute_local_xx_c_i)
compute_nn_o_i, compute_nn_i = s[compute].split(compute_nn, factor=1)
compute_nn_o_o_i, compute_nn_o_i = s[compute].split(compute_nn_o_i, factor=1)
compute_nn_o_o_o, compute_nn_o_o_i = s[compute].split(compute_nn_o_o_i, factor=1)
compute_ff_o_i, compute_ff_i = s[compute].split(compute_ff, factor=1)
compute_ff_o_o_i, compute_ff_o_i = s[compute].split(compute_ff_o_i, factor=16)
compute_ff_o_o_o, compute_ff_o_o_i = s[compute].split(compute_ff_o_o_i, factor=2)
compute_yy_o_i, compute_yy_i = s[compute].split(compute_yy, factor=7)
compute_yy_o_o_i, compute_yy_o_i = s[compute].split(compute_yy_o_i, factor=1)
compute_yy_o_o_o, compute_yy_o_o_i = s[compute].split(compute_yy_o_o_i, factor=1)
compute_xx_o_i, compute_xx_i = s[compute].split(compute_xx, factor=1)
compute_xx_o_o_i, compute_xx_o_i = s[compute].split(compute_xx_o_i, factor=7)
compute_xx_o_o_o, compute_xx_o_o_i = s[compute].split(compute_xx_o_o_i, factor=1)
s[compute].reorder(compute_nn_o_o_o, compute_ff_o_o_o, compute_yy_o_o_o, compute_xx_o_o_o, compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, compute_xx_o_o_i, compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, compute_xx_o_i, compute_nn_i, compute_ff_i, compute_yy_i, compute_xx_i)
s[compute_local].compute_at(s[compute], compute_xx_o_i)
kernel_shared = s.cache_read(kernel, "shared", [compute_local])
kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3 = tuple(kernel_shared.op.axis)
s[kernel_shared].compute_at(s[compute_local], compute_local_rx_o_o)
pad_temp_shared = s.cache_read(pad_temp, "shared", [compute_local])
pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)
s[pad_temp_shared].compute_at(s[compute_local], compute_local_rx_o_o)
s[pad_temp].compute_inline()
compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused = s[compute].fuse(compute_nn_o_o_o, compute_ff_o_o_o, compute_yy_o_o_o, compute_xx_o_o_o)
s[compute].bind(compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused, te.thread_axis("blockIdx.x"))
compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused = s[compute].fuse(compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, compute_xx_o_o_i)
s[compute].bind(compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused, te.thread_axis("vthread"))
compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused = s[compute].fuse(compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, compute_xx_o_i)
s[compute].bind(compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused, te.thread_axis("threadIdx.x"))
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[kernel_shared].fuse(kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3)
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=3)
s[kernel_shared].vectorize(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=112)
s[kernel_shared].bind(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=4)
s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=112)
s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))
s[compute_local].pragma(compute_local_nn_c_o_o_o_o, "auto_unroll_max_step", 512)
s[compute_local].pragma(compute_local_nn_c_o_o_o_o, "unroll_explicit", True)
return s
target = 'cuda -model=t4'
ctx = tvm.gpu(0)
s = schedule_0(data, kernel, out)
#print(tvm.lower(s, [data, kernel, out], simple_mode=True))
func = tvm.build(s, [data, kernel, out], target)
#print(func.imported_modules[0].get_source())
# TODO: Correctness checking
data_np = np.random.uniform(size=[v.value for v in data.shape]).astype(np.float32)
weight_np = np.random.uniform(size=[v.value for v in kernel.shape]).astype(np.float32)
# Evaluate execution time
data_args = []
data_args.append(tvm.nd.array(data_np, ctx=ctx))
data_args.append(tvm.nd.array(weight_np, ctx=ctx))
data_args.append(tvm.nd.empty([v.value for v in out.shape], ctx=ctx))
evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=1000)
print(
"Median execution time: %.3f ms"
% (np.median(evaluator(*data_args).results) * 1000)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment