Skip to content

Instantly share code, notes, and snippets.

@yzhliu
Created July 4, 2018 00:18
Show Gist options
  • Save yzhliu/7a694a99be7f11ebfbf2d634111b6175 to your computer and use it in GitHub Desktop.
Save yzhliu/7a694a99be7f11ebfbf2d634111b6175 to your computer and use it in GitHub Desktop.
import tvm
def compute_conv2d(A, W, stride, padding):
batch_size, in_channel, height, width = A.shape
out_channel, _ = W.shape
kh = 1
kw = 1
out_height = (height + 2 * padding - kh) // stride + 1
out_width = (width + 2 * padding - kw) // stride + 1
A = tvm.compute((batch_size, height, width, in_channel), lambda n, h, w, c: A[n, c, h, w])
# convolution
oshape = (batch_size, out_channel, out_height, out_width)
ic = tvm.reduce_axis((0, in_channel), name='ic')
conv = tvm.compute(oshape, lambda n, oc, oh, ow:
tvm.sum(A[n, oh*stride+kh, ow*stride+kw, ic] * W[oc, ic],
axis=[ic]),
name='conv2d', tag="conv2d")
return conv
def matmul():
wgt = tvm.placeholder((1, 128))
inp = tvm.placeholder((16, 128))
k = tvm.reduce_axis((0, 128), name="k")
out = tvm.compute((16, 1),
lambda i, j: tvm.sum(inp(i, k) * wgt(j, k), axis=[k]))
def intrin_func(inputs, outputs):
def body():
irb = tvm.ir_builder.create()
irb.emit(tvm.call_extern(
"float32", "Matmul"))
return irb.get()
def reset():
return body()
def update():
return body()
return body(), reset(), update()
return tvm.decl_tensor_intrin(out.op, intrin_func, name="Matmul")
def schedule_conv2d(out):
s = tvm.create_schedule(out.op)
conv = out.op.output(0)
data, kernel = conv.op.input_tensors
batch, oc, oh, ow = s[conv].op.axis
ic, = s[conv].op.reduce_axis
s[conv].tensorize(ow, matmul())
return s
def verify_conv2d_nchw(batch, in_channel, in_height, in_width, num_filter, kernel, stride, padding):
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel), name='W')
B = compute_conv2d(A, W, stride, padding)
s = schedule_conv2d(B)
s = s.normalize()
print(tvm.lower(s, [A, W, B], simple_mode=True))
def test_conv2d_nchw():
verify_conv2d_nchw(batch=1, in_channel=128, in_height=16, in_width=16, num_filter=64, kernel=1, stride=1, padding=0)
if __name__ == "__main__":
test_conv2d_nchw()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment