Created
August 5, 2021 08:00
-
-
Save yzh119/9c71e6b0b9c43741bf7f1288b195d275 to your computer and use it in GitHub Desktop.
ell spmm with multi-level tiling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import tir | |
from tvm.script import ty | |
from tvm.tir.schedule.schedule import Schedule | |
@tvm.script.tir | |
def ell_spmm(indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: | |
mb = tir.var('int32') | |
n = tir.var('int32') | |
kb = tir.var('int32') | |
block_size = tir.var('int32') | |
ell_cols = tir.var('int32') | |
indices = tir.match_buffer(indices_, [mb, ell_cols], 'int32') | |
A = tir.match_buffer(a_data, [mb, ell_cols, block_size, block_size], 'float32') | |
B = tir.match_buffer(b, [kb, block_size, n], 'float32') | |
C = tir.match_buffer(c, [mb, block_size, n], 'float32') | |
with tir.block([mb, tir.reduce_axis(0, ell_cols), tir.reduce_axis(0, block_size), block_size, n], 'spmm') as [io, ko, ki, ii, j]: | |
with tir.init(): | |
C[io, ii, j] = 0. | |
C[io, ii, j] = C[io, ii, j] + A[io, ko, ii, ki] * B[indices[io, ko], ki, j] | |
def schedule(sch: tir.schedule): | |
block = sch.get_block("spmm") | |
io, ko, ki, ii, j = sch.get_loops(block) | |
ii, i_tc = sch.split(ii, factors=[none, 16]) | |
ki, k_tc = sch.split(ki, factors=[none, 16]) | |
j, j_tc = sch.split(j, factors=[none, 16]) | |
sch.reorder( | |
io, j, ko, ii, ki, | |
i_tc, j_tc, k_tc, | |
) | |
block_inner = sch.blockize(i_tc) | |
block_outer, block_inner = block_inner, block | |
del block | |
i0, i1, i2, i3 = sch.split(io, factors=sch.sample_perfect_tile(io, n=4)) | |
j0, j1, j2, j3, j4 = sch.split(j, factors=sch.sample_perfect_tile(j, n=5)) | |
k0, k1 = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2)) | |
sch.reorder( | |
# fmt: off | |
i0, j0, # s => blockidx.x | |
i1, j1, # s => vthread | |
i2, j2, # s => threadidx.x | |
# cache_write here | |
k0, # r | |
# vectorized cooperative fetching here | |
k1, # r | |
i3, j3, # s | |
ki, # r | |
ii, j4, # s | |
# fmt: on | |
) | |
block_idx = sch.fuse(i0, j0) | |
vthread = sch.fuse(i1, j1) | |
thread_idx = sch.fuse(i2, j2) | |
sch.bind(block_idx, "blockidx.x") | |
sch.bind(vthread, "vthread") | |
sch.bind(thread_idx, "threadidx.x") | |
block_write_c = sch.cache_write(block_outer, 0, "local") | |
block_outer, block_write_c = block_write_c, block_outer | |
sch.reverse_compute_at(block_write_c, thread_idx) | |
def fetch_to_shared(block, idx, ndim): | |
block_read = sch.cache_read(block, idx, "shared") | |
sch.compute_at(block_read, k0) | |
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) | |
fused_0, fused_1 = sch.split(fused, factors=[none, 4]) | |
sch.mark_loop(fused_0, "loop_type", "lazy_cooperative_fetch") | |
sch.vectorize(fused_1) | |
fetch_to_shared(block_outer, 1, 2) | |
fetch_to_shared(block_outer, 2, 2) | |
# read indices from global to local | |
indices_read = sch.cache_read(block_inner, 3, "local") | |
loop = sch.get_loops(block_outer)[-1] | |
sch.compute_at(indices_read, loop) | |
# step 3. postproc-rewrite-tensorize | |
# step 3.1. cache read | |
loop = sch.get_loops(block_outer)[-1] | |
block_read_a = sch.cache_read(block_inner, 1, "wmma.matrix_a") | |
block_read_b = sch.cache_read(block_inner, 2, 'wmma.matrix_b') | |
sch.compute_at(block_read_a, loop) | |
sch.compute_at(block_read_b, loop) | |
# step 3.2. cache write | |
block_write_c = sch.cache_write(block_outer, 0, 'wmma.accumulator') | |
block_outer, block_write_c = block_write_c, block_outer | |
sch.reverse_compute_at(block_write_c, loop) | |
# step 3.3. decompose | |
loop = sch.get_loops(block_outer)[3] | |
block_init_c = sch.decompose_reduction(block_outer, loop) | |
print(tvm.script.asscript(sch.mod['main'])) | |
if __name__ == '__main__': | |
f = ell_spmm | |
m, n, k = 4096, 512, 4096 | |
block_size = 32 | |
ell_cols = 16 | |
indices_, a_data, b, c = f.params | |
f = f.specialize({indices_: tir.decl_buffer([m // block_size, ell_cols]), | |
a_data: tir.decl_buffer([m // block_size, ell_cols, block_size, block_size]), | |
b: tir.decl_buffer([k // block_size, block_size, n]), | |
c: tir.decl_buffer([m // block_size, block_size, n])}) | |
sch = schedule(f, debug_mode=true, traced=true) | |
schedule(sch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment