Created
July 30, 2021 14:53
-
-
Save yzh119/33786c349be41731c0dd7acacac2dd58 to your computer and use it in GitHub Desktop.
sparse workloads in tir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import tir | |
from tvm.script import ty | |
@tvm.script.tir | |
def csr_spmm(indptr_: ty.handle, indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: | |
m = tir.var('int32') | |
n = tir.var('int32') | |
k = tir.var('int32') | |
nnz = tir.var('int32') | |
indptr = tir.match_buffer(indptr_, [m + 1], 'int32') | |
indices = tir.match_buffer(indices_, [nnz], 'int32') | |
A = tir.match_buffer(a_data, [nnz], 'float32') | |
B = tir.match_buffer(b, [k, n], 'float32') | |
C = tir.match_buffer(c, [m, n], 'float32') | |
with tir.block([m, n], 'spmm_outer') as [vi, vj]: | |
with tir.init(): | |
C[vi, vj] = 0. | |
with tir.block([tir.reduce_axis(indptr[vi], indptr[vi + 1])], 'spmm_inner') as [vk]: | |
C[vi, vj] = C[vi, vj] + A[vk] * B[indices[vk], vj] | |
@tvm.script.tir | |
def csr_sddmm(row_: ty.handle, col_: ty.handle, a: ty.handle, b: ty.handle, c: ty.handle) -> None: | |
m = tir.var('int32') | |
n = tir.var('int32') | |
k = tir.var('int32') | |
nnz = tir.var('int32') | |
row = tir.match_buffer(row_, [nnz,], 'int32') | |
col = tir.match_buffer(col_, [nnz,], 'int32') | |
A = tir.match_buffer(a, [m, k], 'float32') | |
B = tir.match_buffer(b, [k, n], 'float32') | |
C = tir.match_buffer(c, [nnz,], 'float32') | |
with tir.block([nnz, tir.reduce_axis(0, k)], 'sddmm') as [eid, vk]: | |
with tir.init(): | |
C[eid] = 0. | |
C[eid] = C[eid] + A[row[eid], vk] * B[vk, col[eid]] | |
@tvm.script.tir | |
def bsr_spmm(indptr_: ty.handle, indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: | |
mb = tir.var('int32') | |
n = tir.var('int32') | |
kb = tir.var('int32') | |
nnzb = tir.var('int32') | |
block_size = tir.var('int32') | |
indptr = tir.match_buffer(indptr_, [mb + 1], 'int32') | |
indices = tir.match_buffer(indices_, [nnzb], 'int32') | |
A = tir.match_buffer(a_data, [nnzb, block_size, block_size], 'float32') | |
B = tir.match_buffer(b, [kb, block_size, n], 'float32') | |
C = tir.match_buffer(c, [mb, block_size, n], 'float32') | |
with tir.block([mb, tir.reduce_axis(0, block_size), block_size, n], 'spmm_outer') as [io, ki, ii, j]: | |
with tir.init(): | |
C[io, ii, j] = 0. | |
with tir.block([tir.reduce_axis(indptr[io], indptr[io + 1])], 'spmm_inner') as [ko]: | |
C[io, ii, j] = C[io, ii, j] + A[ko, ii, ki] * B[indices[ko], ki, j] | |
@tvm.script.tir | |
def bsr_sddmm(row_: ty.handle, col_: ty.handle, a: ty.handle, b: ty.handle, c: ty.handle) -> None: | |
mb = tir.var('int32') | |
nb = tir.var('int32') | |
k = tir.var('int32') | |
nnzb = tir.var('int32') | |
block_size = tir.var('int32') | |
row = tir.match_buffer(row_, [nnzb,], 'int32') | |
col = tir.match_buffer(col_, [nnzb,], 'int32') | |
A = tir.match_buffer(a, [mb, block_size, k], 'float32') | |
B = tir.match_buffer(b, [k, nb, block_size], 'float32') | |
C = tir.match_buffer(c, [nnzb, block_size, block_size], 'float32') | |
with tir.block([nnzb, block_size, block_size, tir.reduce_axis(0, k)], 'sddmm') as [bid, vi, vj, vk]: | |
with tir.init(): | |
C[bid, vi, vj] = 0. | |
C[bid, vi, vj] = C[bid, vi, vj] + A[row[bid], vi, vk] * B[vk, col[bid], vj] | |
@tvm.script.tir | |
def ell_spmm(indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None: | |
mb = tir.var('int32') | |
n = tir.var('int32') | |
kb = tir.var('int32') | |
block_size = tir.var('int32') | |
ell_cols = tir.var('int32') | |
indices = tir.match_buffer(indices_, [mb, ell_cols], 'int32') | |
A = tir.match_buffer(a_data, [mb, ell_cols, block_size, block_size], 'float32') | |
B = tir.match_buffer(b, [kb, block_size, n], 'float32') | |
C = tir.match_buffer(c, [mb, block_size, n], 'float32') | |
with tir.block([mb, tir.reduce_axis(0, ell_cols), tir.reduce_axis(0, block_size), block_size, n], 'spmm') as [io, ko, ki, ii, j]: | |
with tir.init(): | |
C[io, ii, j] = 0. | |
C[io, ii, j] = C[io, ii, j] + A[io, ko, ii, ki] * B[indices[io, ko], ki, j] | |
if __name__ == '__main__': | |
print(csr_spmm) | |
print(csr_sddmm) | |
print(bsr_spmm) | |
print(bsr_sddmm) | |
print(ell_spmm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment