Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created October 31, 2024 02:24
Show Gist options
  • Select an option

  • Save leslie-fang-intel/af3e3c8e8d603bfcf4ca707ff1352566 to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/af3e3c8e8d603bfcf4ca707ff1352566 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_zailiwan/7m/c7mvzabkfbqphgue6ybjw4pie7x43i3canab67s3j2urllq5hcja.py
# Topologically Sorted Source Nodes: [setitem], Original ATen: [aten.index_put]
# Source node to ATen node mapping:
# setitem => index_put
# Graph fragment:
# %index_put : [num_users=0] = call_function[target=torch.ops.aten.index_put_.default](args = (%arg3_1, [%add], %view_1), kwargs = {})
triton_poi_fused_index_put_0 = async_compile.triton('triton_poi_fused_index_put_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints=[64],
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=70, major=7, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=80, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_0', 'mutated_arg_names': ['in_ptr1', 'out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '024CABE5EDBCB1F8EDED0A0630073015B30C6F277AA90C3672353DD2CFA2D29D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_0(in_ptr0, in_ptr1, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // ks0)
x0 = xindex % ks0
tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
tmp1 = tl.full([1], 10, tl.int64)
tmp2 = tmp0 + tmp1
tmp3 = ks1
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < ks1)) | ~(xmask), "index out of bounds: 0 <= tmp6 < ks1")
tmp8 = tl.load(in_ptr1 + (x0 + (ks0*tmp6)), xmask, eviction_policy='evict_last')
tmp9 = tmp8 + tmp8
tl.store(out_ptr0 + (x0 + (ks0*tmp6)), tmp9, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
s0 = arg0_1
s1 = arg2_1
assert_size_stride(arg1_1, (1, s0), (s0, 1))
assert_size_stride(arg3_1, (s0, s1), (s1, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Topologically Sorted Source Nodes: [setitem], Original ATen: [aten.index_put]
triton_poi_fused_index_put_0_xnumel = s0*s1
stream0 = get_raw_stream(0)
triton_poi_fused_index_put_0.run(arg1_1, arg3_1, arg3_1, s1, s0, triton_poi_fused_index_put_0_xnumel, grid=grid(triton_poi_fused_index_put_0_xnumel), stream=stream0)
del arg1_1
del arg3_1
return ()
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = 2
arg1_1 = rand_strided((1, 2), (2, 1), device='cuda:0', dtype=torch.int64)
arg2_1 = 32
arg3_1 = rand_strided((2, 32), (32, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment