Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created November 4, 2024 18:33
Show Gist options
  • Save davidberard98/a5ead2e729ef7cb951abec1ddd226ff5 to your computer and use it in GitHub Desktop.
Save davidberard98/a5ead2e729ef7cb951abec1ddd226ff5 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_dberard/mo/cmoajmagos2lkqlkap4bce7xzq4tbfcudx7zomabn5saz6tk6g57.py
# Topologically Sorted Source Nodes: [logcumsumexp], Original ATen: [aten.logcumsumexp]
# Source node to ATen node mapping:
# logcumsumexp => logcumsumexp
# Graph fragment:
# %logcumsumexp : [num_users=1] = call_function[target=torch.ops.aten.logcumsumexp.default](args = (%arg2_1, 0), kwargs = {})
triton_red_fused_logcumsumexp_0 = async_compile.triton('triton_red_fused_logcumsumexp_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton.jit
def _triton_helper_fn_minimum_maximum_ne_isinf_bitwise_not_bitwise_or_sub_exp_log1p_add_where0(arg0_0, arg1_0):
tmp0 = triton_helpers.minimum(arg0_0, arg1_0)
tmp1 = triton_helpers.maximum(arg0_0, arg1_0)
tmp2 = tmp0 != tmp1
tmp3 = libdevice.isinf(tmp0).to(tl.int1)
tmp4 = ~tmp3
tmp5 = tmp2 | tmp4
tmp6 = tmp0 - tmp1
tmp7 = tl_math.exp(tmp6)
tmp8 = libdevice.log1p(tmp7)
tmp9 = tmp8 + tmp1
tmp10 = tl.where(tmp5, tmp9, arg0_0)
return tmp10
@triton_heuristics.reduction(
size_hints=[32, 16],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=132, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_logcumsumexp_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '63A118008A8D9F5C10AB08E75F67129677D36B011F4277B312BCA9383911253E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_red_fused_logcumsumexp_0(in_ptr0, out_ptr0, ks0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp3 = tl.full([XBLOCK, 1], float('nan'), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (ks0*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4, = tl.associative_scan((tmp2,), 1, _triton_helper_fn_minimum_maximum_ne_isinf_bitwise_not_bitwise_or_sub_exp_log1p_add_where0)
tmp5 = triton_helpers.select_one((tmp4), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)
tmp6 = triton_helpers.minimum(tmp3, tmp5)
tmp7 = triton_helpers.maximum(tmp3, tmp5)
tmp8 = tmp6 != tmp7
tmp9 = libdevice.isinf(tmp6).to(tl.int1)
tmp10 = ~tmp9
tmp11 = tmp8 | tmp10
tmp12 = tmp6 - tmp7
tmp13 = tl_math.exp(tmp12)
tmp14 = libdevice.log1p(tmp13)
tmp15 = tmp14 + tmp7
tmp16 = tl.where(tmp11, tmp15, tmp3)
tmp17 = triton_helpers.minimum(tmp3, tmp4)
tmp18 = triton_helpers.maximum(tmp3, tmp4)
tmp19 = tmp17 != tmp18
tmp20 = libdevice.isinf(tmp17).to(tl.int1)
tmp21 = ~tmp20
tmp22 = tmp19 | tmp21
tmp23 = tmp17 - tmp18
tmp24 = tl_math.exp(tmp23)
tmp25 = libdevice.log1p(tmp24)
tmp26 = tmp25 + tmp18
tmp27 = tl.where(tmp22, tmp26, tmp3)
tmp28 = tl.where(roffset > 0, tmp27, tmp4)
tmp3 = tl.where(roffset > 0, tmp16, tmp5)
tl.store(out_ptr0 + (x0 + (ks0*r1)), tmp28, rmask & xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
s0 = arg0_1
s1 = arg1_1
assert_size_stride(arg2_1, (s0, s1), (s1, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((s0, s1), (s1, 1), torch.float32)
# Topologically Sorted Source Nodes: [logcumsumexp], Original ATen: [aten.logcumsumexp]
stream0 = get_raw_stream(0)
triton_red_fused_logcumsumexp_0.run(arg2_1, buf0, s1, s1, s0, grid=grid(s1), stream=stream0)
del arg2_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = 16
arg1_1 = 32
arg2_1 = rand_strided((16, 32), (32, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment