Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created October 30, 2024 12:07
Show Gist options
  • Select an option

  • Save leslie-fang-intel/425116e5d47e0928d9ce8f164780256d to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/425116e5d47e0928d9ce8f164780256d to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
cpp_fused_add_index_index_put_0 = async_compile.cpp_pybinding(['const int64_t*', 'const float*', 'int64_t*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/2r/c2rnilspx43ivnzu4uieul65kx65dfhfbptbh5og4wk6rqebuxoo.h"
extern "C" void kernel(const int64_t* in_ptr0,
const float* in_ptr1,
int64_t* out_ptr0,
float* out_ptr1,
float* out_ptr2)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = static_cast<int64_t>(10);
auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
out_ptr0[static_cast<int64_t>(x0)] = tmp2;
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = out_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = 2L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 < 0;
auto tmp5 = tmp4 ? tmp3 : tmp0;
auto tmp6 = tmp5;
auto tmp7 = c10::convert<int64_t>(tmp6);
AOTI_TORCH_CHECK((0 <= tmp7) & (tmp7 < 2L), "index out of bounds: 0 <= tmp7 < 2L");
auto tmp9 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1 + (32L*tmp5)), static_cast<int64_t>(16));
tmp9.store(out_ptr1 + static_cast<int64_t>(x1 + (32L*x0)));
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = out_ptr0[static_cast<int64_t>(x0)];
auto tmp9 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (32L*x0)), static_cast<int64_t>(16));
auto tmp1 = 2L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 < 0;
auto tmp5 = tmp4 ? tmp3 : tmp0;
auto tmp6 = tmp5;
auto tmp7 = c10::convert<int64_t>(tmp6);
AOTI_TORCH_CHECK((0 <= tmp7) & (tmp7 < 2L), "index out of bounds: 0 <= tmp7 < 2L");
auto tmp10 = tmp9 + tmp9;
tmp10.store(out_ptr2 + static_cast<int64_t>(x1 + (32L*tmp5)));
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 2), (2, 1))
assert_size_stride(arg1_1, (2, 32), (32, 1))
buf0 = empty_strided_cpu((1, 2), (2, 1), torch.int64)
buf1 = empty_strided_cpu((1, 2, 32), (64, 32, 1), torch.float32)
cpp_fused_add_index_index_put_0(arg0_1, arg1_1, buf0, buf1, arg1_1)
del arg0_1
del arg1_1
return ()
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1, 2), (2, 1), device='cpu', dtype=torch.int64)
arg1_1 = rand_strided((2, 32), (32, 1), device='cpu', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment