Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created March 6, 2025 06:17
Show Gist options
  • Select an option

  • Save leslie-fang-intel/1b19b20e17a7fd3c3ec09d7d583a2b91 to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/1b19b20e17a7fd3c3ec09d7d583a2b91 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
_frozen_param16 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f81cfb36250
_frozen_param30 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f81cfb37f60
_frozen_param18 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f81cfb37330
_frozen_param31 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f81cff64180
_frozen_param21 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f81cffb8db0
_frozen_param32 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f81cfb34d60
_frozen_param23 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f81cfb376f0
_frozen_param33 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f81cfb34770
_frozen_param26 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f81cfb37880
_frozen_param34 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f81cfb35850
_frozen_param28 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f81cfb37a60
_frozen_param35 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f81cfb35710
cpp_fused_index_select_new_zeros_scatter_add_0 = async_compile.cpp_pybinding(['const int64_t*', 'const float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(const int64_t* in_ptr0,
const float* in_ptr1,
float* out_ptr0,
float* out_ptr1)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = static_cast<float>(0.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(200000L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(64L)))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = 10000L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 < 0;
auto tmp5 = tmp4 ? tmp3 : tmp0;
auto tmp6 = tmp5;
auto tmp7 = c10::convert<int64_t>(tmp6);
TORCH_CHECK((0 <= tmp7) & (tmp7 < 10000L), "index out of bounds: 0 <= tmp7 < 10000L");
auto tmp9 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1 + 64L*tmp5), static_cast<int64_t>(16));
tmp9.store(out_ptr1 + static_cast<int64_t>(x1 + 64L*x0));
}
}
}
}
}
}
''')
cpp_fused__to_copy_add_mul_1 = async_compile.cpp_pybinding(['const float*', 'const float*', 'bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
bfloat16* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = at::vec::VectorizedN<float,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp1 = at::vec::VectorizedN<float,2>::loadu(in_ptr1 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp2 = static_cast<float>(1.0);
auto tmp3 = at::vec::VectorizedN<float,2>(tmp2);
auto tmp4 = tmp3 * tmp1;
auto tmp5 = tmp0 + tmp4;
auto tmp6 = at::vec::convert<bfloat16,1,float,2>(tmp5);
tmp6.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
}
''')
cpp_fused_index_select_new_zeros_scatter_add_2 = async_compile.cpp_pybinding(['const int64_t*', 'const bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(const int64_t* in_ptr0,
const bfloat16* in_ptr1,
bfloat16* out_ptr0,
bfloat16* out_ptr1)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = static_cast<bfloat16>(0.0);
auto tmp1 = at::vec::Vectorized<bfloat16>(tmp0);
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(200000L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(64L)))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = 10000L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 < 0;
auto tmp5 = tmp4 ? tmp3 : tmp0;
auto tmp6 = tmp5;
auto tmp7 = c10::convert<int64_t>(tmp6);
TORCH_CHECK((0 <= tmp7) & (tmp7 < 10000L), "index out of bounds: 0 <= tmp7 < 10000L");
auto tmp9 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<int64_t>(x1 + 64L*tmp5), static_cast<int64_t>(32));
tmp9.store(out_ptr1 + static_cast<int64_t>(x1 + 64L*x0), static_cast<int64_t>(32));
}
}
}
}
}
}
''')
cpp_fused__to_copy_add_mul_3 = async_compile.cpp_pybinding(['bfloat16*', 'const bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(bfloat16* in_out_ptr0,
const bfloat16* in_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp2 = at::vec::Vectorized<bfloat16>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp1 = at::vec::convert<float,2,bfloat16,1>(tmp0);
auto tmp3 = at::vec::convert<float,2,bfloat16,1>(tmp2);
auto tmp4 = static_cast<float>(1.0);
auto tmp5 = at::vec::VectorizedN<float,2>(tmp4);
auto tmp6 = tmp5 * tmp3;
auto tmp7 = tmp1 + tmp6;
auto tmp8 = at::vec::convert<bfloat16,1,float,2>(tmp7);
tmp8.store(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
}
''')
cpp_fused_index_select_new_zeros_scatter_add_4 = async_compile.cpp_pybinding(['const int64_t*', 'const bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(const int64_t* in_ptr0,
const bfloat16* in_ptr1,
bfloat16* out_ptr0,
bfloat16* out_ptr1)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = static_cast<bfloat16>(0.0);
auto tmp1 = at::vec::Vectorized<bfloat16>(tmp0);
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(200000L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(64L)))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = 10000L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 < 0;
auto tmp5 = tmp4 ? tmp3 : tmp0;
auto tmp6 = tmp5;
auto tmp7 = c10::convert<int64_t>(tmp6);
TORCH_CHECK((0 <= tmp7) & (tmp7 < 10000L), "index out of bounds: 0 <= tmp7 < 10000L");
auto tmp9 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<int64_t>(x1 + 64L*tmp5), static_cast<int64_t>(32));
tmp9.store(out_ptr1 + static_cast<int64_t>(x1 + 64L*x0), static_cast<int64_t>(32));
}
}
}
}
}
}
''')
cpp_fused__to_copy_add_mul_5 = async_compile.cpp_pybinding(['bfloat16*', 'const bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(bfloat16* in_out_ptr0,
const bfloat16* in_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp2 = at::vec::Vectorized<bfloat16>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp1 = at::vec::convert<float,2,bfloat16,1>(tmp0);
auto tmp3 = at::vec::convert<float,2,bfloat16,1>(tmp2);
auto tmp4 = static_cast<float>(1.0);
auto tmp5 = at::vec::VectorizedN<float,2>(tmp4);
auto tmp6 = tmp5 * tmp3;
auto tmp7 = tmp1 + tmp6;
auto tmp8 = at::vec::convert<bfloat16,1,float,2>(tmp7);
tmp8.store(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg15_1, arg16_1 = args
args.clear()
assert_size_stride(arg15_1, (10000, 64), (64, 1))
assert_size_stride(arg16_1, (2, 200000), (200000, 1))
buf0 = empty_strided_cpu((10000, 64), (64, 1), torch.float32)
buf1 = empty_strided_cpu((200000, 64), (64, 1), torch.float32)
cpp_fused_index_select_new_zeros_scatter_add_0(arg16_1, arg15_1, buf0, buf1)
aten.scatter_reduce_.two(buf0,0,reinterpret_tensor(arg16_1, (200000, 64), (1, 0), 200000),buf1, reduce='sum', include_self=True)
del buf1
buf3 = empty_strided_cpu((10000, 64), (64, 1), torch.bfloat16)
cpp_fused__to_copy_add_mul_1(buf0, arg15_1, buf3)
del arg15_1
del buf0
# Topologically Sorted Source Nodes: [mul, add_1, linear, relu], Original ATen: [aten.mul, aten.add, aten._to_copy, aten.relu]
buf4 = torch.ops.mkldnn._linear_pointwise.default(buf3, _frozen_param30, _frozen_param16, 'relu', [-1], '')
del buf3
buf5 = buf4
assert_size_stride(buf5, (10000, 64), (64, 1))
del buf4
# Topologically Sorted Source Nodes: [relu_1], Original ATen: [aten.relu]
buf6 = torch.ops.mkldnn._linear_pointwise.default(buf5, _frozen_param31, _frozen_param18, 'relu', [-1], '')
buf7 = buf6
assert_size_stride(buf7, (10000, 64), (64, 1))
del buf6
buf8 = buf5; del buf5 # reuse
buf9 = empty_strided_cpu((200000, 64), (64, 1), torch.bfloat16)
cpp_fused_index_select_new_zeros_scatter_add_2(arg16_1, buf7, buf8, buf9)
aten.scatter_reduce_.two(buf8,0,reinterpret_tensor(arg16_1, (200000, 64), (1, 0), 200000),buf9, reduce='sum', include_self=True)
buf11 = buf7; del buf7 # reuse
cpp_fused__to_copy_add_mul_3(buf11, buf8)
del buf8
# Topologically Sorted Source Nodes: [mul_1, add_3, linear_2, relu_2], Original ATen: [aten.mul, aten.add, aten._to_copy, aten.relu]
buf12 = torch.ops.mkldnn._linear_pointwise.default(buf11, _frozen_param32, _frozen_param21, 'relu', [-1], '')
del buf11
buf13 = buf12
assert_size_stride(buf13, (10000, 64), (64, 1))
del buf12
# Topologically Sorted Source Nodes: [relu_3], Original ATen: [aten.relu]
buf14 = torch.ops.mkldnn._linear_pointwise.default(buf13, _frozen_param33, _frozen_param23, 'relu', [-1], '')
buf15 = buf14
assert_size_stride(buf15, (10000, 64), (64, 1))
del buf14
buf16 = buf13; del buf13 # reuse
buf17 = buf9; del buf9 # reuse
cpp_fused_index_select_new_zeros_scatter_add_4(arg16_1, buf15, buf16, buf17)
aten.scatter_reduce_.two(buf16,0,reinterpret_tensor(arg16_1, (200000, 64), (1, 0), 200000),buf17, reduce='sum', include_self=True)
del arg16_1
del buf17
buf19 = buf15; del buf15 # reuse
cpp_fused__to_copy_add_mul_5(buf19, buf16)
del buf16
# Topologically Sorted Source Nodes: [mul_2, add_5, linear_4, relu_4], Original ATen: [aten.mul, aten.add, aten._to_copy, aten.relu]
buf20 = torch.ops.mkldnn._linear_pointwise.default(buf19, _frozen_param34, _frozen_param26, 'relu', [-1], '')
del buf19
buf21 = buf20
assert_size_stride(buf21, (10000, 64), (64, 1))
del buf20
# Topologically Sorted Source Nodes: [linear_5], Original ATen: [aten.addmm]
buf22 = torch.ops.mkldnn._linear_pointwise.default(buf21, _frozen_param35, _frozen_param28, 'none', [-1], '')
del buf21
buf23 = buf22
assert_size_stride(buf23, (10000, 64), (64, 1))
return (buf23, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param16
_frozen_param16 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param30
_frozen_param30 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param18
_frozen_param18 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param31
_frozen_param31 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param21
_frozen_param21 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param32
_frozen_param32 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param23
_frozen_param23 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param33
_frozen_param33 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param26
_frozen_param26 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param34
_frozen_param34 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param28
_frozen_param28 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param35
_frozen_param35 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
arg15_1 = rand_strided((10000, 64), (64, 1), device='cpu', dtype=torch.float32)
arg16_1 = rand_strided((2, 200000), (200000, 1), device='cpu', dtype=torch.int64)
fn = lambda: call([arg15_1, arg16_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('basic_gnn_gin', benchmark_compiled_module)
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
_frozen_param16 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f45e32ab470
_frozen_param30 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f45e32db880
_frozen_param18 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f45e32abec0
_frozen_param31 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f45e32db740
_frozen_param21 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f45e32a8ef0
_frozen_param32 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f45e32db560
_frozen_param23 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f45e32a8400
_frozen_param33 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f45e32db920
_frozen_param26 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f45e32a89f0
_frozen_param34 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f45e32db8d0
_frozen_param28 = None # device(type='cpu') torch.bfloat16 (64,) (1,) 7f45e337b3d0
_frozen_param35 = None # device(type='cpu') torch.bfloat16 (64, 64) (1, 0) 7f45e32db5b0
cpp_fused__to_copy_add_index_select_mul_new_zeros_scatter_add_0 = async_compile.cpp_pybinding(['const int64_t*', 'const float*', 'float*', 'bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(const int64_t* in_ptr0,
const float* in_ptr1,
float* out_ptr0,
bfloat16* out_ptr1)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = static_cast<float>(0.0);
auto tmp1 = at::vec::Vectorized<float>(tmp0);
tmp1.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(200000L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(64L)))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(200000L + x0)];
auto tmp4 = in_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = tmp0;
auto tmp2 = c10::convert<int64_t>(tmp1);
TORCH_CHECK((0 <= tmp2) & (tmp2 < 10000L), "index out of bounds: 0 <= tmp2 < 10000L");
auto tmp5 = 10000L;
auto tmp6 = c10::convert<int64_t>(tmp5);
auto tmp7 = decltype(tmp4)(tmp4 + tmp6);
auto tmp8 = tmp4 < 0;
auto tmp9 = tmp8 ? tmp7 : tmp4;
auto tmp10 = tmp9;
auto tmp11 = c10::convert<int64_t>(tmp10);
TORCH_CHECK((0 <= tmp11) & (tmp11 < 10000L), "index out of bounds: 0 <= tmp11 < 10000L");
auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1 + 64L*tmp9), static_cast<int64_t>(16));
(tmp13 + at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x1 + 64L*tmp0))).store(out_ptr0 + static_cast<int64_t>(x1 + 64L*tmp0));
}
}
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = at::vec::VectorizedN<float,2>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp1 = at::vec::VectorizedN<float,2>::loadu(in_ptr1 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp2 = static_cast<float>(1.0);
auto tmp3 = at::vec::VectorizedN<float,2>(tmp2);
auto tmp4 = tmp3 * tmp1;
auto tmp5 = tmp0 + tmp4;
auto tmp6 = at::vec::convert<bfloat16,1,float,2>(tmp5);
tmp6.store(out_ptr1 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
}
''')
cpp_fused__to_copy_add_index_select_mul_new_zeros_scatter_add_1 = async_compile.cpp_pybinding(['bfloat16*', 'const int64_t*', 'bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(bfloat16* in_out_ptr0,
const int64_t* in_ptr0,
bfloat16* out_ptr0)
{
auto in_ptr1 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = static_cast<bfloat16>(0.0);
auto tmp1 = at::vec::Vectorized<bfloat16>(tmp0);
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(200000L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(64L)))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(200000L + x0)];
auto tmp4 = in_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = tmp0;
auto tmp2 = c10::convert<int64_t>(tmp1);
TORCH_CHECK((0 <= tmp2) & (tmp2 < 10000L), "index out of bounds: 0 <= tmp2 < 10000L");
auto tmp5 = 10000L;
auto tmp6 = c10::convert<int64_t>(tmp5);
auto tmp7 = decltype(tmp4)(tmp4 + tmp6);
auto tmp8 = tmp4 < 0;
auto tmp9 = tmp8 ? tmp7 : tmp4;
auto tmp10 = tmp9;
auto tmp11 = c10::convert<int64_t>(tmp10);
TORCH_CHECK((0 <= tmp11) & (tmp11 < 10000L), "index out of bounds: 0 <= tmp11 < 10000L");
auto tmp13 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<int64_t>(x1 + 64L*tmp9), static_cast<int64_t>(32));
(tmp13 + at::vec::Vectorized<bfloat16>::loadu(out_ptr0 + static_cast<int64_t>(x1 + 64L*tmp0), static_cast<int64_t>(32))).store(out_ptr0 + static_cast<int64_t>(x1 + 64L*tmp0), static_cast<int64_t>(32));
}
}
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp2 = at::vec::Vectorized<bfloat16>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp1 = at::vec::convert<float,2,bfloat16,1>(tmp0);
auto tmp3 = at::vec::convert<float,2,bfloat16,1>(tmp2);
auto tmp4 = static_cast<float>(1.0);
auto tmp5 = at::vec::VectorizedN<float,2>(tmp4);
auto tmp6 = tmp5 * tmp3;
auto tmp7 = tmp1 + tmp6;
auto tmp8 = at::vec::convert<bfloat16,1,float,2>(tmp7);
tmp8.store(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
}
''')
cpp_fused__to_copy_add_index_select_mul_new_zeros_scatter_add_2 = async_compile.cpp_pybinding(['bfloat16*', 'const int64_t*', 'bfloat16*'], '''
#include "/tmp/torchinductor_leslie/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(bfloat16* in_out_ptr0,
const int64_t* in_ptr0,
bfloat16* out_ptr0)
{
auto in_ptr1 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = static_cast<bfloat16>(0.0);
auto tmp1 = at::vec::Vectorized<bfloat16>(tmp0);
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(200000L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(64L)))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(200000L + x0)];
auto tmp4 = in_ptr0[static_cast<int64_t>(x0)];
auto tmp1 = tmp0;
auto tmp2 = c10::convert<int64_t>(tmp1);
TORCH_CHECK((0 <= tmp2) & (tmp2 < 10000L), "index out of bounds: 0 <= tmp2 < 10000L");
auto tmp5 = 10000L;
auto tmp6 = c10::convert<int64_t>(tmp5);
auto tmp7 = decltype(tmp4)(tmp4 + tmp6);
auto tmp8 = tmp4 < 0;
auto tmp9 = tmp8 ? tmp7 : tmp4;
auto tmp10 = tmp9;
auto tmp11 = c10::convert<int64_t>(tmp10);
TORCH_CHECK((0 <= tmp11) & (tmp11 < 10000L), "index out of bounds: 0 <= tmp11 < 10000L");
auto tmp13 = at::vec::Vectorized<bfloat16>::loadu(in_ptr1 + static_cast<int64_t>(x1 + 64L*tmp9), static_cast<int64_t>(32));
(tmp13 + at::vec::Vectorized<bfloat16>::loadu(out_ptr0 + static_cast<int64_t>(x1 + 64L*tmp0), static_cast<int64_t>(32))).store(out_ptr0 + static_cast<int64_t>(x1 + 64L*tmp0), static_cast<int64_t>(32));
}
}
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(640000L); x0+=static_cast<int64_t>(32L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(640000L)))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp2 = at::vec::Vectorized<bfloat16>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
auto tmp1 = at::vec::convert<float,2,bfloat16,1>(tmp0);
auto tmp3 = at::vec::convert<float,2,bfloat16,1>(tmp2);
auto tmp4 = static_cast<float>(1.0);
auto tmp5 = at::vec::VectorizedN<float,2>(tmp4);
auto tmp6 = tmp5 * tmp3;
auto tmp7 = tmp1 + tmp6;
auto tmp8 = at::vec::convert<bfloat16,1,float,2>(tmp7);
tmp8.store(in_out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg15_1, arg16_1 = args
args.clear()
assert_size_stride(arg15_1, (10000, 64), (64, 1))
assert_size_stride(arg16_1, (2, 200000), (200000, 1))
buf0 = empty_strided_cpu((10000, 64), (64, 1), torch.float32)
buf2 = empty_strided_cpu((10000, 64), (64, 1), torch.bfloat16)
cpp_fused__to_copy_add_index_select_mul_new_zeros_scatter_add_0(arg16_1, arg15_1, buf0, buf2)
del arg15_1
del buf0
# Topologically Sorted Source Nodes: [mul, add_1, linear, relu], Original ATen: [aten.mul, aten.add, aten._to_copy, aten.relu]
buf3 = torch.ops.mkldnn._linear_pointwise.default(buf2, _frozen_param30, _frozen_param16, 'relu', [-1], '')
del buf2
buf4 = buf3
assert_size_stride(buf4, (10000, 64), (64, 1))
del buf3
# Topologically Sorted Source Nodes: [relu_1], Original ATen: [aten.relu]
buf5 = torch.ops.mkldnn._linear_pointwise.default(buf4, _frozen_param31, _frozen_param18, 'relu', [-1], '')
buf6 = buf5
assert_size_stride(buf6, (10000, 64), (64, 1))
del buf5
buf7 = buf4; del buf4 # reuse
buf9 = buf6; del buf6 # reuse
cpp_fused__to_copy_add_index_select_mul_new_zeros_scatter_add_1(buf9, arg16_1, buf7)
del buf7
# Topologically Sorted Source Nodes: [mul_1, add_3, linear_2, relu_2], Original ATen: [aten.mul, aten.add, aten._to_copy, aten.relu]
buf10 = torch.ops.mkldnn._linear_pointwise.default(buf9, _frozen_param32, _frozen_param21, 'relu', [-1], '')
del buf9
buf11 = buf10
assert_size_stride(buf11, (10000, 64), (64, 1))
del buf10
# Topologically Sorted Source Nodes: [relu_3], Original ATen: [aten.relu]
buf12 = torch.ops.mkldnn._linear_pointwise.default(buf11, _frozen_param33, _frozen_param23, 'relu', [-1], '')
buf13 = buf12
assert_size_stride(buf13, (10000, 64), (64, 1))
del buf12
buf14 = buf11; del buf11 # reuse
buf16 = buf13; del buf13 # reuse
cpp_fused__to_copy_add_index_select_mul_new_zeros_scatter_add_2(buf16, arg16_1, buf14)
del arg16_1
del buf14
# Topologically Sorted Source Nodes: [mul_2, add_5, linear_4, relu_4], Original ATen: [aten.mul, aten.add, aten._to_copy, aten.relu]
buf17 = torch.ops.mkldnn._linear_pointwise.default(buf16, _frozen_param34, _frozen_param26, 'relu', [-1], '')
del buf16
buf18 = buf17
assert_size_stride(buf18, (10000, 64), (64, 1))
del buf17
# Topologically Sorted Source Nodes: [linear_5], Original ATen: [aten.addmm]
buf19 = torch.ops.mkldnn._linear_pointwise.default(buf18, _frozen_param35, _frozen_param28, 'none', [-1], '')
del buf18
buf20 = buf19
assert_size_stride(buf20, (10000, 64), (64, 1))
return (buf20, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param16
_frozen_param16 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param30
_frozen_param30 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param18
_frozen_param18 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param31
_frozen_param31 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param21
_frozen_param21 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param32
_frozen_param32 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param23
_frozen_param23 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param33
_frozen_param33 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param26
_frozen_param26 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param34
_frozen_param34 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param28
_frozen_param28 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.bfloat16)
global _frozen_param35
_frozen_param35 = rand_strided((64, 64), (1, 0), device='cpu', dtype=torch.bfloat16)
arg15_1 = rand_strided((10000, 64), (64, 1), device='cpu', dtype=torch.float32)
arg16_1 = rand_strided((2, 200000), (200000, 1), device='cpu', dtype=torch.int64)
fn = lambda: call([arg15_1, arg16_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('basic_gnn_gin', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment