Skip to content

Instantly share code, notes, and snippets.

@csullivan
Created September 2, 2021 23:22
Show Gist options
  • Save csullivan/170c9c3a3b22240d0fb347f138342801 to your computer and use it in GitHub Desktop.
Save csullivan/170c9c3a3b22240d0fb347f138342801 to your computer and use it in GitHub Desktop.
import tvm
from tvm import te
def intrin_vadd(xo, m, n):
x = te.placeholder((n,), name="vx")
y = te.placeholder((n,), name="vy")
if m % n == 0:
body = lambda i: x[i] + y[i]
else:
body = lambda i: tvm.tir.Select(xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype))
z = te.compute(x.shape, body, name="z")
def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
return tvm.tir.call_packed("vadd", xx, yy, zz)
buffer_params = {"offset_factor": 16}
return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params)
def add(m):
x = te.placeholder((m,), name="x")
y = te.placeholder((m,), name="y")
z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
return x, y, z
def check_cache_write(m, factor):
x, y, z = add(m)
s = te.create_schedule(z.op)
_, _ = s[z].split(z.op.axis[0], factor=factor)
z_global = s.cache_write(z, "global")
xo, xi = z_global.op.axis
vadd = intrin_vadd(xo, m, factor)
s[z_global].tensorize(xi, vadd)
print(tvm.lower(s, [x, y, z]))
if __name__ == "__main__":
check_cache_write(128, 16)
#check_cache_write(129, 16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment