Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created November 25, 2024 19:33
Show Gist options
  • Save davidberard98/9f7551f94d58f4a02423934316efd76d to your computer and use it in GitHub Desktop.
Save davidberard98/9f7551f94d58f4a02423934316efd76d to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_dberard/jp/cjppgi2xgpusm3gtom2ox3wa7qq5qs6yldf5xxci2jgbmg35tnmb.py
# Topologically Sorted Source Nodes: [grid_sampler_2d_default], Original ATen: [aten.grid_sampler_2d]
# Source node to ATen node mapping:
# grid_sampler_2d_default => add, add_1, add_2, add_3, add_4, add_5, add_6, floor, floor_1, full_default_11, full_default_2, full_default_5, full_default_8, ge, ge_1, ge_2, ge_3, ge_4, ge_5, ge_6, ge_7, index, index_1, index_2, index_3, logical_and, logical_and_1, logical_and_10, logical_and_11, logical_and_2, logical_and_3, logical_and_4, logical_and_5, logical_and_6, logical_and_7, logical_and_8, logical_and_9, lt, lt_1, lt_2, lt_3, lt_4, lt_5, lt_6, lt_7, mul, mul_1, mul_2, mul_3, mul_4, mul_5, mul_6, mul_7, mul_8, mul_9, sub, sub_1, sub_2, sub_3, sub_4, sub_5, sub_6, sub_7, where_11, where_2, where_5, where_8
# Graph fragment:
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select, 176.0), kwargs = {})
# %add : [num_users=5] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, 175.5), kwargs = {})
# %floor : [num_users=9] = call_function[target=torch.ops.aten.floor.default](args = (%add,), kwargs = {})
# %ge : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor, 0), kwargs = {})
# %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor, 352), kwargs = {})
# %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_1, 176.0), kwargs = {})
# %add_1 : [num_users=5] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, 175.5), kwargs = {})
# %floor_1 : [num_users=9] = call_function[target=torch.ops.aten.floor.default](args = (%add_1,), kwargs = {})
# %ge_1 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor_1, 0), kwargs = {})
# %lt_1 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor_1, 352), kwargs = {})
# %logical_and : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_1, %lt_1), kwargs = {})
# %logical_and_1 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt, %logical_and), kwargs = {})
# %logical_and_2 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge, %logical_and_1), kwargs = {})
# %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_1, %where]), kwargs = {})
# %add_2 : [num_users=8] = call_function[target=torch.ops.aten.add.Tensor](args = (%floor, 1), kwargs = {})
# %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_2, %add), kwargs = {})
# %add_3 : [num_users=8] = call_function[target=torch.ops.aten.add.Tensor](args = (%floor_1, 1), kwargs = {})
# %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_3, %add_1), kwargs = {})
# %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub_1), kwargs = {})
# %full_default_2 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_2, %mul_2, %full_default_2), kwargs = {})
# %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index, %where_2), kwargs = {})
# %ge_2 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_2, 0), kwargs = {})
# %lt_2 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_2, 352), kwargs = {})
# %ge_3 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor_1, 0), kwargs = {})
# %lt_3 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor_1, 352), kwargs = {})
# %logical_and_3 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_3, %lt_3), kwargs = {})
# %logical_and_4 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt_2, %logical_and_3), kwargs = {})
# %logical_and_5 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_2, %logical_and_4), kwargs = {})
# %index_1 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_4, %where_3]), kwargs = {})
# %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add, %floor), kwargs = {})
# %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_3, %add_1), kwargs = {})
# %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_2, %sub_3), kwargs = {})
# %full_default_5 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_5 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_5, %mul_3, %full_default_5), kwargs = {})
# %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_1, %where_5), kwargs = {})
# %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
# %ge_4 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor, 0), kwargs = {})
# %lt_4 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor, 352), kwargs = {})
# %ge_5 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_3, 0), kwargs = {})
# %lt_5 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_3, 352), kwargs = {})
# %logical_and_6 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_5, %lt_5), kwargs = {})
# %logical_and_7 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt_4, %logical_and_6), kwargs = {})
# %logical_and_8 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_4, %logical_and_7), kwargs = {})
# %index_2 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_7, %where_6]), kwargs = {})
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_2, %add), kwargs = {})
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_1, %floor_1), kwargs = {})
# %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_4, %sub_5), kwargs = {})
# %full_default_8 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_8, %mul_4, %full_default_8), kwargs = {})
# %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_2, %where_8), kwargs = {})
# %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %mul_8), kwargs = {})
# %ge_6 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_2, 0), kwargs = {})
# %lt_6 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_2, 352), kwargs = {})
# %ge_7 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_3, 0), kwargs = {})
# %lt_7 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_3, 352), kwargs = {})
# %logical_and_9 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_7, %lt_7), kwargs = {})
# %logical_and_10 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt_6, %logical_and_9), kwargs = {})
# %logical_and_11 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_6, %logical_and_10), kwargs = {})
# %index_3 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_10, %where_9]), kwargs = {})
# %sub_6 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add, %floor), kwargs = {})
# %sub_7 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_1, %floor_1), kwargs = {})
# %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %sub_7), kwargs = {})
# %full_default_11 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_11 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_11, %mul_5, %full_default_11), kwargs = {})
# %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_3, %where_11), kwargs = {})
# %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %mul_9), kwargs = {})
triton_poi_fused_grid_sampler_2d_0 = async_compile.triton('triton_poi_fused_grid_sampler_2d_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=132, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_grid_sampler_2d_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5F27510833D0212BDB44E6935BFB3CF3F2FC121004971730AF324EFA11390927', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_grid_sampler_2d_0(in_out_ptr0, in_ptr0, in_ptr1, xnumel, XBLOCK : tl.constexpr):
xnumel = 2230272
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 123904
x2 = (xindex // 371712)
x4 = (xindex // 123904)
x3 = xindex
tmp0 = tl.load(in_ptr0 + ((2*x0) + (247808*x2)), xmask, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr0 + (1 + (2*x0) + (247808*x2)), xmask, eviction_policy='evict_last')
tmp1 = 176.0
tmp2 = tmp0 * tmp1
tmp3 = 175.5
tmp4 = tmp2 + tmp3
tmp5 = libdevice.floor(tmp4)
tmp6 = 0.0
tmp7 = tmp5 >= tmp6
tmp8 = 352.0
tmp9 = tmp5 < tmp8
tmp11 = tmp10 * tmp1
tmp12 = tmp11 + tmp3
tmp13 = libdevice.floor(tmp12)
tmp14 = tmp13 >= tmp6
tmp15 = tmp13 < tmp8
tmp16 = tmp14 & tmp15
tmp17 = tmp9 & tmp16
tmp18 = tmp7 & tmp17
tmp19 = tmp13.to(tl.int64)
tmp20 = tl.full([1], 0, tl.int64)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = tl.full([XBLOCK], 352, tl.int32)
tmp23 = tmp21 + tmp22
tmp24 = tmp21 < 0
tmp25 = tl.where(tmp24, tmp23, tmp21)
tl.device_assert(((0 <= tmp25) & (tmp25 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp25 < 352")
tmp27 = tmp5.to(tl.int64)
tmp28 = tl.where(tmp18, tmp27, tmp20)
tmp29 = tmp28 + tmp22
tmp30 = tmp28 < 0
tmp31 = tl.where(tmp30, tmp29, tmp28)
tl.device_assert(((0 <= tmp31) & (tmp31 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp31 < 352")
tmp33 = tl.load(in_ptr1 + (tmp31 + (352*tmp25) + (123904*x4)), xmask, eviction_policy='evict_last')
tmp34 = 1.0
tmp35 = tmp5 + tmp34
tmp36 = tmp35 - tmp4
tmp37 = tmp13 + tmp34
tmp38 = tmp37 - tmp12
tmp39 = tmp36 * tmp38
tmp40 = tl.where(tmp18, tmp39, tmp6)
tmp41 = tmp33 * tmp40
tmp42 = tmp35 >= tmp6
tmp43 = tmp35 < tmp8
tmp44 = tmp43 & tmp16
tmp45 = tmp42 & tmp44
tmp46 = tl.where(tmp45, tmp19, tmp20)
tmp47 = tmp46 + tmp22
tmp48 = tmp46 < 0
tmp49 = tl.where(tmp48, tmp47, tmp46)
tl.device_assert(((0 <= tmp49) & (tmp49 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 352")
tmp51 = tmp35.to(tl.int64)
tmp52 = tl.where(tmp45, tmp51, tmp20)
tmp53 = tmp52 + tmp22
tmp54 = tmp52 < 0
tmp55 = tl.where(tmp54, tmp53, tmp52)
tl.device_assert(((0 <= tmp55) & (tmp55 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp55 < 352")
tmp57 = tl.load(in_ptr1 + (tmp55 + (352*tmp49) + (123904*x4)), xmask, eviction_policy='evict_last')
tmp58 = tmp4 - tmp5
tmp59 = tmp58 * tmp38
tmp60 = tl.where(tmp45, tmp59, tmp6)
tmp61 = tmp57 * tmp60
tmp62 = tmp37 >= tmp6
tmp63 = tmp37 < tmp8
tmp64 = tmp62 & tmp63
tmp65 = tmp9 & tmp64
tmp66 = tmp7 & tmp65
tmp67 = tmp37.to(tl.int64)
tmp68 = tl.where(tmp66, tmp67, tmp20)
tmp69 = tmp68 + tmp22
tmp70 = tmp68 < 0
tmp71 = tl.where(tmp70, tmp69, tmp68)
tl.device_assert(((0 <= tmp71) & (tmp71 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp71 < 352")
tmp73 = tl.where(tmp66, tmp27, tmp20)
tmp74 = tmp73 + tmp22
tmp75 = tmp73 < 0
tmp76 = tl.where(tmp75, tmp74, tmp73)
tl.device_assert(((0 <= tmp76) & (tmp76 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp76 < 352")
tmp78 = tl.load(in_ptr1 + (tmp76 + (352*tmp71) + (123904*x4)), xmask, eviction_policy='evict_last')
tmp79 = tmp12 - tmp13
tmp80 = tmp36 * tmp79
tmp81 = tl.where(tmp66, tmp80, tmp6)
tmp82 = tmp78 * tmp81
tmp83 = tmp43 & tmp64
tmp84 = tmp42 & tmp83
tmp85 = tl.where(tmp84, tmp67, tmp20)
tmp86 = tmp85 + tmp22
tmp87 = tmp85 < 0
tmp88 = tl.where(tmp87, tmp86, tmp85)
tl.device_assert(((0 <= tmp88) & (tmp88 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp88 < 352")
tmp90 = tl.where(tmp84, tmp51, tmp20)
tmp91 = tmp90 + tmp22
tmp92 = tmp90 < 0
tmp93 = tl.where(tmp92, tmp91, tmp90)
tl.device_assert(((0 <= tmp93) & (tmp93 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp93 < 352")
tmp95 = tl.load(in_ptr1 + (tmp93 + (352*tmp88) + (123904*x4)), xmask, eviction_policy='evict_last')
tmp96 = tmp58 * tmp79
tmp97 = tl.where(tmp84, tmp96, tmp6)
tmp98 = tmp95 * tmp97
tmp99 = tmp41 + tmp61
tmp100 = tmp99 + tmp82
tmp101 = tmp100 + tmp98
tl.store(in_out_ptr0 + (x3), tmp101, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (6, 3, 352, 352), (371712, 123904, 352, 1))
assert_size_stride(arg1_1, (6, 352, 352, 2), (247808, 704, 2, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((6, 3, 352, 352), (371712, 123904, 352, 1), torch.float32)
buf4 = buf0; del buf0 # reuse
# Topologically Sorted Source Nodes: [grid_sampler_2d_default], Original ATen: [aten.grid_sampler_2d]
stream0 = get_raw_stream(0)
triton_poi_fused_grid_sampler_2d_0.run(buf4, arg1_1, arg0_1, 2230272, grid=grid(2230272), stream=stream0)
del arg0_1
del arg1_1
return (buf4, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((6, 3, 352, 352), (371712, 123904, 352, 1), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((6, 352, 352, 2), (247808, 704, 2, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment