Created
July 25, 2023 18:37
-
-
Save HDCharles/d8a1ff7d52fcafcb7a0d880596b2c0c1 to your computer and use it in GitHub Desktop.
codegen
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch import empty_strided, as_strided, device | |
from torch._inductor.codecache import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
aten = torch.ops.aten | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_cdhernandez/me/cme6svhcsihecrrpkvmkdyfs65x3wwf3kkkf372kqjwnoi3ltmyn.py | |
# Original ATen: aten.add, aten.mm, aten.mul | |
# aten.add => add | |
# aten.mm => tuned_mixed_dtype_mm | |
# aten.mul => mul | |
triton_unk_fused_add_mm_mul_0 = async_compile.triton('triton_unk_fused_add_mm_mul_0', ''' | |
import triton.language as tl | |
import triton | |
from torch._inductor.triton_heuristics import template | |
from torch._inductor.utils import instance_descriptor | |
from torch._inductor import triton_helpers | |
@template(num_stages=4, num_warps=8, meta={'signature': {0: '*bf16', 1: '*i8', 2: '*bf16', 3: '*bf16', 4: '*bf16'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}) | |
@triton.jit | |
def triton_unk_fused_add_mm_mul_0(arg_A, arg_B, in_ptr2, in_ptr3, out_ptr1): | |
GROUP_M : tl.constexpr = 8 | |
EVEN_K : tl.constexpr = True | |
ALLOW_TF32 : tl.constexpr = False | |
ACC_TYPE : tl.constexpr = tl.float32 | |
B_PROLOGUE_CAST_TYPE : tl.constexpr = tl.bfloat16 | |
BLOCK_M : tl.constexpr = 128 | |
BLOCK_N : tl.constexpr = 64 | |
BLOCK_K : tl.constexpr = 32 | |
A = arg_A | |
B = arg_B | |
M = 4096 | |
N = 1280 | |
K = 1280 | |
stride_am = 1280 | |
stride_ak = 1 | |
stride_bk = 1280 | |
stride_bn = 1 | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
rk = tl.arange(0, BLOCK_K) | |
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.) | |
if B_PROLOGUE_CAST_TYPE is not None: | |
b = b.to(B_PROLOGUE_CAST_TYPE) | |
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
# inductor generates a suffix | |
xindex = idx_n + (1280*idx_m) | |
tmp0 = tl.load(in_ptr2 + (tl.broadcast_to(idx_n, mask.shape)), mask, eviction_policy='evict_last').to(tl.float32) | |
tmp2 = tl.load(in_ptr3 + (tl.broadcast_to(idx_n, mask.shape)), mask, eviction_policy='evict_last').to(tl.float32) | |
tmp1 = acc * tmp0 | |
tmp3 = tmp1 + tmp2 | |
tl.store(out_ptr1 + (tl.broadcast_to(idx_n + (1280*idx_m), mask.shape)), tmp3, mask) | |
''') | |
import triton | |
import triton.language as tl | |
from torch._inductor.triton_heuristics import grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream | |
import torch._inductor.kernel.mm_common | |
meta0 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.float32', 'B_PROLOGUE_CAST_TYPE': 'tl.bfloat16', 'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32} | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1, arg2_1, arg3_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (1280, 1280), (1280, 1)) | |
assert_size_stride(arg1_1, (1280, ), (1, )) | |
assert_size_stride(arg2_1, (1280, ), (1, )) | |
assert_size_stride(arg3_1, (4096, 1280), (1280, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) # no-op to ensure context | |
buf1 = empty_strided((4096, 1280), (1280, 1), device='cuda', dtype=torch.bfloat16) | |
stream0 = get_cuda_stream(0) | |
triton_unk_fused_add_mm_mul_0.run(arg3_1, arg0_1, arg1_1, arg2_1, buf1, grid=torch._inductor.kernel.mm_common.mm_grid(4096, 1280, meta0), stream=stream0) | |
del arg0_1 | |
del arg1_1 | |
del arg2_1 | |
del arg3_1 | |
return (buf1, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((1280, 1280), (1280, 1), device='cuda:0', dtype=torch.int8) | |
arg1_1 = rand_strided((1280, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg2_1 = rand_strided((1280, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
arg3_1 = rand_strided((4096, 1280), (1280, 1), device='cuda:0', dtype=torch.bfloat16) | |
return print_performance(lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]), times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.utils import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment