Created
November 25, 2024 19:33
-
-
Save davidberard98/9f7551f94d58f4a02423934316efd76d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_inference'] | |
from ctypes import c_void_p, c_long, c_int | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import ( | |
grid, | |
split_scan_grid, | |
grid_combo_kernels, | |
start_graph, | |
end_graph, | |
cooperative_reduction_grid, | |
) | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
# kernel path: /tmp/torchinductor_dberard/jp/cjppgi2xgpusm3gtom2ox3wa7qq5qs6yldf5xxci2jgbmg35tnmb.py | |
# Topologically Sorted Source Nodes: [grid_sampler_2d_default], Original ATen: [aten.grid_sampler_2d] | |
# Source node to ATen node mapping: | |
# grid_sampler_2d_default => add, add_1, add_2, add_3, add_4, add_5, add_6, floor, floor_1, full_default_11, full_default_2, full_default_5, full_default_8, ge, ge_1, ge_2, ge_3, ge_4, ge_5, ge_6, ge_7, index, index_1, index_2, index_3, logical_and, logical_and_1, logical_and_10, logical_and_11, logical_and_2, logical_and_3, logical_and_4, logical_and_5, logical_and_6, logical_and_7, logical_and_8, logical_and_9, lt, lt_1, lt_2, lt_3, lt_4, lt_5, lt_6, lt_7, mul, mul_1, mul_2, mul_3, mul_4, mul_5, mul_6, mul_7, mul_8, mul_9, sub, sub_1, sub_2, sub_3, sub_4, sub_5, sub_6, sub_7, where_11, where_2, where_5, where_8 | |
# Graph fragment: | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select, 176.0), kwargs = {}) | |
# %add : [num_users=5] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, 175.5), kwargs = {}) | |
# %floor : [num_users=9] = call_function[target=torch.ops.aten.floor.default](args = (%add,), kwargs = {}) | |
# %ge : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor, 0), kwargs = {}) | |
# %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor, 352), kwargs = {}) | |
# %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_1, 176.0), kwargs = {}) | |
# %add_1 : [num_users=5] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, 175.5), kwargs = {}) | |
# %floor_1 : [num_users=9] = call_function[target=torch.ops.aten.floor.default](args = (%add_1,), kwargs = {}) | |
# %ge_1 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor_1, 0), kwargs = {}) | |
# %lt_1 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor_1, 352), kwargs = {}) | |
# %logical_and : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_1, %lt_1), kwargs = {}) | |
# %logical_and_1 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt, %logical_and), kwargs = {}) | |
# %logical_and_2 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge, %logical_and_1), kwargs = {}) | |
# %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_1, %where]), kwargs = {}) | |
# %add_2 : [num_users=8] = call_function[target=torch.ops.aten.add.Tensor](args = (%floor, 1), kwargs = {}) | |
# %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_2, %add), kwargs = {}) | |
# %add_3 : [num_users=8] = call_function[target=torch.ops.aten.add.Tensor](args = (%floor_1, 1), kwargs = {}) | |
# %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_3, %add_1), kwargs = {}) | |
# %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub_1), kwargs = {}) | |
# %full_default_2 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_2, %mul_2, %full_default_2), kwargs = {}) | |
# %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index, %where_2), kwargs = {}) | |
# %ge_2 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_2, 0), kwargs = {}) | |
# %lt_2 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_2, 352), kwargs = {}) | |
# %ge_3 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor_1, 0), kwargs = {}) | |
# %lt_3 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor_1, 352), kwargs = {}) | |
# %logical_and_3 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_3, %lt_3), kwargs = {}) | |
# %logical_and_4 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt_2, %logical_and_3), kwargs = {}) | |
# %logical_and_5 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_2, %logical_and_4), kwargs = {}) | |
# %index_1 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_4, %where_3]), kwargs = {}) | |
# %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add, %floor), kwargs = {}) | |
# %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_3, %add_1), kwargs = {}) | |
# %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_2, %sub_3), kwargs = {}) | |
# %full_default_5 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_5 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_5, %mul_3, %full_default_5), kwargs = {}) | |
# %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_1, %where_5), kwargs = {}) | |
# %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {}) | |
# %ge_4 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%floor, 0), kwargs = {}) | |
# %lt_4 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%floor, 352), kwargs = {}) | |
# %ge_5 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_3, 0), kwargs = {}) | |
# %lt_5 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_3, 352), kwargs = {}) | |
# %logical_and_6 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_5, %lt_5), kwargs = {}) | |
# %logical_and_7 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt_4, %logical_and_6), kwargs = {}) | |
# %logical_and_8 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_4, %logical_and_7), kwargs = {}) | |
# %index_2 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_7, %where_6]), kwargs = {}) | |
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_2, %add), kwargs = {}) | |
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_1, %floor_1), kwargs = {}) | |
# %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_4, %sub_5), kwargs = {}) | |
# %full_default_8 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_8, %mul_4, %full_default_8), kwargs = {}) | |
# %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_2, %where_8), kwargs = {}) | |
# %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %mul_8), kwargs = {}) | |
# %ge_6 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_2, 0), kwargs = {}) | |
# %lt_6 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_2, 352), kwargs = {}) | |
# %ge_7 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%add_3, 0), kwargs = {}) | |
# %lt_7 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%add_3, 352), kwargs = {}) | |
# %logical_and_9 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_7, %lt_7), kwargs = {}) | |
# %logical_and_10 : [num_users=1] = call_function[target=torch.ops.aten.logical_and.default](args = (%lt_6, %logical_and_9), kwargs = {}) | |
# %logical_and_11 : [num_users=3] = call_function[target=torch.ops.aten.logical_and.default](args = (%ge_6, %logical_and_10), kwargs = {}) | |
# %index_3 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%view_1, %view_2, %where_10, %where_9]), kwargs = {}) | |
# %sub_6 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add, %floor), kwargs = {}) | |
# %sub_7 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_1, %floor_1), kwargs = {}) | |
# %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %sub_7), kwargs = {}) | |
# %full_default_11 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_11 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_and_11, %mul_5, %full_default_11), kwargs = {}) | |
# %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_3, %where_11), kwargs = {}) | |
# %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %mul_9), kwargs = {}) | |
triton_poi_fused_grid_sampler_2d_0 = async_compile.triton('triton_poi_fused_grid_sampler_2d_0', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints=[4194304], | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=132, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_grid_sampler_2d_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5F27510833D0212BDB44E6935BFB3CF3F2FC121004971730AF324EFA11390927', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_grid_sampler_2d_0(in_out_ptr0, in_ptr0, in_ptr1, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 2230272 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex % 123904 | |
x2 = (xindex // 371712) | |
x4 = (xindex // 123904) | |
x3 = xindex | |
tmp0 = tl.load(in_ptr0 + ((2*x0) + (247808*x2)), xmask, eviction_policy='evict_last') | |
tmp10 = tl.load(in_ptr0 + (1 + (2*x0) + (247808*x2)), xmask, eviction_policy='evict_last') | |
tmp1 = 176.0 | |
tmp2 = tmp0 * tmp1 | |
tmp3 = 175.5 | |
tmp4 = tmp2 + tmp3 | |
tmp5 = libdevice.floor(tmp4) | |
tmp6 = 0.0 | |
tmp7 = tmp5 >= tmp6 | |
tmp8 = 352.0 | |
tmp9 = tmp5 < tmp8 | |
tmp11 = tmp10 * tmp1 | |
tmp12 = tmp11 + tmp3 | |
tmp13 = libdevice.floor(tmp12) | |
tmp14 = tmp13 >= tmp6 | |
tmp15 = tmp13 < tmp8 | |
tmp16 = tmp14 & tmp15 | |
tmp17 = tmp9 & tmp16 | |
tmp18 = tmp7 & tmp17 | |
tmp19 = tmp13.to(tl.int64) | |
tmp20 = tl.full([1], 0, tl.int64) | |
tmp21 = tl.where(tmp18, tmp19, tmp20) | |
tmp22 = tl.full([XBLOCK], 352, tl.int32) | |
tmp23 = tmp21 + tmp22 | |
tmp24 = tmp21 < 0 | |
tmp25 = tl.where(tmp24, tmp23, tmp21) | |
tl.device_assert(((0 <= tmp25) & (tmp25 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp25 < 352") | |
tmp27 = tmp5.to(tl.int64) | |
tmp28 = tl.where(tmp18, tmp27, tmp20) | |
tmp29 = tmp28 + tmp22 | |
tmp30 = tmp28 < 0 | |
tmp31 = tl.where(tmp30, tmp29, tmp28) | |
tl.device_assert(((0 <= tmp31) & (tmp31 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp31 < 352") | |
tmp33 = tl.load(in_ptr1 + (tmp31 + (352*tmp25) + (123904*x4)), xmask, eviction_policy='evict_last') | |
tmp34 = 1.0 | |
tmp35 = tmp5 + tmp34 | |
tmp36 = tmp35 - tmp4 | |
tmp37 = tmp13 + tmp34 | |
tmp38 = tmp37 - tmp12 | |
tmp39 = tmp36 * tmp38 | |
tmp40 = tl.where(tmp18, tmp39, tmp6) | |
tmp41 = tmp33 * tmp40 | |
tmp42 = tmp35 >= tmp6 | |
tmp43 = tmp35 < tmp8 | |
tmp44 = tmp43 & tmp16 | |
tmp45 = tmp42 & tmp44 | |
tmp46 = tl.where(tmp45, tmp19, tmp20) | |
tmp47 = tmp46 + tmp22 | |
tmp48 = tmp46 < 0 | |
tmp49 = tl.where(tmp48, tmp47, tmp46) | |
tl.device_assert(((0 <= tmp49) & (tmp49 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 352") | |
tmp51 = tmp35.to(tl.int64) | |
tmp52 = tl.where(tmp45, tmp51, tmp20) | |
tmp53 = tmp52 + tmp22 | |
tmp54 = tmp52 < 0 | |
tmp55 = tl.where(tmp54, tmp53, tmp52) | |
tl.device_assert(((0 <= tmp55) & (tmp55 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp55 < 352") | |
tmp57 = tl.load(in_ptr1 + (tmp55 + (352*tmp49) + (123904*x4)), xmask, eviction_policy='evict_last') | |
tmp58 = tmp4 - tmp5 | |
tmp59 = tmp58 * tmp38 | |
tmp60 = tl.where(tmp45, tmp59, tmp6) | |
tmp61 = tmp57 * tmp60 | |
tmp62 = tmp37 >= tmp6 | |
tmp63 = tmp37 < tmp8 | |
tmp64 = tmp62 & tmp63 | |
tmp65 = tmp9 & tmp64 | |
tmp66 = tmp7 & tmp65 | |
tmp67 = tmp37.to(tl.int64) | |
tmp68 = tl.where(tmp66, tmp67, tmp20) | |
tmp69 = tmp68 + tmp22 | |
tmp70 = tmp68 < 0 | |
tmp71 = tl.where(tmp70, tmp69, tmp68) | |
tl.device_assert(((0 <= tmp71) & (tmp71 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp71 < 352") | |
tmp73 = tl.where(tmp66, tmp27, tmp20) | |
tmp74 = tmp73 + tmp22 | |
tmp75 = tmp73 < 0 | |
tmp76 = tl.where(tmp75, tmp74, tmp73) | |
tl.device_assert(((0 <= tmp76) & (tmp76 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp76 < 352") | |
tmp78 = tl.load(in_ptr1 + (tmp76 + (352*tmp71) + (123904*x4)), xmask, eviction_policy='evict_last') | |
tmp79 = tmp12 - tmp13 | |
tmp80 = tmp36 * tmp79 | |
tmp81 = tl.where(tmp66, tmp80, tmp6) | |
tmp82 = tmp78 * tmp81 | |
tmp83 = tmp43 & tmp64 | |
tmp84 = tmp42 & tmp83 | |
tmp85 = tl.where(tmp84, tmp67, tmp20) | |
tmp86 = tmp85 + tmp22 | |
tmp87 = tmp85 < 0 | |
tmp88 = tl.where(tmp87, tmp86, tmp85) | |
tl.device_assert(((0 <= tmp88) & (tmp88 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp88 < 352") | |
tmp90 = tl.where(tmp84, tmp51, tmp20) | |
tmp91 = tmp90 + tmp22 | |
tmp92 = tmp90 < 0 | |
tmp93 = tl.where(tmp92, tmp91, tmp90) | |
tl.device_assert(((0 <= tmp93) & (tmp93 < 352)) | ~(xmask), "index out of bounds: 0 <= tmp93 < 352") | |
tmp95 = tl.load(in_ptr1 + (tmp93 + (352*tmp88) + (123904*x4)), xmask, eviction_policy='evict_last') | |
tmp96 = tmp58 * tmp79 | |
tmp97 = tl.where(tmp84, tmp96, tmp6) | |
tmp98 = tmp95 * tmp97 | |
tmp99 = tmp41 + tmp61 | |
tmp100 = tmp99 + tmp82 | |
tmp101 = tmp100 + tmp98 | |
tl.store(in_out_ptr0 + (x3), tmp101, xmask) | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (6, 3, 352, 352), (371712, 123904, 352, 1)) | |
assert_size_stride(arg1_1, (6, 352, 352, 2), (247808, 704, 2, 1)) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf0 = empty_strided_cuda((6, 3, 352, 352), (371712, 123904, 352, 1), torch.float32) | |
buf4 = buf0; del buf0 # reuse | |
# Topologically Sorted Source Nodes: [grid_sampler_2d_default], Original ATen: [aten.grid_sampler_2d] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_grid_sampler_2d_0.run(buf4, arg1_1, arg0_1, 2230272, grid=grid(2230272), stream=stream0) | |
del arg0_1 | |
del arg1_1 | |
return (buf4, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((6, 3, 352, 352), (371712, 123904, 352, 1), device='cuda:0', dtype=torch.float32) | |
arg1_1 = rand_strided((6, 352, 352, 2), (247808, 704, 2, 1), device='cuda:0', dtype=torch.float32) | |
fn = lambda: call([arg0_1, arg1_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment