Skip to content

Instantly share code, notes, and snippets.

@comaniac
Created July 9, 2021 19:43
Show Gist options
  • Save comaniac/605d92d8e160536926db1b0bab8f26a0 to your computer and use it in GitHub Desktop.
Save comaniac/605d92d8e160536926db1b0bab8f26a0 to your computer and use it in GitHub Desktop.
import numpy as np
import tvm
from tvm import te, tir, topi
from tvm.topi import utils
dev = tvm.device("gpu", 0)
target = tvm.target.Target("cuda")
### Copy from topi/cuda/injective.py and make block/thread num configurable
def schedule_injective_from_existing(sch, out, max_block, num_thread):
fused = sch[out].fuse(*sch[out].op.axis)
# vectorize on fp16 data type. This allows to better utilize the memory
# bandwidth.
vector_width = 4 if out.dtype == "float16" else 1
is_dynamic_output = False
for dim in out.shape:
if not isinstance(dim, tvm.tir.IntImm):
is_dynamic_output = True
break
out_len = utils.prod(out.shape)
try:
const_size = utils.get_const_int(out_len)
need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False
const_size = 0
if vector_width > 1:
fused, v = sch[out].split(fused, vector_width)
sch[out].vectorize(v)
if need_block_split:
xo, xi = sch[out].split(fused, factor=num_thread * max_block)
bx, tx = sch[out].split(xi, factor=num_thread)
sch[out].reorder(bx, tx, xo)
sch[out].bind(bx, te.thread_axis("blockIdx.x"))
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
else:
# Use less threads for dynamic shape ops to avoid runtime error.
if is_dynamic_output:
num_thread //= 2
if const_size != 0 and const_size < num_thread:
bx, tx = sch[out].split(fused, factor=const_size)
else:
bx, tx = sch[out].split(fused, factor=num_thread)
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
sch[out].bind(bx, te.thread_axis("blockIdx.x"))
return sch
def schedule_injective(outs, max_block, num_thread):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
tvm.te.schedule.AutoInlineInjective(s)
for out in outs:
if not utils.is_empty_shape(out.shape):
schedule_injective_from_existing(s, out, max_block, num_thread)
return s
### Schedule end
def case0():
# FP32 multiply + FP16 add
A = te.placeholder((1,), name="A", dtype="float32")
B = te.placeholder((768, 3072), name="B", dtype="float32")
C = te.placeholder((768, 3072), name="C", dtype="float32")
D = te.compute(B.shape, lambda *i: A[0] * B[i])
E = topi.cast(D, "float16")
F = topi.cast(C, "float16")
G = te.compute(C.shape, lambda *i: E[i] + F[i], name="G")
args = [A, B, C, G]
a = tvm.nd.array(np.random.uniform(size=(1,)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev)
c = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(C.dtype), dev)
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev)
data = [a, b, c, g]
return "FP32Mul_FP16Add", G, args, data
def case1():
# FP16 multiply + FP16 add
A = te.placeholder((1,), name="A", dtype="float32")
A_ = topi.cast(A, "float16")
B = te.placeholder((768, 3072), name="B", dtype="float32")
B_ = topi.cast(B, "float16")
C = te.placeholder((768, 3072), name="C", dtype="float32")
D = te.compute(B.shape, lambda *i: A_[0] * B_[i])
F = topi.cast(C, "float16")
G = te.compute(C.shape, lambda *i: D[i] + F[i], name="G")
args = [A, B, C, G]
a = tvm.nd.array(np.random.uniform(size=(1,)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev)
c = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(C.dtype), dev)
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev)
data = [a, b, c, g]
return "FP16Mul_FP16Add", G, args, data
def case2():
# FP16 multiply
A = te.placeholder((768, 3072), name="A", dtype="float32")
A_ = topi.cast(A, "float16")
B = te.placeholder((768, 3072), name="B", dtype="float32")
B_ = topi.cast(B, "float16")
G = te.compute(B.shape, lambda *i: A_[i] * B_[i], name="G")
args = [A, B, G]
a = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev)
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev)
data = [a, b, g]
return "FP16Mul", G, args, data
def case3():
# just a cast
B = te.placeholder((768, 3072), name="B", dtype="float32")
G = topi.cast(B, "float16")
args = [B, G]
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev)
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev)
data = [b, g]
return "Cast", G, args, data
cases = [case0, case1, case2, case3]
# Benchmark
for case in cases:
name, G, args, data = case()
logs = []
curr_best = (-1, float("inf"))
for idx, cfg in enumerate([(256, 1), (128, 1), (64, 1), (256, 2), (128, 2), (64, 2)]):
block = cfg[0]
thread = target.max_num_threads // cfg[1]
s = schedule_injective(G, block, thread)
func = tvm.build(s, args, target, name="func")
use_half2 = "(half2*)" in func.imported_modules[0].get_source()
evaluator = func.time_evaluator(func.entry_name, dev, number=100)
mean_time = evaluator(*data).mean * 1000
if mean_time < curr_best[1]:
curr_best = (idx, mean_time)
logs.append("%18s, block=%3d, thread=%4d, prod=%6d: %.4fms, use-half2? %s" % (name, block, thread, block * thread, mean_time, use_half2))
for idx, rpt in enumerate(logs):
print("%s %s" % (rpt, "--> best" if idx == curr_best[0] else ""))
print("==============")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment