Created
July 9, 2021 19:43
-
-
Save comaniac/605d92d8e160536926db1b0bab8f26a0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tvm | |
from tvm import te, tir, topi | |
from tvm.topi import utils | |
dev = tvm.device("gpu", 0) | |
target = tvm.target.Target("cuda") | |
### Copy from topi/cuda/injective.py and make block/thread num configurable | |
def schedule_injective_from_existing(sch, out, max_block, num_thread): | |
fused = sch[out].fuse(*sch[out].op.axis) | |
# vectorize on fp16 data type. This allows to better utilize the memory | |
# bandwidth. | |
vector_width = 4 if out.dtype == "float16" else 1 | |
is_dynamic_output = False | |
for dim in out.shape: | |
if not isinstance(dim, tvm.tir.IntImm): | |
is_dynamic_output = True | |
break | |
out_len = utils.prod(out.shape) | |
try: | |
const_size = utils.get_const_int(out_len) | |
need_block_split = const_size > max_block * num_thread * vector_width | |
except ValueError: | |
need_block_split = False | |
const_size = 0 | |
if vector_width > 1: | |
fused, v = sch[out].split(fused, vector_width) | |
sch[out].vectorize(v) | |
if need_block_split: | |
xo, xi = sch[out].split(fused, factor=num_thread * max_block) | |
bx, tx = sch[out].split(xi, factor=num_thread) | |
sch[out].reorder(bx, tx, xo) | |
sch[out].bind(bx, te.thread_axis("blockIdx.x")) | |
sch[out].bind(tx, te.thread_axis("threadIdx.x")) | |
else: | |
# Use less threads for dynamic shape ops to avoid runtime error. | |
if is_dynamic_output: | |
num_thread //= 2 | |
if const_size != 0 and const_size < num_thread: | |
bx, tx = sch[out].split(fused, factor=const_size) | |
else: | |
bx, tx = sch[out].split(fused, factor=num_thread) | |
sch[out].bind(tx, te.thread_axis("threadIdx.x")) | |
sch[out].bind(bx, te.thread_axis("blockIdx.x")) | |
return sch | |
def schedule_injective(outs, max_block, num_thread): | |
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs | |
s = te.create_schedule([x.op for x in outs]) | |
tvm.te.schedule.AutoInlineInjective(s) | |
for out in outs: | |
if not utils.is_empty_shape(out.shape): | |
schedule_injective_from_existing(s, out, max_block, num_thread) | |
return s | |
### Schedule end | |
def case0(): | |
# FP32 multiply + FP16 add | |
A = te.placeholder((1,), name="A", dtype="float32") | |
B = te.placeholder((768, 3072), name="B", dtype="float32") | |
C = te.placeholder((768, 3072), name="C", dtype="float32") | |
D = te.compute(B.shape, lambda *i: A[0] * B[i]) | |
E = topi.cast(D, "float16") | |
F = topi.cast(C, "float16") | |
G = te.compute(C.shape, lambda *i: E[i] + F[i], name="G") | |
args = [A, B, C, G] | |
a = tvm.nd.array(np.random.uniform(size=(1,)).astype(A.dtype), dev) | |
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev) | |
c = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(C.dtype), dev) | |
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev) | |
data = [a, b, c, g] | |
return "FP32Mul_FP16Add", G, args, data | |
def case1(): | |
# FP16 multiply + FP16 add | |
A = te.placeholder((1,), name="A", dtype="float32") | |
A_ = topi.cast(A, "float16") | |
B = te.placeholder((768, 3072), name="B", dtype="float32") | |
B_ = topi.cast(B, "float16") | |
C = te.placeholder((768, 3072), name="C", dtype="float32") | |
D = te.compute(B.shape, lambda *i: A_[0] * B_[i]) | |
F = topi.cast(C, "float16") | |
G = te.compute(C.shape, lambda *i: D[i] + F[i], name="G") | |
args = [A, B, C, G] | |
a = tvm.nd.array(np.random.uniform(size=(1,)).astype(A.dtype), dev) | |
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev) | |
c = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(C.dtype), dev) | |
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev) | |
data = [a, b, c, g] | |
return "FP16Mul_FP16Add", G, args, data | |
def case2(): | |
# FP16 multiply | |
A = te.placeholder((768, 3072), name="A", dtype="float32") | |
A_ = topi.cast(A, "float16") | |
B = te.placeholder((768, 3072), name="B", dtype="float32") | |
B_ = topi.cast(B, "float16") | |
G = te.compute(B.shape, lambda *i: A_[i] * B_[i], name="G") | |
args = [A, B, G] | |
a = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(A.dtype), dev) | |
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev) | |
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev) | |
data = [a, b, g] | |
return "FP16Mul", G, args, data | |
def case3(): | |
# just a cast | |
B = te.placeholder((768, 3072), name="B", dtype="float32") | |
G = topi.cast(B, "float16") | |
args = [B, G] | |
b = tvm.nd.array(np.random.uniform(size=(768, 3072)).astype(B.dtype), dev) | |
g = tvm.nd.array(np.zeros((768, 3072), dtype=G.dtype), dev) | |
data = [b, g] | |
return "Cast", G, args, data | |
cases = [case0, case1, case2, case3] | |
# Benchmark | |
for case in cases: | |
name, G, args, data = case() | |
logs = [] | |
curr_best = (-1, float("inf")) | |
for idx, cfg in enumerate([(256, 1), (128, 1), (64, 1), (256, 2), (128, 2), (64, 2)]): | |
block = cfg[0] | |
thread = target.max_num_threads // cfg[1] | |
s = schedule_injective(G, block, thread) | |
func = tvm.build(s, args, target, name="func") | |
use_half2 = "(half2*)" in func.imported_modules[0].get_source() | |
evaluator = func.time_evaluator(func.entry_name, dev, number=100) | |
mean_time = evaluator(*data).mean * 1000 | |
if mean_time < curr_best[1]: | |
curr_best = (idx, mean_time) | |
logs.append("%18s, block=%3d, thread=%4d, prod=%6d: %.4fms, use-half2? %s" % (name, block, thread, block * thread, mean_time, use_half2)) | |
for idx, rpt in enumerate(logs): | |
print("%s %s" % (rpt, "--> best" if idx == curr_best[0] else "")) | |
print("==============") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment