Created
May 9, 2023 22:35
-
-
Save HDCharles/c6413717039002c2c20b6cd669edba3e to your computer and use it in GitHub Desktop.
triton graph for safe_int_mm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch import empty_strided, as_strided, device | |
from torch._inductor.codecache import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
aten = torch.ops.aten | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_cdhernandez/va/cvakdsbvsebtiaosp3kolpke7cfvzv5o6jcxjkfli4t2yuz2vum2.py | |
# Original ATen: aten._int_mm | |
# aten._int_mm => _int_mm | |
triton_unk_fused__int_mm_0 = async_compile.triton('triton_', ''' | |
import triton.language as tl | |
import triton | |
from torch._inductor.triton_heuristics import template | |
from torch._inductor.utils import instance_descriptor | |
from torch._inductor import triton_helpers | |
@template(num_stages=2, num_warps=1, meta={'signature': {0: '*i8', 1: '*i8', 2: '*i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) | |
@triton.jit | |
def triton_(arg_A, arg_B, out_ptr0): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = False | |
ALLOW_TF32 : tl.constexpr = False | |
ACC_TYPE : tl.constexpr = tl.int32 | |
BLOCK_M : tl.constexpr = 16 | |
BLOCK_N : tl.constexpr = 16 | |
BLOCK_K : tl.constexpr = 32 | |
A = arg_A | |
B = arg_B | |
M = 8 | |
N = 8 | |
K = 17 | |
stride_am = 17 | |
stride_ak = 1 | |
stride_bk = 8 | |
stride_bn = 1 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (8*idx_m) | |
tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask) | |
''') | |
import triton | |
import triton.language as tl | |
from torch._inductor.triton_heuristics import grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream | |
import torch._inductor.kernel.mm_common | |
meta0 = {'GROUP_M': 8, 'EVEN_K': False, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.int32', 'BLOCK_M': 16, 'BLOCK_N': 16, 'BLOCK_K': 32} | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1 = args | |
args.clear() | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) # no-op to ensure context | |
buf0 = empty_strided((8, 8), (8, 1), device='cuda', dtype=torch.int32) | |
stream0 = get_cuda_stream(0) | |
triton_unk_fused__int_mm_0.run(arg0_1, arg1_1, buf0, grid=torch._inductor.kernel.mm_common.mm_grid(8, 8, meta0), stream=stream0) | |
del arg0_1 | |
del arg1_1 | |
return (buf0, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((8, 17), (17, 1), device='cuda:0', dtype=torch.int8) | |
arg1_1 = rand_strided((17, 8), (8, 1), device='cuda:0', dtype=torch.int8) | |
return print_performance(lambda: call([arg0_1, arg1_1]), times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.utils import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment