Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save leslie-fang-intel/54e5c7b601e55fc44087ec9651e41ea6 to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/54e5c7b601e55fc44087ec9651e41ea6 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
_frozen_param18 = None # device(type='cpu') torch.float32 (64,) (1,) 7f1e30c112b0
_frozen_param34 = None # device(type='cpu') torch.float32 (64,) (1,) 7f1e30c10ea0
_frozen_param47 = None # device(type='cpu') torch.float32 (128,) (1,) 7f1e30c13dd0
_frozen_param63 = None # device(type='cpu') torch.float32 (128,) (1,) 7f1e30d13ab0
_frozen_param85 = None # device(type='cpu') torch.float32 (16,) (1,) 7f1e30d11a30
_frozen_param87 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e30d11bc0
_frozen_param101 = None # device(type='cpu') torch.float32 (16,) (1,) 7f1e30d05080
_frozen_param103 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e30d07470
_frozen_param147 = None # device(type='cpu') torch.float32 (1000,) (1,) 7f1e30d0c8b0
_frozen_param564 = None # device(type='cpu') torch.float32 (24,) (1,) 7f1e1f7b8090
_frozen_param598 = None # device(type='cpu') torch.float32 (24, 3, 3, 3) (1, 0, 0, 0) 7f1e1f7ff8d0
_frozen_param565 = None # device(type='cpu') torch.float32 (32,) (1,) 7f1e1fad4e00
_frozen_param599 = None # device(type='cpu') torch.float32 (32, 24, 3, 3) (1, 0, 0, 0) 7f1e1f6b1c60
_frozen_param566 = None # device(type='cpu') torch.float32 (64,) (1,) 7f1e1f84e660
_frozen_param600 = None # device(type='cpu') torch.float32 (64, 32, 3, 3) (1, 0, 0, 0) 7f1e1f6b1cb0
_frozen_param567 = None # device(type='cpu') torch.float32 (64,) (1,) 7f1e1fb2ebb0
_frozen_param601 = None # device(type='cpu') torch.float32 (64, 64, 1, 1) (1, 0, 0, 0) 7f1e1f6b1c10
_frozen_param568 = None # device(type='cpu') torch.float32 (64,) (1,) 7f1e1fb0c270
_frozen_param602 = None # device(type='cpu') torch.float32 (64, 64, 3, 3) (1, 0, 0, 0) 7f1e1f6b1b70
_frozen_param603 = None # device(type='cpu') torch.float32 (8, 64, 1, 1) (1, 0, 0, 0) 7f1e1f6b1d50
_frozen_param604 = None # device(type='cpu') torch.float32 (64, 8, 1, 1) (1, 0, 0, 0) 7f1e1f6b1da0
_frozen_param569 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1f79b0b0
_frozen_param605 = None # device(type='cpu') torch.float32 (256, 64, 1, 1) (1, 0, 0, 0) 7f1e1f6b1df0
_frozen_param570 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1fb0e3e0
_frozen_param606 = None # device(type='cpu') torch.float32 (256, 64, 1, 1) (1, 0, 0, 0) 7f1e1f6b1e40
_frozen_param571 = None # device(type='cpu') torch.float32 (64,) (1,) 7f1e1f793420
_frozen_param607 = None # device(type='cpu') torch.float32 (64, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b1e90
_frozen_param572 = None # device(type='cpu') torch.float32 (64,) (1,) 7f1e1f792980
_frozen_param608 = None # device(type='cpu') torch.float32 (64, 64, 3, 3) (1, 0, 0, 0) 7f1e1f7e3010
_frozen_param609 = None # device(type='cpu') torch.float32 (8, 64, 1, 1) (1, 0, 0, 0) 7f1e1f6b1f30
_frozen_param610 = None # device(type='cpu') torch.float32 (64, 8, 1, 1) (1, 0, 0, 0) 7f1e1fa9bdd0
_frozen_param573 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1f92d0d0
_frozen_param611 = None # device(type='cpu') torch.float32 (256, 64, 1, 1) (1, 0, 0, 0) 7f1e1f6b1fd0
_frozen_param574 = None # device(type='cpu') torch.float32 (128,) (1,) 7f1e1f9ce930
_frozen_param612 = None # device(type='cpu') torch.float32 (128, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b2020
_frozen_param575 = None # device(type='cpu') torch.float32 (128,) (1,) 7f1e1f7fcd60
_frozen_param613 = None # device(type='cpu') torch.float32 (128, 128, 3, 3) (1, 0, 0, 0) 7f1e1f6b2070
_frozen_param614 = None # device(type='cpu') torch.float32 (8, 128, 1, 1) (1, 0, 0, 0) 7f1e1f6b1ee0
_frozen_param615 = None # device(type='cpu') torch.float32 (128, 8, 1, 1) (1, 0, 0, 0) 7f1e1f6b20c0
_frozen_param576 = None # device(type='cpu') torch.float32 (512,) (1,) 7f1e1f7fc680
_frozen_param616 = None # device(type='cpu') torch.float32 (512, 128, 1, 1) (1, 0, 0, 0) 7f1e1f6b1f80
_frozen_param577 = None # device(type='cpu') torch.float32 (512,) (1,) 7f1e1f7fdda0
_frozen_param617 = None # device(type='cpu') torch.float32 (512, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b2110
_frozen_param578 = None # device(type='cpu') torch.float32 (128,) (1,) 7f1e1f800a40
_frozen_param618 = None # device(type='cpu') torch.float32 (128, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b2160
_frozen_param579 = None # device(type='cpu') torch.float32 (128,) (1,) 7f1e1f801c60
_frozen_param619 = None # device(type='cpu') torch.float32 (128, 128, 3, 3) (1, 0, 0, 0) 7f1e1f6b21b0
_frozen_param620 = None # device(type='cpu') torch.float32 (8, 128, 1, 1) (1, 0, 0, 0) 7f1e1f6b2200
_frozen_param621 = None # device(type='cpu') torch.float32 (128, 8, 1, 1) (1, 0, 0, 0) 7f1e1f6b2250
_frozen_param580 = None # device(type='cpu') torch.float32 (512,) (1,) 7f1e1f85f600
_frozen_param622 = None # device(type='cpu') torch.float32 (512, 128, 1, 1) (1, 0, 0, 0) 7f1e1f6b22a0
_frozen_param581 = None # device(type='cpu') torch.float32 (128,) (1,) 7f1e1f85f740
_frozen_param623 = None # device(type='cpu') torch.float32 (128, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b22f0
_frozen_param624 = None # device(type='cpu') torch.float32 (384, 128, 1, 1) (1, 0, 0, 0) 7f1e1f6b2340
_frozen_param625 = None # device(type='cpu') torch.float32 (63, 32) (32, 1) 7f1e1fad6bb0
_frozen_param626 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b23e0
_frozen_param627 = None # device(type='cpu') torch.float32 (63, 32) (32, 1) 7f1e1f7f2200
_frozen_param628 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b2390
_frozen_param305 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f1e1f92f060
_frozen_param306 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f1e1f92ee80
_frozen_param307 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f1e1f92ef70
_frozen_param308 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f1e1f92f010
_frozen_param582 = None # device(type='cpu') torch.float32 (512,) (1,) 7f1e1fb0c9f0
_frozen_param629 = None # device(type='cpu') torch.float32 (512, 128, 1, 1) (1, 0, 0, 0) 7f1e1f6b27f0
_frozen_param583 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1f7e7470
_frozen_param630 = None # device(type='cpu') torch.float32 (256, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b27a0
_frozen_param584 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1f85f6f0
_frozen_param631 = None # device(type='cpu') torch.float32 (256, 256, 3, 3) (1, 0, 0, 0) 7f1e1f6b2750
_frozen_param632 = None # device(type='cpu') torch.float32 (16, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b2700
_frozen_param633 = None # device(type='cpu') torch.float32 (256, 16, 1, 1) (1, 0, 0, 0) 7f1e1f6b2660
_frozen_param585 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f1e1f85f2e0
_frozen_param634 = None # device(type='cpu') torch.float32 (1024, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b2570
_frozen_param586 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f1e1f85f830
_frozen_param635 = None # device(type='cpu') torch.float32 (1024, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b2840
_frozen_param587 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1f85f880
_frozen_param636 = None # device(type='cpu') torch.float32 (256, 1024, 1, 1) (1, 0, 0, 0) 7f1e1f6b2610
_frozen_param588 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1f85f8d0
_frozen_param637 = None # device(type='cpu') torch.float32 (256, 256, 3, 3) (1, 0, 0, 0) 7f1e1f6b25c0
_frozen_param638 = None # device(type='cpu') torch.float32 (16, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b2520
_frozen_param639 = None # device(type='cpu') torch.float32 (256, 16, 1, 1) (1, 0, 0, 0) 7f1e1f6b2890
_frozen_param589 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f1e1f85f7e0
_frozen_param640 = None # device(type='cpu') torch.float32 (1024, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b28e0
_frozen_param590 = None # device(type='cpu') torch.float32 (256,) (1,) 7f1e1f85f970
_frozen_param641 = None # device(type='cpu') torch.float32 (256, 1024, 1, 1) (1, 0, 0, 0) 7f1e1f95a610
_frozen_param642 = None # device(type='cpu') torch.float32 (768, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b2980
_frozen_param643 = None # device(type='cpu') torch.float32 (31, 64) (64, 1) 7f1e1f6b2c00
_frozen_param644 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b2a20
_frozen_param645 = None # device(type='cpu') torch.float32 (31, 64) (64, 1) 7f1e1f6b2930
_frozen_param646 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b2ac0
_frozen_param349 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f1e1f92ff60
_frozen_param350 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f1e1f93c090
_frozen_param351 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f1e1f93c130
_frozen_param352 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f1e1f93c180
_frozen_param591 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f1e1f85f9c0
_frozen_param647 = None # device(type='cpu') torch.float32 (1024, 256, 1, 1) (1, 0, 0, 0) 7f1e1f6b2de0
_frozen_param592 = None # device(type='cpu') torch.float32 (512,) (1,) 7f1e1f85fa10
_frozen_param648 = None # device(type='cpu') torch.float32 (512, 1024, 1, 1) (1, 0, 0, 0) 7f1e1f6b2d90
_frozen_param649 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b2d40
_frozen_param650 = None # device(type='cpu') torch.float32 (31, 128) (128, 1) 7f1e1f6b2e80
_frozen_param651 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b2ed0
_frozen_param652 = None # device(type='cpu') torch.float32 (31, 128) (128, 1) 7f1e1f6b2bb0
_frozen_param653 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b2ca0
_frozen_param363 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93c5e0
_frozen_param364 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93c630
_frozen_param365 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93c680
_frozen_param366 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93c6d0
_frozen_param593 = None # device(type='cpu') torch.float32 (1536,) (1,) 7f1e1f85f6a0
_frozen_param654 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b30b0
_frozen_param594 = None # device(type='cpu') torch.float32 (1536,) (1,) 7f1e1f85fab0
_frozen_param655 = None # device(type='cpu') torch.float32 (1536, 1024, 1, 1) (1, 0, 0, 0) 7f1e1f6b3060
_frozen_param595 = None # device(type='cpu') torch.float32 (512,) (1,) 7f1e1f85fb00
_frozen_param656 = None # device(type='cpu') torch.float32 (512, 1536, 1, 1) (1, 0, 0, 0) 7f1e1f6b3010
_frozen_param657 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b2fc0
_frozen_param658 = None # device(type='cpu') torch.float32 (15, 128) (128, 1) 7f1e1f6b31a0
_frozen_param659 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b31f0
_frozen_param660 = None # device(type='cpu') torch.float32 (15, 128) (128, 1) 7f1e1f6b2a70
_frozen_param661 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f1e1f6b2c50
_frozen_param381 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93cd10
_frozen_param382 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93cc70
_frozen_param383 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93cd60
_frozen_param384 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f1e1f93cdb0
_frozen_param596 = None # device(type='cpu') torch.float32 (1536,) (1,) 7f1e1f85fb50
_frozen_param662 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f1e1f6b33d0
_frozen_param597 = None # device(type='cpu') torch.float32 (1280,) (1,) 7f1e1f85fba0
_frozen_param663 = None # device(type='cpu') torch.float32 (1280, 1536, 1, 1) (1, 0, 0, 0) 7f1e1f6b3380
_frozen_param664 = None # device(type='cpu') torch.float32 (1000, 1280) (1280, 1) 7f1e1f6b3150
_frozen_param665 = None # device(type='cpu') torch.float32 (3490017, 1) (1, 0) 7f1e1f6b3470
cpp_fused_silu_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3L); x0+=static_cast<int64_t>(3L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(65536L); x1+=static_cast<int64_t>(16L))
{
alignas(16) float tmp1[3*16];
for (long x0_inner = 0; x0_inner < 3; x0_inner++)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + (65536L*x0) + (65536L*x0_inner)), 16);
tmp0.store(tmp1 + static_cast<int64_t>(16L*x0_inner));
}
at::vec::transpose_mxn<float,3,16>(tmp1, 16, out_ptr0 + static_cast<int64_t>(x0 + (3L*x1)), static_cast<int64_t>(3L));
}
}
}
}
''')
cpp_fused_mean_1 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(4096L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (64L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(4096.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(-0.3303021788597107);
auto tmp9 = static_cast<float>(0.06354581564664841);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(1.774228811264038);
auto tmp14 = static_cast<float>(2.1113927364349365);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(0.32513317465782166);
auto tmp22 = static_cast<float>(1.232210397720337);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(0.7079262137413025);
auto tmp27 = static_cast<float>(0.2353029102087021);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_2 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4096L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)));
}
}
}
}
''')
cpp_fused_silu_3 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1048576L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_4 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(4096L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (64L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(4096.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(0.7516645193099976);
auto tmp9 = static_cast<float>(0.5124969482421875);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(1.2062517404556274);
auto tmp14 = static_cast<float>(0.9069646596908569);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(1.4622137546539307);
auto tmp22 = static_cast<float>(-0.11386305838823318);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(-0.2968502938747406);
auto tmp27 = static_cast<float>(0.884636402130127);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_5 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4096L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)));
}
}
}
}
''')
cpp_fused_silu_6 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1048576L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_7 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (128L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(1024.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(0.47172150015830994);
auto tmp9 = static_cast<float>(1.4283583164215088);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(-0.04577525332570076);
auto tmp14 = static_cast<float>(2.043065309524536);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(0.13726529479026794);
auto tmp22 = static_cast<float>(1.1331775188446045);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(-0.11772552132606506);
auto tmp27 = static_cast<float>(0.527721107006073);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_8 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(128L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)));
}
}
}
}
''')
cpp_fused_silu_9 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(524288L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_10 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (128L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(1024.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(0.29536592960357666);
auto tmp9 = static_cast<float>(1.1849623918533325);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(0.3150717318058014);
auto tmp14 = static_cast<float>(0.5096337795257568);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(-0.18543481826782227);
auto tmp22 = static_cast<float>(0.3189537227153778);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(-0.7191315293312073);
auto tmp27 = static_cast<float>(0.42770394682884216);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_11 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(128L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)));
}
}
}
}
''')
cpp_fused_silu_12 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(524288L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_13 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (32L*x0) + (384L*x1) + (12288L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (32L*x2) + (1024L*x1) + (32768L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_14 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (32L*x0) + (384L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (32L*x1) + (32768L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_15 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(1024L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (1024L*x1) + (1048576L*x0))];
auto tmp1 = static_cast<float>(0.1767766952966369);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(2048);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(63);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))), static_cast<int64_t>(64L)))) + (2016L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(63);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L)))), static_cast<int64_t>(64L)))) + (2016L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp29);
}
out_ptr0[static_cast<int64_t>(x1 + (1024L*x0))] = tmp_acc0;
}
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(1024L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (1024L*x1) + (1048576L*x0))];
auto tmp30 = out_ptr0[static_cast<int64_t>(x1 + (1024L*x0))];
auto tmp1 = static_cast<float>(0.1767766952966369);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(2048);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(63);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))), static_cast<int64_t>(64L)))) + (2016L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(63);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L)))), static_cast<int64_t>(64L)))) + (2016L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
auto tmp31 = decltype(tmp29)(tmp29 - tmp30);
auto tmp32 = std::exp(tmp31);
out_ptr1[static_cast<int64_t>(x2 + (1024L*x1) + (1048576L*x0))] = tmp32;
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4096L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (1024L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (1024L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (1024L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_silu_16 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(128L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((32L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (static_cast<int64_t>(x1) % static_cast<int64_t>(32L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x1), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 - tmp1;
auto tmp4 = tmp2 * tmp3;
auto tmp6 = tmp4 * tmp5;
auto tmp8 = tmp6 + tmp7;
auto tmp9 = decltype(tmp8)(1)/(decltype(tmp8)(1) + tmp8.neg().exp());
auto tmp10 = tmp8 * tmp9;
tmp10.store(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)));
}
}
}
}
''')
cpp_fused_silu_17 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(524288L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_18 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (256L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(256.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mul_19 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused_silu_20 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(262144L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_21 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (256L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(256.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mul_22 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused_silu_23 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(262144L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_24 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(16L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(64L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (64L*x0) + (768L*x1) + (12288L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (64L*x2) + (1024L*x1) + (16384L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_25 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (64L*x0) + (768L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (64L*x1) + (16384L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_26 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0))];
auto tmp1 = static_cast<float>(0.125);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(512);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(31);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(31);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp29);
}
out_ptr0[static_cast<int64_t>(x1 + (256L*x0))] = tmp_acc0;
}
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0))];
auto tmp30 = out_ptr0[static_cast<int64_t>(x1 + (256L*x0))];
auto tmp1 = static_cast<float>(0.125);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(512);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(31);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(31);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
auto tmp31 = decltype(tmp29)(tmp29 - tmp30);
auto tmp32 = std::exp(tmp31);
out_ptr1[static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0))] = tmp32;
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_silu_27 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((64L*x0) + (16384L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(64L)))) + (static_cast<int64_t>(x1) % static_cast<int64_t>(64L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x1), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 - tmp1;
auto tmp4 = tmp2 * tmp3;
auto tmp6 = tmp4 * tmp5;
auto tmp8 = tmp6 + tmp7;
auto tmp9 = decltype(tmp8)(1)/(decltype(tmp8)(1) + tmp8.neg().exp());
auto tmp10 = tmp8 * tmp9;
tmp10.store(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused_silu_28 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(262144L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_29 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(16L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(128L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (128L*x0) + (1536L*x1) + (24576L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (128L*x2) + (2048L*x1) + (32768L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_30 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(128L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (128L*x0) + (1536L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (128L*x1) + (32768L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_31 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0))];
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(512);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(31);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(31);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp29);
}
out_ptr0[static_cast<int64_t>(x1 + (256L*x0))] = tmp_acc0;
}
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0))];
auto tmp30 = out_ptr0[static_cast<int64_t>(x1 + (256L*x0))];
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(512);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(31);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(31);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
auto tmp31 = decltype(tmp29)(tmp29 - tmp30);
auto tmp32 = std::exp(tmp31);
out_ptr1[static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0))] = tmp32;
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_avg_pool2d_silu_32 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(8L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(512L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(128L + (256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(2048L + (256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(2176L + (256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp10 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x2), 16);
auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x2), 16);
auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x2), 16);
auto tmp16 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x2), 16);
auto tmp2 = tmp1 + tmp0;
auto tmp4 = tmp3 + tmp2;
auto tmp6 = tmp5 + tmp4;
auto tmp7 = static_cast<float>(0.25);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp6 * tmp8;
auto tmp11 = tmp9 - tmp10;
auto tmp13 = tmp11 * tmp12;
auto tmp15 = tmp13 * tmp14;
auto tmp17 = tmp15 + tmp16;
auto tmp18 = decltype(tmp17)(1)/(decltype(tmp17)(1) + tmp17.neg().exp());
auto tmp19 = tmp17 * tmp18;
tmp19.store(in_out_ptr0 + static_cast<int64_t>(x2 + (512L*x1) + (4096L*x0)));
}
}
}
}
}
''')
cpp_fused_silu_33 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_34 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(8L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(8L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(128L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (128L*x0) + (1536L*x1) + (12288L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (128L*x2) + (1024L*x1) + (8192L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_35 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(128L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (128L*x0) + (1536L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (128L*x1) + (8192L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_36 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (64L*x1) + (4096L*x0))];
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(128);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(15);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L))))), static_cast<int64_t>(16L)))) + (120L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(15);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L)))), static_cast<int64_t>(16L)))) + (120L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp29);
}
out_ptr0[static_cast<int64_t>(x1 + (64L*x0))] = tmp_acc0;
}
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(1L))
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x2 + (64L*x1) + (4096L*x0))];
auto tmp30 = out_ptr0[static_cast<int64_t>(x1 + (64L*x0))];
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
auto tmp3 = 7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L)));
auto tmp4 = c10::convert<int64_t>(tmp3);
auto tmp5 = static_cast<int64_t>(128);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L);
auto tmp9 = c10::convert<int64_t>(tmp8);
auto tmp10 = static_cast<int64_t>(15);
auto tmp11 = tmp9 < tmp10;
auto tmp12 = [&]
{
auto tmp13 = in_ptr1[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L))))), static_cast<int64_t>(16L)))) + (120L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L)))];
return tmp13;
}
;
auto tmp14 = tmp11 ? tmp12() : static_cast<decltype(tmp12())>(0.0);
return tmp14;
}
;
auto tmp15 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp16 = 7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L));
auto tmp17 = c10::convert<int64_t>(tmp16);
auto tmp18 = tmp17 < tmp5;
auto tmp19 = [&]
{
auto tmp20 = static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L);
auto tmp21 = c10::convert<int64_t>(tmp20);
auto tmp22 = static_cast<int64_t>(15);
auto tmp23 = tmp21 < tmp22;
auto tmp24 = [&]
{
auto tmp25 = in_ptr2[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L)))), static_cast<int64_t>(16L)))) + (120L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L)))];
return tmp25;
}
;
auto tmp26 = tmp23 ? tmp24() : static_cast<decltype(tmp24())>(0.0);
return tmp26;
}
;
auto tmp27 = tmp18 ? tmp19() : static_cast<decltype(tmp19())>(0.0);
auto tmp28 = decltype(tmp15)(tmp15 + tmp27);
auto tmp29 = decltype(tmp2)(tmp2 + tmp28);
auto tmp31 = decltype(tmp29)(tmp29 - tmp30);
auto tmp32 = std::exp(tmp31);
out_ptr1[static_cast<int64_t>(x2 + (64L*x1) + (4096L*x0))] = tmp32;
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (64L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (64L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (64L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_silu_37 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(512L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((128L*x0) + (8192L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x1) % static_cast<int64_t>(128L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x1), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 - tmp1;
auto tmp4 = tmp2 * tmp3;
auto tmp6 = tmp4 * tmp5;
auto tmp8 = tmp6 + tmp7;
auto tmp9 = decltype(tmp8)(1)/(decltype(tmp8)(1) + tmp8.neg().exp());
auto tmp10 = tmp8 * tmp9;
tmp10.store(in_out_ptr0 + static_cast<int64_t>(x1 + (512L*x0)));
}
}
}
}
''')
cpp_fused_silu_38 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_39 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1280L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (1280L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1280L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(64.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_addmm_40 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(992L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp2 = tmp0 + tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
for(int64_t x0=static_cast<int64_t>(992L); x0<static_cast<int64_t>(1000L); x0+=static_cast<int64_t>(8L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), 8);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 8);
auto tmp2 = tmp0 + tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x0), 8);
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg224_1, = args
args.clear()
assert_size_stride(arg224_1, (1, 3, 256, 256), (196608, 65536, 256, 1))
buf0 = empty_strided_cpu((1, 3, 256, 256), (196608, 1, 768, 3), torch.float32)
cpp_fused_silu_0(arg224_1, buf0)
del arg224_1
buf1 = torch.ops.mkldnn._convolution_pointwise.default(buf0, _frozen_param598, _frozen_param564, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf1, (1, 24, 128, 128), (393216, 1, 3072, 24))
del buf0
buf2 = torch.ops.mkldnn._convolution_pointwise.default(buf1, _frozen_param599, _frozen_param565, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf2, (1, 32, 128, 128), (524288, 1, 4096, 32))
del buf1
buf3 = torch.ops.mkldnn._convolution_pointwise.default(buf2, _frozen_param600, _frozen_param566, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf3, (1, 64, 64, 64), (262144, 1, 4096, 64))
del buf2
buf4 = torch.ops.mkldnn._convolution_pointwise.default(buf3, _frozen_param601, _frozen_param567, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf4, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf5 = torch.ops.mkldnn._convolution_pointwise.default(buf4, _frozen_param602, _frozen_param568, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf5, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf6 = empty_strided_cpu((1, 64, 1, 1), (64, 1, 64, 64), torch.float32)
buf7 = reinterpret_tensor(buf6, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf6 # reuse
buf8 = empty_strided_cpu((8, ), (1, ), torch.float32)
cpp_fused_mean_1(buf7, buf5, buf8)
buf9 = torch.ops.mkldnn._convolution_pointwise.default(buf7, _frozen_param603, buf8, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf9, (1, 8, 1, 1), (8, 1, 8, 8))
del buf7
del buf8
buf10 = torch.ops.mkldnn._convolution_pointwise.default(buf9, _frozen_param604, _frozen_param18, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf10, (1, 64, 1, 1), (64, 1, 64, 64))
buf11 = buf5; del buf5 # reuse
cpp_fused_mul_2(buf11, buf10)
buf12 = torch.ops.mkldnn._convolution_pointwise.default(buf11, _frozen_param605, _frozen_param569, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf12, (1, 256, 64, 64), (1048576, 1, 16384, 256))
del buf11
buf13 = torch.ops.mkldnn._convolution_pointwise_.binary(buf12, buf3, _frozen_param606, _frozen_param570, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf3
del buf4
buf16 = empty_strided_cpu((1, 256, 64, 64), (1048576, 1, 16384, 256), torch.float32)
cpp_fused_silu_3(buf12, buf16)
buf17 = torch.ops.mkldnn._convolution_pointwise.default(buf16, _frozen_param607, _frozen_param571, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf17, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf18 = torch.ops.mkldnn._convolution_pointwise.default(buf17, _frozen_param608, _frozen_param572, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf18, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf19 = buf10; del buf10 # reuse
buf20 = reinterpret_tensor(buf19, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf19 # reuse
buf21 = reinterpret_tensor(buf9, (8, ), (1, ), 0); del buf9 # reuse
cpp_fused_mean_4(buf20, buf18, buf21)
buf22 = torch.ops.mkldnn._convolution_pointwise.default(buf20, _frozen_param609, buf21, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf22, (1, 8, 1, 1), (8, 1, 8, 8))
del buf20
del buf21
buf23 = torch.ops.mkldnn._convolution_pointwise.default(buf22, _frozen_param610, _frozen_param34, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf23, (1, 64, 1, 1), (64, 1, 64, 64))
buf24 = buf18; del buf18 # reuse
cpp_fused_mul_5(buf24, buf23)
del buf23
buf25 = torch.ops.mkldnn._convolution_pointwise_.binary(buf16, buf24, _frozen_param611, _frozen_param573, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
buf28 = buf12; del buf12 # reuse
cpp_fused_silu_6(buf16, buf28)
del buf16
buf29 = torch.ops.mkldnn._convolution_pointwise.default(buf28, _frozen_param612, _frozen_param574, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf29, (1, 128, 64, 64), (524288, 1, 8192, 128))
buf30 = torch.ops.mkldnn._convolution_pointwise.default(buf29, _frozen_param613, _frozen_param575, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf30, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf31 = empty_strided_cpu((1, 128, 1, 1), (128, 1, 128, 128), torch.float32)
buf32 = reinterpret_tensor(buf31, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf31 # reuse
buf33 = reinterpret_tensor(buf22, (8, ), (1, ), 0); del buf22 # reuse
cpp_fused_mean_7(buf32, buf30, buf33)
buf34 = torch.ops.mkldnn._convolution_pointwise.default(buf32, _frozen_param614, buf33, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf34, (1, 8, 1, 1), (8, 1, 8, 8))
del buf32
del buf33
buf35 = torch.ops.mkldnn._convolution_pointwise.default(buf34, _frozen_param615, _frozen_param47, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf35, (1, 128, 1, 1), (128, 1, 128, 128))
buf36 = buf30; del buf30 # reuse
cpp_fused_mul_8(buf36, buf35)
buf37 = torch.ops.mkldnn._convolution_pointwise.default(buf36, _frozen_param616, _frozen_param576, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf37, (1, 512, 32, 32), (524288, 1, 16384, 512))
del buf36
buf38 = torch.ops.mkldnn._convolution_pointwise_.binary(buf37, buf28, _frozen_param617, _frozen_param577, [0, 0], [2, 2], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf28
buf41 = reinterpret_tensor(buf29, (1, 512, 32, 32), (524288, 1, 16384, 512), 0); del buf29 # reuse
cpp_fused_silu_9(buf37, buf41)
buf42 = torch.ops.mkldnn._convolution_pointwise.default(buf41, _frozen_param618, _frozen_param578, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf42, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf43 = torch.ops.mkldnn._convolution_pointwise.default(buf42, _frozen_param619, _frozen_param579, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf43, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf44 = buf35; del buf35 # reuse
buf45 = reinterpret_tensor(buf44, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf44 # reuse
buf46 = reinterpret_tensor(buf34, (8, ), (1, ), 0); del buf34 # reuse
cpp_fused_mean_10(buf45, buf43, buf46)
buf47 = torch.ops.mkldnn._convolution_pointwise.default(buf45, _frozen_param620, buf46, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf47, (1, 8, 1, 1), (8, 1, 8, 8))
del buf45
del buf46
buf48 = torch.ops.mkldnn._convolution_pointwise.default(buf47, _frozen_param621, _frozen_param63, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf48, (1, 128, 1, 1), (128, 1, 128, 128))
del buf47
buf49 = buf43; del buf43 # reuse
cpp_fused_mul_11(buf49, buf48)
del buf48
buf50 = torch.ops.mkldnn._convolution_pointwise_.binary(buf41, buf49, _frozen_param622, _frozen_param580, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
buf53 = buf37; del buf37 # reuse
cpp_fused_silu_12(buf41, buf53)
buf54 = torch.ops.mkldnn._convolution_pointwise.default(buf53, _frozen_param623, _frozen_param581, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf54, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf55 = torch.ops.mkldnn._convolution_pointwise.default(buf54, _frozen_param624, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf55, (1, 384, 32, 32), (393216, 1, 12288, 384))
buf56 = empty_strided_cpu((4, 1024, 1024), (1048576, 1024, 1), torch.float32)
# Topologically Sorted Source Nodes: [matmul], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf55, (4, 1024, 32), (32, 384, 1), 0), reinterpret_tensor(buf55, (4, 32, 1024), (32, 1, 384), 128), out=buf56)
buf57 = reinterpret_tensor(buf49, (4, 32, 32, 32), (32768, 1024, 32, 1), 0); del buf49 # reuse
cpp_fused_clone_13(buf55, buf57)
buf58 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf57, (4096, 32), (32, 1), 0), _frozen_param628, _frozen_param627, None, 4096)
buf59 = buf57; del buf57 # reuse
cpp_fused_clone_14(buf55, buf59)
buf60 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf59, (4096, 32), (32, 1), 0), _frozen_param626, _frozen_param625, None, 4096)
buf61 = empty_strided_cpu((4, 1024, 1), (1024, 1, 4096), torch.float32)
buf62 = empty_strided_cpu((4, 1024, 1024), (1048576, 1024, 1), torch.float32)
buf63 = empty_strided_cpu((4, 1024, 1), (1024, 1, 4096), torch.float32)
buf64 = empty_strided_cpu((4, 1024, 1024), (1048576, 1024, 1), torch.float32)
cpp_fused__softmax_add_mul_15(buf56, buf58, buf60, buf61, buf62, buf63, buf64)
del buf56
del buf58
del buf60
del buf61
del buf62
del buf63
buf65 = reinterpret_tensor(buf59, (4, 1024, 32), (32768, 32, 1), 0); del buf59 # reuse
# Topologically Sorted Source Nodes: [attn_1, matmul_3], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf64, reinterpret_tensor(buf55, (4, 1024, 32), (32, 384, 1), 256), out=buf65)
del buf55
del buf64
buf66 = buf42; del buf42 # reuse
buf67 = buf66; del buf66 # reuse
cpp_fused__native_batch_norm_legit_no_training_silu_16(buf67, buf65, _frozen_param305, _frozen_param306, _frozen_param307, _frozen_param308)
del buf65
buf68 = torch.ops.mkldnn._convolution_pointwise_.binary(buf53, buf67, _frozen_param629, _frozen_param582, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf54
buf71 = buf41; del buf41 # reuse
cpp_fused_silu_17(buf53, buf71)
del buf53
buf72 = torch.ops.mkldnn._convolution_pointwise.default(buf71, _frozen_param630, _frozen_param583, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf72, (1, 256, 32, 32), (262144, 1, 8192, 256))
buf73 = torch.ops.mkldnn._convolution_pointwise.default(buf72, _frozen_param631, _frozen_param584, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf73, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf74 = empty_strided_cpu((1, 256, 1, 1), (256, 1, 256, 256), torch.float32)
buf75 = reinterpret_tensor(buf74, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf74 # reuse
cpp_fused_mean_18(buf75, buf73)
buf76 = torch.ops.mkldnn._convolution_pointwise.default(buf75, _frozen_param632, _frozen_param85, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf76, (1, 16, 1, 1), (16, 1, 16, 16))
del buf75
buf77 = torch.ops.mkldnn._convolution_pointwise.default(buf76, _frozen_param633, _frozen_param87, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf77, (1, 256, 1, 1), (256, 1, 256, 256))
del buf76
buf78 = buf73; del buf73 # reuse
cpp_fused_mul_19(buf78, buf77)
buf79 = torch.ops.mkldnn._convolution_pointwise.default(buf78, _frozen_param634, _frozen_param585, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf79, (1, 1024, 16, 16), (262144, 1, 16384, 1024))
del buf78
buf80 = torch.ops.mkldnn._convolution_pointwise_.binary(buf79, buf71, _frozen_param635, _frozen_param586, [0, 0], [2, 2], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf71
buf83 = reinterpret_tensor(buf72, (1, 1024, 16, 16), (262144, 1, 16384, 1024), 0); del buf72 # reuse
cpp_fused_silu_20(buf79, buf83)
buf84 = torch.ops.mkldnn._convolution_pointwise.default(buf83, _frozen_param636, _frozen_param587, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf84, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf85 = torch.ops.mkldnn._convolution_pointwise.default(buf84, _frozen_param637, _frozen_param588, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf85, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf86 = buf77; del buf77 # reuse
buf87 = reinterpret_tensor(buf86, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf86 # reuse
cpp_fused_mean_21(buf87, buf85)
buf88 = torch.ops.mkldnn._convolution_pointwise.default(buf87, _frozen_param638, _frozen_param101, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf88, (1, 16, 1, 1), (16, 1, 16, 16))
buf89 = torch.ops.mkldnn._convolution_pointwise.default(buf88, _frozen_param639, _frozen_param103, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf89, (1, 256, 1, 1), (256, 1, 256, 256))
del buf88
buf90 = buf85; del buf85 # reuse
cpp_fused_mul_22(buf90, buf89)
buf91 = torch.ops.mkldnn._convolution_pointwise_.binary(buf83, buf90, _frozen_param640, _frozen_param589, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
buf94 = buf79; del buf79 # reuse
cpp_fused_silu_23(buf83, buf94)
buf95 = torch.ops.mkldnn._convolution_pointwise.default(buf94, _frozen_param641, _frozen_param590, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf95, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf96 = torch.ops.mkldnn._convolution_pointwise.default(buf95, _frozen_param642, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf96, (1, 768, 16, 16), (196608, 1, 12288, 768))
buf97 = reinterpret_tensor(buf83, (4, 256, 256), (65536, 256, 1), 0); del buf83 # reuse
# Topologically Sorted Source Nodes: [matmul_4], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf96, (4, 256, 64), (64, 768, 1), 0), reinterpret_tensor(buf96, (4, 64, 256), (64, 1, 768), 256), out=buf97)
buf98 = reinterpret_tensor(buf90, (4, 16, 16, 64), (16384, 1024, 64, 1), 0); del buf90 # reuse
cpp_fused_clone_24(buf96, buf98)
buf99 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf98, (1024, 64), (64, 1), 0), _frozen_param646, _frozen_param645, None, 1024)
buf100 = buf98; del buf98 # reuse
cpp_fused_clone_25(buf96, buf100)
buf101 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf100, (1024, 64), (64, 1), 0), _frozen_param644, _frozen_param643, None, 1024)
buf102 = empty_strided_cpu((4, 256, 1), (256, 1, 1024), torch.float32)
buf103 = reinterpret_tensor(buf24, (4, 256, 256), (65536, 256, 1), 0); del buf24 # reuse
buf104 = empty_strided_cpu((4, 256, 1), (256, 1, 1024), torch.float32)
buf105 = reinterpret_tensor(buf17, (4, 256, 256), (65536, 256, 1), 0); del buf17 # reuse
cpp_fused__softmax_add_mul_26(buf97, buf99, buf101, buf102, buf103, buf104, buf105)
del buf101
del buf99
buf106 = reinterpret_tensor(buf100, (4, 256, 64), (16384, 64, 1), 0); del buf100 # reuse
# Topologically Sorted Source Nodes: [attn_3, matmul_7], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf105, reinterpret_tensor(buf96, (4, 256, 64), (64, 768, 1), 512), out=buf106)
del buf96
buf107 = buf84; del buf84 # reuse
buf108 = buf107; del buf107 # reuse
cpp_fused__native_batch_norm_legit_no_training_silu_27(buf108, buf106, _frozen_param349, _frozen_param350, _frozen_param351, _frozen_param352)
del buf106
buf109 = torch.ops.mkldnn._convolution_pointwise_.binary(buf94, buf108, _frozen_param647, _frozen_param591, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf108
del buf95
buf112 = reinterpret_tensor(buf105, (1, 1024, 16, 16), (262144, 1, 16384, 1024), 0); del buf105 # reuse
cpp_fused_silu_28(buf94, buf112)
buf113 = torch.ops.mkldnn._convolution_pointwise.default(buf112, _frozen_param648, _frozen_param592, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf113, (1, 512, 16, 16), (131072, 1, 8192, 512))
buf114 = torch.ops.mkldnn._convolution_pointwise.default(buf113, _frozen_param649, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf114, (1, 1536, 16, 16), (393216, 1, 24576, 1536))
buf115 = reinterpret_tensor(buf94, (4, 256, 256), (65536, 256, 1), 0); del buf94 # reuse
# Topologically Sorted Source Nodes: [matmul_8], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf114, (4, 256, 128), (128, 1536, 1), 0), reinterpret_tensor(buf114, (4, 128, 256), (128, 1, 1536), 512), out=buf115)
buf116 = reinterpret_tensor(buf67, (4, 16, 16, 128), (32768, 2048, 128, 1), 0); del buf67 # reuse
cpp_fused_clone_29(buf114, buf116)
buf117 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf116, (1024, 128), (128, 1), 0), _frozen_param653, _frozen_param652, None, 1024)
buf118 = buf116; del buf116 # reuse
cpp_fused_clone_30(buf114, buf118)
buf119 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf118, (1024, 128), (128, 1), 0), _frozen_param651, _frozen_param650, None, 1024)
buf120 = buf104; del buf104 # reuse
buf121 = buf97; del buf97 # reuse
buf122 = buf102; del buf102 # reuse
buf123 = buf103; del buf103 # reuse
cpp_fused__softmax_add_mul_31(buf115, buf117, buf119, buf120, buf121, buf122, buf123)
del buf115
del buf117
del buf119
del buf120
del buf121
del buf122
buf124 = reinterpret_tensor(buf118, (4, 256, 128), (32768, 128, 1), 0); del buf118 # reuse
# Topologically Sorted Source Nodes: [attn_5, matmul_11], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf123, reinterpret_tensor(buf114, (4, 256, 128), (128, 1536, 1), 1024), out=buf124)
del buf114
del buf123
buf125 = empty_strided_cpu((1, 512, 8, 8), (32768, 1, 4096, 512), torch.float32)
buf126 = buf125; del buf125 # reuse
cpp_fused__native_batch_norm_legit_no_training_avg_pool2d_silu_32(buf126, buf124, _frozen_param363, _frozen_param364, _frozen_param365, _frozen_param366)
del buf124
buf127 = torch.ops.mkldnn._convolution_pointwise.default(buf126, _frozen_param654, _frozen_param593, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf127, (1, 1536, 8, 8), (98304, 1, 12288, 1536))
buf128 = torch.ops.mkldnn._convolution_pointwise_.binary(buf127, buf112, _frozen_param655, _frozen_param594, [0, 0], [2, 2], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf112
del buf113
buf131 = empty_strided_cpu((1, 1536, 8, 8), (98304, 1, 12288, 1536), torch.float32)
cpp_fused_silu_33(buf127, buf131)
del buf127
buf132 = torch.ops.mkldnn._convolution_pointwise.default(buf131, _frozen_param656, _frozen_param595, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf132, (1, 512, 8, 8), (32768, 1, 4096, 512))
buf133 = torch.ops.mkldnn._convolution_pointwise.default(buf132, _frozen_param657, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf133, (1, 1536, 8, 8), (98304, 1, 12288, 1536))
buf134 = empty_strided_cpu((4, 64, 64), (4096, 64, 1), torch.float32)
# Topologically Sorted Source Nodes: [matmul_12], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf133, (4, 64, 128), (128, 1536, 1), 0), reinterpret_tensor(buf133, (4, 128, 64), (128, 1, 1536), 512), out=buf134)
buf135 = reinterpret_tensor(buf126, (4, 8, 8, 128), (8192, 1024, 128, 1), 0); del buf126 # reuse
cpp_fused_clone_34(buf133, buf135)
buf136 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf135, (256, 128), (128, 1), 0), _frozen_param661, _frozen_param660, None, 256)
buf137 = buf135; del buf135 # reuse
cpp_fused_clone_35(buf133, buf137)
buf138 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf137, (256, 128), (128, 1), 0), _frozen_param659, _frozen_param658, None, 256)
buf139 = reinterpret_tensor(buf89, (4, 64, 1), (64, 1, 256), 0); del buf89 # reuse
buf140 = empty_strided_cpu((4, 64, 64), (4096, 64, 1), torch.float32)
buf141 = reinterpret_tensor(buf87, (4, 64, 1), (64, 1, 256), 0); del buf87 # reuse
buf142 = empty_strided_cpu((4, 64, 64), (4096, 64, 1), torch.float32)
cpp_fused__softmax_add_mul_36(buf134, buf136, buf138, buf139, buf140, buf141, buf142)
del buf134
del buf136
del buf138
del buf139
del buf140
del buf141
buf143 = reinterpret_tensor(buf137, (4, 64, 128), (8192, 128, 1), 0); del buf137 # reuse
# Topologically Sorted Source Nodes: [attn_7, matmul_15], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf142, reinterpret_tensor(buf133, (4, 64, 128), (128, 1536, 1), 1024), out=buf143)
del buf142
buf144 = empty_strided_cpu((1, 512, 8, 8), (32768, 1, 4096, 512), torch.float32)
buf145 = buf144; del buf144 # reuse
cpp_fused__native_batch_norm_legit_no_training_silu_37(buf145, buf143, _frozen_param381, _frozen_param382, _frozen_param383, _frozen_param384)
del buf143
buf146 = torch.ops.mkldnn._convolution_pointwise_.binary(buf131, buf145, _frozen_param662, _frozen_param596, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf132
del buf145
buf149 = buf133; del buf133 # reuse
cpp_fused_silu_38(buf131, buf149)
del buf131
buf150 = torch.ops.mkldnn._convolution_pointwise.default(buf149, _frozen_param663, _frozen_param597, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf150, (1, 1280, 8, 8), (81920, 1, 10240, 1280))
del buf149
buf151 = empty_strided_cpu((1, 1280, 1, 1), (1280, 1, 1280, 1280), torch.float32)
buf152 = buf151; del buf151 # reuse
cpp_fused_mean_39(buf152, buf150)
del buf150
buf153 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf152, (1, 1280), (0, 1), 0), _frozen_param665, _frozen_param664, None, 1)
del buf152
buf154 = buf153; del buf153 # reuse
cpp_fused_addmm_40(buf154, _frozen_param147)
return (buf154, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param18
_frozen_param18 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param34
_frozen_param34 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param47
_frozen_param47 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param63
_frozen_param63 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param85
_frozen_param85 = rand_strided((16, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param87
_frozen_param87 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param101
_frozen_param101 = rand_strided((16, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param103
_frozen_param103 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param147
_frozen_param147 = rand_strided((1000, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param564
_frozen_param564 = rand_strided((24, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param598
_frozen_param598 = rand_strided((24, 3, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param565
_frozen_param565 = rand_strided((32, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param599
_frozen_param599 = rand_strided((32, 24, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param566
_frozen_param566 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param600
_frozen_param600 = rand_strided((64, 32, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param567
_frozen_param567 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param601
_frozen_param601 = rand_strided((64, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param568
_frozen_param568 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param602
_frozen_param602 = rand_strided((64, 64, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param603
_frozen_param603 = rand_strided((8, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param604
_frozen_param604 = rand_strided((64, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param569
_frozen_param569 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param605
_frozen_param605 = rand_strided((256, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param570
_frozen_param570 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param606
_frozen_param606 = rand_strided((256, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param571
_frozen_param571 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param607
_frozen_param607 = rand_strided((64, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param572
_frozen_param572 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param608
_frozen_param608 = rand_strided((64, 64, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param609
_frozen_param609 = rand_strided((8, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param610
_frozen_param610 = rand_strided((64, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param573
_frozen_param573 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param611
_frozen_param611 = rand_strided((256, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param574
_frozen_param574 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param612
_frozen_param612 = rand_strided((128, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param575
_frozen_param575 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param613
_frozen_param613 = rand_strided((128, 128, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param614
_frozen_param614 = rand_strided((8, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param615
_frozen_param615 = rand_strided((128, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param576
_frozen_param576 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param616
_frozen_param616 = rand_strided((512, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param577
_frozen_param577 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param617
_frozen_param617 = rand_strided((512, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param578
_frozen_param578 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param618
_frozen_param618 = rand_strided((128, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param579
_frozen_param579 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param619
_frozen_param619 = rand_strided((128, 128, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param620
_frozen_param620 = rand_strided((8, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param621
_frozen_param621 = rand_strided((128, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param580
_frozen_param580 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param622
_frozen_param622 = rand_strided((512, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param581
_frozen_param581 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param623
_frozen_param623 = rand_strided((128, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param624
_frozen_param624 = rand_strided((384, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param625
_frozen_param625 = rand_strided((63, 32), (32, 1), device='cpu', dtype=torch.float32)
global _frozen_param626
_frozen_param626 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param627
_frozen_param627 = rand_strided((63, 32), (32, 1), device='cpu', dtype=torch.float32)
global _frozen_param628
_frozen_param628 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param305
_frozen_param305 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param306
_frozen_param306 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param307
_frozen_param307 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param308
_frozen_param308 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param582
_frozen_param582 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param629
_frozen_param629 = rand_strided((512, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param583
_frozen_param583 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param630
_frozen_param630 = rand_strided((256, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param584
_frozen_param584 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param631
_frozen_param631 = rand_strided((256, 256, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param632
_frozen_param632 = rand_strided((16, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param633
_frozen_param633 = rand_strided((256, 16, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param585
_frozen_param585 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param634
_frozen_param634 = rand_strided((1024, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param586
_frozen_param586 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param635
_frozen_param635 = rand_strided((1024, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param587
_frozen_param587 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param636
_frozen_param636 = rand_strided((256, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param588
_frozen_param588 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param637
_frozen_param637 = rand_strided((256, 256, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param638
_frozen_param638 = rand_strided((16, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param639
_frozen_param639 = rand_strided((256, 16, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param589
_frozen_param589 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param640
_frozen_param640 = rand_strided((1024, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param590
_frozen_param590 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param641
_frozen_param641 = rand_strided((256, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param642
_frozen_param642 = rand_strided((768, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param643
_frozen_param643 = rand_strided((31, 64), (64, 1), device='cpu', dtype=torch.float32)
global _frozen_param644
_frozen_param644 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param645
_frozen_param645 = rand_strided((31, 64), (64, 1), device='cpu', dtype=torch.float32)
global _frozen_param646
_frozen_param646 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param349
_frozen_param349 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param350
_frozen_param350 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param351
_frozen_param351 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param352
_frozen_param352 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param591
_frozen_param591 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param647
_frozen_param647 = rand_strided((1024, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param592
_frozen_param592 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param648
_frozen_param648 = rand_strided((512, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param649
_frozen_param649 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param650
_frozen_param650 = rand_strided((31, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param651
_frozen_param651 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param652
_frozen_param652 = rand_strided((31, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param653
_frozen_param653 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param363
_frozen_param363 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param364
_frozen_param364 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param365
_frozen_param365 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param366
_frozen_param366 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param593
_frozen_param593 = rand_strided((1536, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param654
_frozen_param654 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param594
_frozen_param594 = rand_strided((1536, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param655
_frozen_param655 = rand_strided((1536, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param595
_frozen_param595 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param656
_frozen_param656 = rand_strided((512, 1536, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param657
_frozen_param657 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param658
_frozen_param658 = rand_strided((15, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param659
_frozen_param659 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param660
_frozen_param660 = rand_strided((15, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param661
_frozen_param661 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param381
_frozen_param381 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param382
_frozen_param382 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param383
_frozen_param383 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param384
_frozen_param384 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param596
_frozen_param596 = rand_strided((1536, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param662
_frozen_param662 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param597
_frozen_param597 = rand_strided((1280, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param663
_frozen_param663 = rand_strided((1280, 1536, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param664
_frozen_param664 = rand_strided((1000, 1280), (1280, 1), device='cpu', dtype=torch.float32)
global _frozen_param665
_frozen_param665 = rand_strided((3490017, 1), (1, 0), device='cpu', dtype=torch.float32)
arg224_1 = rand_strided((1, 3, 256, 256), (196608, 65536, 256, 1), device='cpu', dtype=torch.float32)
fn = lambda: call([arg224_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('sebotnet33ts_256', benchmark_compiled_module)
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
_frozen_param18 = None # device(type='cpu') torch.float32 (64,) (1,) 7f8740180d10
_frozen_param34 = None # device(type='cpu') torch.float32 (64,) (1,) 7f8740183510
_frozen_param47 = None # device(type='cpu') torch.float32 (128,) (1,) 7f874228fba0
_frozen_param63 = None # device(type='cpu') torch.float32 (128,) (1,) 7f874228ea20
_frozen_param85 = None # device(type='cpu') torch.float32 (16,) (1,) 7f8742285300
_frozen_param87 = None # device(type='cpu') torch.float32 (256,) (1,) 7f8742285f80
_frozen_param101 = None # device(type='cpu') torch.float32 (16,) (1,) 7f874228b330
_frozen_param103 = None # device(type='cpu') torch.float32 (256,) (1,) 7f8742289f30
_frozen_param147 = None # device(type='cpu') torch.float32 (1000,) (1,) 7f8742271df0
_frozen_param564 = None # device(type='cpu') torch.float32 (24,) (1,) 7f872f028720
_frozen_param598 = None # device(type='cpu') torch.float32 (24, 3, 3, 3) (1, 0, 0, 0) 7f872f0b36f0
_frozen_param565 = None # device(type='cpu') torch.float32 (32,) (1,) 7f872ef84cc0
_frozen_param599 = None # device(type='cpu') torch.float32 (32, 24, 3, 3) (1, 0, 0, 0) 7f872ec544f0
_frozen_param566 = None # device(type='cpu') torch.float32 (64,) (1,) 7f872f681030
_frozen_param600 = None # device(type='cpu') torch.float32 (64, 32, 3, 3) (1, 0, 0, 0) 7f872ec54400
_frozen_param567 = None # device(type='cpu') torch.float32 (64,) (1,) 7f872ed47920
_frozen_param601 = None # device(type='cpu') torch.float32 (64, 64, 1, 1) (1, 0, 0, 0) 7f872ec544a0
_frozen_param568 = None # device(type='cpu') torch.float32 (64,) (1,) 7f872f02b420
_frozen_param602 = None # device(type='cpu') torch.float32 (64, 64, 3, 3) (1, 0, 0, 0) 7f872ec54220
_frozen_param603 = None # device(type='cpu') torch.float32 (8, 64, 1, 1) (1, 0, 0, 0) 7f872ec54540
_frozen_param604 = None # device(type='cpu') torch.float32 (64, 8, 1, 1) (1, 0, 0, 0) 7f872ec54590
_frozen_param569 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872f043a10
_frozen_param605 = None # device(type='cpu') torch.float32 (256, 64, 1, 1) (1, 0, 0, 0) 7f872ec545e0
_frozen_param570 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872eec8db0
_frozen_param606 = None # device(type='cpu') torch.float32 (256, 64, 1, 1) (1, 0, 0, 0) 7f872ec54630
_frozen_param571 = None # device(type='cpu') torch.float32 (64,) (1,) 7f872ee743b0
_frozen_param607 = None # device(type='cpu') torch.float32 (64, 256, 1, 1) (1, 0, 0, 0) 7f872ec54680
_frozen_param572 = None # device(type='cpu') torch.float32 (64,) (1,) 7f872ed90ef0
_frozen_param608 = None # device(type='cpu') torch.float32 (64, 64, 3, 3) (1, 0, 0, 0) 7f872ec546d0
_frozen_param609 = None # device(type='cpu') torch.float32 (8, 64, 1, 1) (1, 0, 0, 0) 7f872ec54720
_frozen_param610 = None # device(type='cpu') torch.float32 (64, 8, 1, 1) (1, 0, 0, 0) 7f872ec54770
_frozen_param573 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872edd8ae0
_frozen_param611 = None # device(type='cpu') torch.float32 (256, 64, 1, 1) (1, 0, 0, 0) 7f872ec547c0
_frozen_param574 = None # device(type='cpu') torch.float32 (128,) (1,) 7f872ecfa250
_frozen_param612 = None # device(type='cpu') torch.float32 (128, 256, 1, 1) (1, 0, 0, 0) 7f872ec54810
_frozen_param575 = None # device(type='cpu') torch.float32 (128,) (1,) 7f872eda6700
_frozen_param613 = None # device(type='cpu') torch.float32 (128, 128, 3, 3) (1, 0, 0, 0) 7f872ec54860
_frozen_param614 = None # device(type='cpu') torch.float32 (8, 128, 1, 1) (1, 0, 0, 0) 7f872ec548b0
_frozen_param615 = None # device(type='cpu') torch.float32 (128, 8, 1, 1) (1, 0, 0, 0) 7f872ec54900
_frozen_param576 = None # device(type='cpu') torch.float32 (512,) (1,) 7f872ed32520
_frozen_param616 = None # device(type='cpu') torch.float32 (512, 128, 1, 1) (1, 0, 0, 0) 7f872ec54950
_frozen_param577 = None # device(type='cpu') torch.float32 (512,) (1,) 7f872ebcbe20
_frozen_param617 = None # device(type='cpu') torch.float32 (512, 256, 1, 1) (1, 0, 0, 0) 7f872ec549a0
_frozen_param578 = None # device(type='cpu') torch.float32 (128,) (1,) 7f872ee75030
_frozen_param618 = None # device(type='cpu') torch.float32 (128, 512, 1, 1) (1, 0, 0, 0) 7f872ec549f0
_frozen_param579 = None # device(type='cpu') torch.float32 (128,) (1,) 7f872ebd4ef0
_frozen_param619 = None # device(type='cpu') torch.float32 (128, 128, 3, 3) (1, 0, 0, 0) 7f872ec54a40
_frozen_param620 = None # device(type='cpu') torch.float32 (8, 128, 1, 1) (1, 0, 0, 0) 7f872ec54a90
_frozen_param621 = None # device(type='cpu') torch.float32 (128, 8, 1, 1) (1, 0, 0, 0) 7f872ec54ae0
_frozen_param580 = None # device(type='cpu') torch.float32 (512,) (1,) 7f872ed325c0
_frozen_param622 = None # device(type='cpu') torch.float32 (512, 128, 1, 1) (1, 0, 0, 0) 7f872ec54b30
_frozen_param581 = None # device(type='cpu') torch.float32 (128,) (1,) 7f872ebf87c0
_frozen_param623 = None # device(type='cpu') torch.float32 (128, 512, 1, 1) (1, 0, 0, 0) 7f872ec54b80
_frozen_param624 = None # device(type='cpu') torch.float32 (384, 128, 1, 1) (1, 0, 0, 0) 7f872ec54bd0
_frozen_param625 = None # device(type='cpu') torch.float32 (63, 32) (32, 1) 7f872ed026b0
_frozen_param626 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872ecf8680
_frozen_param627 = None # device(type='cpu') torch.float32 (63, 32) (32, 1) 7f872ebca610
_frozen_param628 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872eda1210
_frozen_param305 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f872eccc3b0
_frozen_param306 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f872eccc310
_frozen_param307 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f872eccc360
_frozen_param308 = None # device(type='cpu') torch.float32 (128, 1, 1) (1, 1, 1) 7f872eccc400
_frozen_param582 = None # device(type='cpu') torch.float32 (512,) (1,) 7f872ebf8810
_frozen_param629 = None # device(type='cpu') torch.float32 (512, 128, 1, 1) (1, 0, 0, 0) 7f872ec54ef0
_frozen_param583 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872ebf8860
_frozen_param630 = None # device(type='cpu') torch.float32 (256, 512, 1, 1) (1, 0, 0, 0) 7f872ec54ea0
_frozen_param584 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872ebf88b0
_frozen_param631 = None # device(type='cpu') torch.float32 (256, 256, 3, 3) (1, 0, 0, 0) 7f872ec54e50
_frozen_param632 = None # device(type='cpu') torch.float32 (16, 256, 1, 1) (1, 0, 0, 0) 7f872ec54e00
_frozen_param633 = None # device(type='cpu') torch.float32 (256, 16, 1, 1) (1, 0, 0, 0) 7f872ec54d60
_frozen_param585 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f872ebf84a0
_frozen_param634 = None # device(type='cpu') torch.float32 (1024, 256, 1, 1) (1, 0, 0, 0) 7f872ec54c70
_frozen_param586 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f872ebf8950
_frozen_param635 = None # device(type='cpu') torch.float32 (1024, 512, 1, 1) (1, 0, 0, 0) 7f872ec54f40
_frozen_param587 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872ed329d0
_frozen_param636 = None # device(type='cpu') torch.float32 (256, 1024, 1, 1) (1, 0, 0, 0) 7f872ec54d10
_frozen_param588 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872ebf86d0
_frozen_param637 = None # device(type='cpu') torch.float32 (256, 256, 3, 3) (1, 0, 0, 0) 7f872ec54cc0
_frozen_param638 = None # device(type='cpu') torch.float32 (16, 256, 1, 1) (1, 0, 0, 0) 7f872ec54c20
_frozen_param639 = None # device(type='cpu') torch.float32 (256, 16, 1, 1) (1, 0, 0, 0) 7f872ec54f90
_frozen_param589 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f872ebf8310
_frozen_param640 = None # device(type='cpu') torch.float32 (1024, 256, 1, 1) (1, 0, 0, 0) 7f872ec54fe0
_frozen_param590 = None # device(type='cpu') torch.float32 (256,) (1,) 7f872ebf8a40
_frozen_param641 = None # device(type='cpu') torch.float32 (256, 1024, 1, 1) (1, 0, 0, 0) 7f872ec55120
_frozen_param642 = None # device(type='cpu') torch.float32 (768, 256, 1, 1) (1, 0, 0, 0) 7f872ec55080
_frozen_param643 = None # device(type='cpu') torch.float32 (31, 64) (64, 1) 7f872ec55300
_frozen_param644 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872ec55350
_frozen_param645 = None # device(type='cpu') torch.float32 (31, 64) (64, 1) 7f872ec55030
_frozen_param646 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872ec551c0
_frozen_param349 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f872eccd440
_frozen_param350 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f872eccd3a0
_frozen_param351 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f872eccd490
_frozen_param352 = None # device(type='cpu') torch.float32 (256, 1, 1) (1, 1, 1) 7f872eccd260
_frozen_param591 = None # device(type='cpu') torch.float32 (1024,) (1,) 7f872ebf8a90
_frozen_param647 = None # device(type='cpu') torch.float32 (1024, 256, 1, 1) (1, 0, 0, 0) 7f872ec55530
_frozen_param592 = None # device(type='cpu') torch.float32 (512,) (1,) 7f872ebf8ae0
_frozen_param648 = None # device(type='cpu') torch.float32 (512, 1024, 1, 1) (1, 0, 0, 0) 7f872ec554e0
_frozen_param649 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f872ec55490
_frozen_param650 = None # device(type='cpu') torch.float32 (31, 128) (128, 1) 7f872ec555d0
_frozen_param651 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872ec55620
_frozen_param652 = None # device(type='cpu') torch.float32 (31, 128) (128, 1) 7f872ec552b0
_frozen_param653 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872ec553f0
_frozen_param363 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872eccd940
_frozen_param364 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872eccd990
_frozen_param365 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872eccd9e0
_frozen_param366 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872eccd8f0
_frozen_param593 = None # device(type='cpu') torch.float32 (1536,) (1,) 7f872ebf8770
_frozen_param654 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f872ec55800
_frozen_param594 = None # device(type='cpu') torch.float32 (1536,) (1,) 7f872ebf8b80
_frozen_param655 = None # device(type='cpu') torch.float32 (1536, 1024, 1, 1) (1, 0, 0, 0) 7f872ec557b0
_frozen_param595 = None # device(type='cpu') torch.float32 (512,) (1,) 7f872ebf8bd0
_frozen_param656 = None # device(type='cpu') torch.float32 (512, 1536, 1, 1) (1, 0, 0, 0) 7f872ec55760
_frozen_param657 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f872ec55710
_frozen_param658 = None # device(type='cpu') torch.float32 (15, 128) (128, 1) 7f872ec558f0
_frozen_param659 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872ec55940
_frozen_param660 = None # device(type='cpu') torch.float32 (15, 128) (128, 1) 7f872ec55170
_frozen_param661 = None # device(type='cpu') torch.float32 (1982689, 1) (1, 0) 7f872ec553a0
_frozen_param381 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872ecce070
_frozen_param382 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872eccdfd0
_frozen_param383 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872ecce0c0
_frozen_param384 = None # device(type='cpu') torch.float32 (512, 1, 1) (1, 1, 1) 7f872eccde90
_frozen_param596 = None # device(type='cpu') torch.float32 (1536,) (1,) 7f872ebf8c20
_frozen_param662 = None # device(type='cpu') torch.float32 (1536, 512, 1, 1) (1, 0, 0, 0) 7f872ec55b20
_frozen_param597 = None # device(type='cpu') torch.float32 (1280,) (1,) 7f872ebf8c70
_frozen_param663 = None # device(type='cpu') torch.float32 (1280, 1536, 1, 1) (1, 0, 0, 0) 7f872ec55ad0
_frozen_param664 = None # device(type='cpu') torch.float32 (1000, 1280) (1280, 1) 7f872ec558a0
_frozen_param665 = None # device(type='cpu') torch.float32 (3490017, 1) (1, 0) 7f872ec55bc0
cpp_fused_silu_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3L); x0+=static_cast<int64_t>(3L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(65536L); x1+=static_cast<int64_t>(16L))
{
alignas(16) float tmp1[3*16];
for (long x0_inner = 0; x0_inner < 3; x0_inner++)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + (65536L*x0) + (65536L*x0_inner)), 16);
tmp0.store(tmp1 + static_cast<int64_t>(16L*x0_inner));
}
at::vec::transpose_mxn<float,3,16>(tmp1, 16, out_ptr0 + static_cast<int64_t>(x0 + (3L*x1)), static_cast<int64_t>(3L));
}
}
}
}
''')
cpp_fused_mean_1 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(4096L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (64L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(4096.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(-0.3303021788597107);
auto tmp9 = static_cast<float>(0.06354581564664841);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(1.774228811264038);
auto tmp14 = static_cast<float>(2.1113927364349365);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(0.32513317465782166);
auto tmp22 = static_cast<float>(1.232210397720337);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(0.7079262137413025);
auto tmp27 = static_cast<float>(0.2353029102087021);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_2 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4096L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)));
}
}
}
}
''')
cpp_fused_silu_3 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1048576L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_4 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(4096L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (64L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(4096.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(0.7516645193099976);
auto tmp9 = static_cast<float>(0.5124969482421875);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(1.2062517404556274);
auto tmp14 = static_cast<float>(0.9069646596908569);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(1.4622137546539307);
auto tmp22 = static_cast<float>(-0.11386305838823318);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(-0.2968502938747406);
auto tmp27 = static_cast<float>(0.884636402130127);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_5 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4096L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (64L*x0)));
}
}
}
}
''')
cpp_fused_silu_6 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1048576L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_7 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (128L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(1024.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(0.47172150015830994);
auto tmp9 = static_cast<float>(1.4283583164215088);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(-0.04577525332570076);
auto tmp14 = static_cast<float>(2.043065309524536);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(0.13726529479026794);
auto tmp22 = static_cast<float>(1.1331775188446045);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(-0.11772552132606506);
auto tmp27 = static_cast<float>(0.527721107006073);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_8 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(128L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)));
}
}
}
}
''')
cpp_fused_silu_9 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(524288L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_10 = async_compile.cpp_pybinding(['float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
float* out_ptr1)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (128L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(128L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(1024.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
{
#pragma omp simd simdlen(8)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
auto tmp0 = x0;
auto tmp1 = c10::convert<int64_t>(tmp0);
auto tmp2 = static_cast<int64_t>(4);
auto tmp3 = tmp1 < tmp2;
auto tmp4 = static_cast<int64_t>(2);
auto tmp5 = tmp1 < tmp4;
auto tmp6 = static_cast<int64_t>(1);
auto tmp7 = tmp1 < tmp6;
auto tmp8 = static_cast<float>(0.29536592960357666);
auto tmp9 = static_cast<float>(1.1849623918533325);
auto tmp10 = tmp7 ? tmp8 : tmp9;
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp1 < tmp11;
auto tmp13 = static_cast<float>(0.3150717318058014);
auto tmp14 = static_cast<float>(0.5096337795257568);
auto tmp15 = tmp12 ? tmp13 : tmp14;
auto tmp16 = tmp5 ? tmp10 : tmp15;
auto tmp17 = static_cast<int64_t>(6);
auto tmp18 = tmp1 < tmp17;
auto tmp19 = static_cast<int64_t>(5);
auto tmp20 = tmp1 < tmp19;
auto tmp21 = static_cast<float>(-0.18543481826782227);
auto tmp22 = static_cast<float>(0.3189537227153778);
auto tmp23 = tmp20 ? tmp21 : tmp22;
auto tmp24 = static_cast<int64_t>(7);
auto tmp25 = tmp1 < tmp24;
auto tmp26 = static_cast<float>(-0.7191315293312073);
auto tmp27 = static_cast<float>(0.42770394682884216);
auto tmp28 = tmp25 ? tmp26 : tmp27;
auto tmp29 = tmp18 ? tmp23 : tmp28;
auto tmp30 = tmp3 ? tmp16 : tmp29;
out_ptr1[static_cast<int64_t>(x0)] = tmp30;
}
}
}
''')
cpp_fused_mul_11 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(128L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)));
}
}
}
}
''')
cpp_fused_silu_12 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(524288L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_13 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (32L*x0) + (384L*x1) + (12288L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (32L*x2) + (1024L*x1) + (32768L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_14 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (32L*x0) + (384L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (32L*x1) + (32768L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_15 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(1024L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (1024L*x1) + (1048576L*x0)), 16);
auto tmp1 = static_cast<float>(0.1767766952966369);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 = 31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L)));
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = static_cast<int64_t>(2048);
auto tmp7 = tmp5 < tmp6;
auto tmp8 = [&]
{
auto tmp9 = static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L);
auto tmp10 = c10::convert<int64_t>(tmp9);
auto tmp11 = static_cast<int64_t>(63);
auto tmp12 = tmp10 < tmp11;
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 = in_ptr1[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))), static_cast<int64_t>(64L)))) + (2016L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L)))];
return tmp15;
}
;
auto tmp16 = tmp12 ? tmp13() : static_cast<float>(0.0);
return tmp16;
}
;
auto tmp17 = tmp7 ? tmp8() : static_cast<float>(0.0);
auto tmp18 = 31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L));
auto tmp19 = c10::convert<int64_t>(tmp18);
auto tmp20 = at::vec::VectorizedN<int64_t,2>::arange(tmp19, 1);
auto tmp21 = at::vec::VectorizedN<int64_t,2>(tmp6);
auto tmp22 = at::vec::VecMask<int64_t,2>(tmp20 < tmp21);
auto tmp23 = [&]
{
auto tmp24 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp22.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp25 = static_cast<int64_t>(63);
auto tmp26 = at::vec::VectorizedN<int64_t,2>(tmp25);
auto tmp27 = at::vec::VecMask<int64_t,2>(tmp24 < tmp26);
auto tmp29 = tmp27 & tmp22;
auto tmp28 = [&]
{
auto tmp30 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp29.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(32L)))), static_cast<int64_t>(64L)))) + (2016L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp30;
}
;
auto tmp33 =
[&]
{
if (tmp29.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp31 = tmp28();
auto tmp32 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp31)::blendv(tmp32, tmp31, tmp29.template cast<float,1>());
}
}
()
;
return tmp33;
}
;
auto tmp36 =
[&]
{
if (tmp22.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp34 = tmp23();
auto tmp35 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp34)::blendv(tmp35, tmp34, tmp22.template cast<float,1>());
}
}
()
;
auto tmp37 = at::vec::Vectorized<float>(tmp17);
auto tmp38 = tmp37 + tmp36;
auto tmp39 = tmp3 + tmp38;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp39);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr0[static_cast<int64_t>(x1 + (1024L*x0))] = static_cast<float>(tmp_acc0);
}
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(1024L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (1024L*x1) + (1048576L*x0)), 16);
auto tmp40 = out_ptr0[static_cast<int64_t>(x1 + (1024L*x0))];
auto tmp1 = static_cast<float>(0.1767766952966369);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 = 31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L)));
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = static_cast<int64_t>(2048);
auto tmp7 = tmp5 < tmp6;
auto tmp8 = [&]
{
auto tmp9 = static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L);
auto tmp10 = c10::convert<int64_t>(tmp9);
auto tmp11 = static_cast<int64_t>(63);
auto tmp12 = tmp10 < tmp11;
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 = in_ptr1[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))), static_cast<int64_t>(64L)))) + (2016L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(32L))))) % static_cast<int64_t>(64L)))];
return tmp15;
}
;
auto tmp16 = tmp12 ? tmp13() : static_cast<float>(0.0);
return tmp16;
}
;
auto tmp17 = tmp7 ? tmp8() : static_cast<float>(0.0);
auto tmp18 = 31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(32L));
auto tmp19 = c10::convert<int64_t>(tmp18);
auto tmp20 = at::vec::VectorizedN<int64_t,2>::arange(tmp19, 1);
auto tmp21 = at::vec::VectorizedN<int64_t,2>(tmp6);
auto tmp22 = at::vec::VecMask<int64_t,2>(tmp20 < tmp21);
auto tmp23 = [&]
{
auto tmp24 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp22.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp25 = static_cast<int64_t>(63);
auto tmp26 = at::vec::VectorizedN<int64_t,2>(tmp25);
auto tmp27 = at::vec::VecMask<int64_t,2>(tmp24 < tmp26);
auto tmp29 = tmp27 & tmp22;
auto tmp28 = [&]
{
auto tmp30 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp29.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((63L*(c10::div_floor_integer(static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(32L)))), static_cast<int64_t>(64L)))) + (2016L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (64512L*x0) + (static_cast<int64_t>((31L + (63L*(static_cast<int64_t>(x1) % static_cast<int64_t>(32L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(32L)))) % static_cast<int64_t>(64L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp30;
}
;
auto tmp33 =
[&]
{
if (tmp29.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp31 = tmp28();
auto tmp32 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp31)::blendv(tmp32, tmp31, tmp29.template cast<float,1>());
}
}
()
;
return tmp33;
}
;
auto tmp36 =
[&]
{
if (tmp22.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp34 = tmp23();
auto tmp35 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp34)::blendv(tmp35, tmp34, tmp22.template cast<float,1>());
}
}
()
;
auto tmp37 = at::vec::Vectorized<float>(tmp17);
auto tmp38 = tmp37 + tmp36;
auto tmp39 = tmp3 + tmp38;
auto tmp41 = at::vec::Vectorized<float>(tmp40);
auto tmp42 = tmp39 - tmp41;
auto tmp43 = tmp42.exp();
tmp43.store(out_ptr1 + static_cast<int64_t>(x2 + (1024L*x1) + (1048576L*x0)));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4096L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (1024L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(1024L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (1024L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (1024L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_silu_16 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(128L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((32L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(32L)))) + (static_cast<int64_t>(x1) % static_cast<int64_t>(32L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x1), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 - tmp1;
auto tmp4 = tmp2 * tmp3;
auto tmp6 = tmp4 * tmp5;
auto tmp8 = tmp6 + tmp7;
auto tmp9 = decltype(tmp8)(1)/(decltype(tmp8)(1) + tmp8.neg().exp());
auto tmp10 = tmp8 * tmp9;
tmp10.store(in_out_ptr0 + static_cast<int64_t>(x1 + (128L*x0)));
}
}
}
}
''')
cpp_fused_silu_17 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(524288L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_18 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (256L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(256.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mul_19 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused_silu_20 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(262144L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_21 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (256L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(256.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mul_22 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 * tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused_silu_23 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(262144L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_24 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(16L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(64L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (64L*x0) + (768L*x1) + (12288L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (64L*x2) + (1024L*x1) + (16384L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_25 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (64L*x0) + (768L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (64L*x1) + (16384L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_26 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0)), 16);
auto tmp1 = static_cast<float>(0.125);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = static_cast<int64_t>(512);
auto tmp7 = tmp5 < tmp6;
auto tmp8 = [&]
{
auto tmp9 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp10 = c10::convert<int64_t>(tmp9);
auto tmp11 = static_cast<int64_t>(31);
auto tmp12 = tmp10 < tmp11;
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp15;
}
;
auto tmp16 = tmp12 ? tmp13() : static_cast<float>(0.0);
return tmp16;
}
;
auto tmp17 = tmp7 ? tmp8() : static_cast<float>(0.0);
auto tmp18 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp19 = c10::convert<int64_t>(tmp18);
auto tmp20 = at::vec::VectorizedN<int64_t,2>::arange(tmp19, 1);
auto tmp21 = at::vec::VectorizedN<int64_t,2>(tmp6);
auto tmp22 = at::vec::VecMask<int64_t,2>(tmp20 < tmp21);
auto tmp23 = [&]
{
auto tmp24 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp22.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp25 = static_cast<int64_t>(31);
auto tmp26 = at::vec::VectorizedN<int64_t,2>(tmp25);
auto tmp27 = at::vec::VecMask<int64_t,2>(tmp24 < tmp26);
auto tmp29 = tmp27 & tmp22;
auto tmp28 = [&]
{
auto tmp30 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp29.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp30;
}
;
auto tmp33 =
[&]
{
if (tmp29.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp31 = tmp28();
auto tmp32 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp31)::blendv(tmp32, tmp31, tmp29.template cast<float,1>());
}
}
()
;
return tmp33;
}
;
auto tmp36 =
[&]
{
if (tmp22.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp34 = tmp23();
auto tmp35 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp34)::blendv(tmp35, tmp34, tmp22.template cast<float,1>());
}
}
()
;
auto tmp37 = at::vec::Vectorized<float>(tmp17);
auto tmp38 = tmp37 + tmp36;
auto tmp39 = tmp3 + tmp38;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp39);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr0[static_cast<int64_t>(x1 + (256L*x0))] = static_cast<float>(tmp_acc0);
}
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0)), 16);
auto tmp40 = out_ptr0[static_cast<int64_t>(x1 + (256L*x0))];
auto tmp1 = static_cast<float>(0.125);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = static_cast<int64_t>(512);
auto tmp7 = tmp5 < tmp6;
auto tmp8 = [&]
{
auto tmp9 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp10 = c10::convert<int64_t>(tmp9);
auto tmp11 = static_cast<int64_t>(31);
auto tmp12 = tmp10 < tmp11;
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp15;
}
;
auto tmp16 = tmp12 ? tmp13() : static_cast<float>(0.0);
return tmp16;
}
;
auto tmp17 = tmp7 ? tmp8() : static_cast<float>(0.0);
auto tmp18 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp19 = c10::convert<int64_t>(tmp18);
auto tmp20 = at::vec::VectorizedN<int64_t,2>::arange(tmp19, 1);
auto tmp21 = at::vec::VectorizedN<int64_t,2>(tmp6);
auto tmp22 = at::vec::VecMask<int64_t,2>(tmp20 < tmp21);
auto tmp23 = [&]
{
auto tmp24 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp22.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp25 = static_cast<int64_t>(31);
auto tmp26 = at::vec::VectorizedN<int64_t,2>(tmp25);
auto tmp27 = at::vec::VecMask<int64_t,2>(tmp24 < tmp26);
auto tmp29 = tmp27 & tmp22;
auto tmp28 = [&]
{
auto tmp30 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp29.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp30;
}
;
auto tmp33 =
[&]
{
if (tmp29.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp31 = tmp28();
auto tmp32 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp31)::blendv(tmp32, tmp31, tmp29.template cast<float,1>());
}
}
()
;
return tmp33;
}
;
auto tmp36 =
[&]
{
if (tmp22.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp34 = tmp23();
auto tmp35 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp34)::blendv(tmp35, tmp34, tmp22.template cast<float,1>());
}
}
()
;
auto tmp37 = at::vec::Vectorized<float>(tmp17);
auto tmp38 = tmp37 + tmp36;
auto tmp39 = tmp3 + tmp38;
auto tmp41 = at::vec::Vectorized<float>(tmp40);
auto tmp42 = tmp39 - tmp41;
auto tmp43 = tmp42.exp();
tmp43.store(out_ptr1 + static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0)));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_silu_27 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((64L*x0) + (16384L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(64L)))) + (static_cast<int64_t>(x1) % static_cast<int64_t>(64L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x1), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 - tmp1;
auto tmp4 = tmp2 * tmp3;
auto tmp6 = tmp4 * tmp5;
auto tmp8 = tmp6 + tmp7;
auto tmp9 = decltype(tmp8)(1)/(decltype(tmp8)(1) + tmp8.neg().exp());
auto tmp10 = tmp8 * tmp9;
tmp10.store(in_out_ptr0 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused_silu_28 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(262144L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_29 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(16L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(128L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (128L*x0) + (1536L*x1) + (24576L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (128L*x2) + (2048L*x1) + (32768L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_30 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(128L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (128L*x0) + (1536L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (128L*x1) + (32768L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_31 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0)), 16);
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = static_cast<int64_t>(512);
auto tmp7 = tmp5 < tmp6;
auto tmp8 = [&]
{
auto tmp9 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp10 = c10::convert<int64_t>(tmp9);
auto tmp11 = static_cast<int64_t>(31);
auto tmp12 = tmp10 < tmp11;
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp15;
}
;
auto tmp16 = tmp12 ? tmp13() : static_cast<float>(0.0);
return tmp16;
}
;
auto tmp17 = tmp7 ? tmp8() : static_cast<float>(0.0);
auto tmp18 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp19 = c10::convert<int64_t>(tmp18);
auto tmp20 = at::vec::VectorizedN<int64_t,2>::arange(tmp19, 1);
auto tmp21 = at::vec::VectorizedN<int64_t,2>(tmp6);
auto tmp22 = at::vec::VecMask<int64_t,2>(tmp20 < tmp21);
auto tmp23 = [&]
{
auto tmp24 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp22.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp25 = static_cast<int64_t>(31);
auto tmp26 = at::vec::VectorizedN<int64_t,2>(tmp25);
auto tmp27 = at::vec::VecMask<int64_t,2>(tmp24 < tmp26);
auto tmp29 = tmp27 & tmp22;
auto tmp28 = [&]
{
auto tmp30 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp29.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp30;
}
;
auto tmp33 =
[&]
{
if (tmp29.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp31 = tmp28();
auto tmp32 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp31)::blendv(tmp32, tmp31, tmp29.template cast<float,1>());
}
}
()
;
return tmp33;
}
;
auto tmp36 =
[&]
{
if (tmp22.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp34 = tmp23();
auto tmp35 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp34)::blendv(tmp35, tmp34, tmp22.template cast<float,1>());
}
}
()
;
auto tmp37 = at::vec::Vectorized<float>(tmp17);
auto tmp38 = tmp37 + tmp36;
auto tmp39 = tmp3 + tmp38;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp39);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr0[static_cast<int64_t>(x1 + (256L*x0))] = static_cast<float>(tmp_acc0);
}
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(256L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0)), 16);
auto tmp40 = out_ptr0[static_cast<int64_t>(x1 + (256L*x0))];
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 = 15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L)));
auto tmp5 = c10::convert<int64_t>(tmp4);
auto tmp6 = static_cast<int64_t>(512);
auto tmp7 = tmp5 < tmp6;
auto tmp8 = [&]
{
auto tmp9 = static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L);
auto tmp10 = c10::convert<int64_t>(tmp9);
auto tmp11 = static_cast<int64_t>(31);
auto tmp12 = tmp10 < tmp11;
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 = in_ptr1[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))), static_cast<int64_t>(32L)))) + (496L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(16L))))) % static_cast<int64_t>(32L)))];
return tmp15;
}
;
auto tmp16 = tmp12 ? tmp13() : static_cast<float>(0.0);
return tmp16;
}
;
auto tmp17 = tmp7 ? tmp8() : static_cast<float>(0.0);
auto tmp18 = 15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(16L));
auto tmp19 = c10::convert<int64_t>(tmp18);
auto tmp20 = at::vec::VectorizedN<int64_t,2>::arange(tmp19, 1);
auto tmp21 = at::vec::VectorizedN<int64_t,2>(tmp6);
auto tmp22 = at::vec::VecMask<int64_t,2>(tmp20 < tmp21);
auto tmp23 = [&]
{
auto tmp24 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp22.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp25 = static_cast<int64_t>(31);
auto tmp26 = at::vec::VectorizedN<int64_t,2>(tmp25);
auto tmp27 = at::vec::VecMask<int64_t,2>(tmp24 < tmp26);
auto tmp29 = tmp27 & tmp22;
auto tmp28 = [&]
{
auto tmp30 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp29.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((31L*(c10::div_floor_integer(static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))), static_cast<int64_t>(32L)))) + (496L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(16L)))) + (7936L*x0) + (static_cast<int64_t>((15L + (31L*(static_cast<int64_t>(x1) % static_cast<int64_t>(16L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(16L)))) % static_cast<int64_t>(32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp30;
}
;
auto tmp33 =
[&]
{
if (tmp29.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp31 = tmp28();
auto tmp32 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp31)::blendv(tmp32, tmp31, tmp29.template cast<float,1>());
}
}
()
;
return tmp33;
}
;
auto tmp36 =
[&]
{
if (tmp22.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp34 = tmp23();
auto tmp35 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp34)::blendv(tmp35, tmp34, tmp22.template cast<float,1>());
}
}
()
;
auto tmp37 = at::vec::Vectorized<float>(tmp17);
auto tmp38 = tmp37 + tmp36;
auto tmp39 = tmp3 + tmp38;
auto tmp41 = at::vec::Vectorized<float>(tmp40);
auto tmp42 = tmp39 - tmp41;
auto tmp43 = tmp42.exp();
tmp43.store(out_ptr1 + static_cast<int64_t>(x2 + (256L*x1) + (65536L*x0)));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1024L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(256L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (256L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (256L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_avg_pool2d_silu_32 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(8L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(8L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(512L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(128L + (256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(2048L + (256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(2176L + (256L*x1) + (4096L*x0) + (32768L*(c10::div_floor_integer(static_cast<int64_t>(x2), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x2) % static_cast<int64_t>(128L))), 16);
auto tmp10 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x2), 16);
auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x2), 16);
auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x2), 16);
auto tmp16 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x2), 16);
auto tmp2 = tmp1 + tmp0;
auto tmp4 = tmp3 + tmp2;
auto tmp6 = tmp5 + tmp4;
auto tmp7 = static_cast<float>(0.25);
auto tmp8 = at::vec::Vectorized<float>(tmp7);
auto tmp9 = tmp6 * tmp8;
auto tmp11 = tmp9 - tmp10;
auto tmp13 = tmp11 * tmp12;
auto tmp15 = tmp13 * tmp14;
auto tmp17 = tmp15 + tmp16;
auto tmp18 = decltype(tmp17)(1)/(decltype(tmp17)(1) + tmp17.neg().exp());
auto tmp19 = tmp17 * tmp18;
tmp19.store(in_out_ptr0 + static_cast<int64_t>(x2 + (512L*x1) + (4096L*x0)));
}
}
}
}
}
''')
cpp_fused_silu_33 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_clone_34 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(8L); x1+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(8L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(128L); x3+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (128L*x0) + (1536L*x1) + (12288L*x2)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x3 + (128L*x2) + (1024L*x1) + (8192L*x0)));
}
}
}
}
}
}
''')
cpp_fused_clone_35 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(128L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (128L*x0) + (1536L*x1)), 16);
tmp0.store(out_ptr0 + static_cast<int64_t>(x2 + (128L*x1) + (8192L*x0)));
}
}
}
}
}
''')
cpp_fused__softmax_add_mul_36 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(4L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (64L*x1) + (4096L*x0)), 16);
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = static_cast<int64_t>(7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))));
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp5 = static_cast<int64_t>(128);
auto tmp6 = at::vec::VectorizedN<int64_t,2>(tmp5);
auto tmp7 = at::vec::VecMask<int64_t,2>(tmp4 < tmp6);
auto tmp8 = [&]
{
auto tmp9 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp7.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp10 = static_cast<int64_t>(15);
auto tmp11 = at::vec::VectorizedN<int64_t,2>(tmp10);
auto tmp12 = at::vec::VecMask<int64_t,2>(tmp9 < tmp11);
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp14.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr1[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))))), static_cast<int64_t>(16L)))) + (120L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp15;
}
;
auto tmp18 =
[&]
{
if (tmp14.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp16 = tmp13();
auto tmp17 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp16)::blendv(tmp17, tmp16, tmp14.template cast<float,1>());
}
}
()
;
return tmp18;
}
;
auto tmp21 =
[&]
{
if (tmp7.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp19 = tmp8();
auto tmp20 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp19)::blendv(tmp20, tmp19, tmp7.template cast<float,1>());
}
}
()
;
auto tmp22 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = static_cast<int64_t>(7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)));
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp23 = at::vec::VecMask<int64_t,2>(tmp22 < tmp6);
auto tmp24 = [&]
{
auto tmp25 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp23.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp26 = static_cast<int64_t>(15);
auto tmp27 = at::vec::VectorizedN<int64_t,2>(tmp26);
auto tmp28 = at::vec::VecMask<int64_t,2>(tmp25 < tmp27);
auto tmp30 = tmp28 & tmp23;
auto tmp29 = [&]
{
auto tmp31 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp30.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)))), static_cast<int64_t>(16L)))) + (120L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp31;
}
;
auto tmp34 =
[&]
{
if (tmp30.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp32 = tmp29();
auto tmp33 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp32)::blendv(tmp33, tmp32, tmp30.template cast<float,1>());
}
}
()
;
return tmp34;
}
;
auto tmp37 =
[&]
{
if (tmp23.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp35 = tmp24();
auto tmp36 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp35)::blendv(tmp36, tmp35, tmp23.template cast<float,1>());
}
}
()
;
auto tmp38 = tmp21 + tmp37;
auto tmp39 = tmp3 + tmp38;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp39);
}
tmp_acc0 = max_propagate_nan(tmp_acc0, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return at::vec::maximum(x, y); }, tmp_acc0_vec));
out_ptr0[static_cast<int64_t>(x1 + (64L*x0))] = static_cast<float>(tmp_acc0);
}
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + (64L*x1) + (4096L*x0)), 16);
auto tmp40 = out_ptr0[static_cast<int64_t>(x1 + (64L*x0))];
auto tmp1 = static_cast<float>(0.08838834764831845);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp4 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = static_cast<int64_t>(7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))));
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp5 = static_cast<int64_t>(128);
auto tmp6 = at::vec::VectorizedN<int64_t,2>(tmp5);
auto tmp7 = at::vec::VecMask<int64_t,2>(tmp4 < tmp6);
auto tmp8 = [&]
{
auto tmp9 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp7.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp10 = static_cast<int64_t>(15);
auto tmp11 = at::vec::VectorizedN<int64_t,2>(tmp10);
auto tmp12 = at::vec::VecMask<int64_t,2>(tmp9 < tmp11);
auto tmp14 = tmp12 & tmp7;
auto tmp13 = [&]
{
auto tmp15 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp14.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr1[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))))), static_cast<int64_t>(16L)))) + (120L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (c10::div_floor_integer(static_cast<int64_t>((x2 + x2_inner)), static_cast<int64_t>(8L))))) % static_cast<int64_t>(16L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp15;
}
;
auto tmp18 =
[&]
{
if (tmp14.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp16 = tmp13();
auto tmp17 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp16)::blendv(tmp17, tmp16, tmp14.template cast<float,1>());
}
}
()
;
return tmp18;
}
;
auto tmp21 =
[&]
{
if (tmp7.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp19 = tmp8();
auto tmp20 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp19)::blendv(tmp20, tmp19, tmp7.template cast<float,1>());
}
}
()
;
auto tmp22 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
tmpbuf[x2_inner] = static_cast<int64_t>(7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)));
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp23 = at::vec::VecMask<int64_t,2>(tmp22 < tmp6);
auto tmp24 = [&]
{
auto tmp25 =
[&]
{
__at_align__ std::array<int64_t, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp23.is_masked(x2_inner))
{
tmpbuf[x2_inner] = static_cast<int64_t>(static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L));
}
}
return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
}
()
;
auto tmp26 = static_cast<int64_t>(15);
auto tmp27 = at::vec::VectorizedN<int64_t,2>(tmp26);
auto tmp28 = at::vec::VecMask<int64_t,2>(tmp25 < tmp27);
auto tmp30 = tmp28 & tmp23;
auto tmp29 = [&]
{
auto tmp31 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x2_inner = 0; x2_inner < 16; x2_inner++)
{
if (tmp30.is_masked(x2_inner))
{
tmpbuf[x2_inner] = in_ptr2[static_cast<int64_t>((15L*(c10::div_floor_integer(static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)))), static_cast<int64_t>(16L)))) + (120L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(8L)))) + (960L*x0) + (static_cast<int64_t>((7L + (15L*(static_cast<int64_t>(x1) % static_cast<int64_t>(8L))) + (static_cast<int64_t>((x2 + x2_inner)) % static_cast<int64_t>(8L)))) % static_cast<int64_t>(16L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
}
()
;
return tmp31;
}
;
auto tmp34 =
[&]
{
if (tmp30.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp32 = tmp29();
auto tmp33 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp32)::blendv(tmp33, tmp32, tmp30.template cast<float,1>());
}
}
()
;
return tmp34;
}
;
auto tmp37 =
[&]
{
if (tmp23.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
auto tmp35 = tmp24();
auto tmp36 = at::vec::Vectorized<float>(static_cast<float>(0.0));
return decltype(tmp35)::blendv(tmp36, tmp35, tmp23.template cast<float,1>());
}
}
()
;
auto tmp38 = tmp21 + tmp37;
auto tmp39 = tmp3 + tmp38;
auto tmp41 = at::vec::Vectorized<float>(tmp40);
auto tmp42 = tmp39 - tmp41;
auto tmp43 = tmp42.exp();
tmp43.store(out_ptr1 + static_cast<int64_t>(x2 + (64L*x1) + (4096L*x0)));
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(256L); x0+=static_cast<int64_t>(1L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (64L*x0)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr2[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0);
}
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + static_cast<int64_t>(x1 + (64L*x0)), 16);
auto tmp1 = out_ptr2[static_cast<int64_t>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(out_ptr3 + static_cast<int64_t>(x1 + (64L*x0)));
}
}
}
}
''')
cpp_fused__native_batch_norm_legit_no_training_silu_37 = async_compile.cpp_pybinding(['float*', 'const float*', 'const float*', 'const float*', 'const float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
const float* in_ptr3,
const float* in_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(64L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(512L); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>((128L*x0) + (8192L*(c10::div_floor_integer(static_cast<int64_t>(x1), static_cast<int64_t>(128L)))) + (static_cast<int64_t>(x1) % static_cast<int64_t>(128L))), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x1), 16);
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x1), 16);
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(x1), 16);
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr4 + static_cast<int64_t>(x1), 16);
auto tmp2 = tmp0 - tmp1;
auto tmp4 = tmp2 * tmp3;
auto tmp6 = tmp4 * tmp5;
auto tmp8 = tmp6 + tmp7;
auto tmp9 = decltype(tmp8)(1)/(decltype(tmp8)(1) + tmp8.neg().exp());
auto tmp10 = tmp8 * tmp9;
tmp10.store(in_out_ptr0 + static_cast<int64_t>(x1 + (512L*x0)));
}
}
}
}
''')
cpp_fused_silu_38 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = decltype(tmp0)(1)/(decltype(tmp0)(1) + tmp0.neg().exp());
auto tmp2 = tmp0 * tmp1;
tmp2.store(out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_mean_39 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
auto out_ptr0 = in_out_ptr0;
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1280L); x0+=static_cast<int64_t>(16L))
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(64L); x1+=static_cast<int64_t>(1L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0 + (1280L*x1)), 16);
tmp_acc0_vec = tmp_acc0_vec + tmp0;
}
tmp_acc0_vec.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1280L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = static_cast<float>(64.0);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 / tmp2;
tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
}
}
''')
cpp_fused_addmm_40 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include "/tmp/torchinductor_leslie/j2/cj22tgrdgh42wbunl7gdptg2lintcziox2kmr7rdbcc6n2njrhgx.h"
extern "C" void kernel(float* in_out_ptr0,
const float* in_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(992L); x0+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 16);
auto tmp2 = tmp0 + tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x0));
}
for(int64_t x0=static_cast<int64_t>(992L); x0<static_cast<int64_t>(1000L); x0+=static_cast<int64_t>(8L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<int64_t>(x0), 8);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), 8);
auto tmp2 = tmp0 + tmp1;
tmp2.store(in_out_ptr0 + static_cast<int64_t>(x0), 8);
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg224_1, = args
args.clear()
assert_size_stride(arg224_1, (1, 3, 256, 256), (196608, 65536, 256, 1))
buf0 = empty_strided_cpu((1, 3, 256, 256), (196608, 1, 768, 3), torch.float32)
cpp_fused_silu_0(arg224_1, buf0)
del arg224_1
buf1 = torch.ops.mkldnn._convolution_pointwise.default(buf0, _frozen_param598, _frozen_param564, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf1, (1, 24, 128, 128), (393216, 1, 3072, 24))
del buf0
buf2 = torch.ops.mkldnn._convolution_pointwise.default(buf1, _frozen_param599, _frozen_param565, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf2, (1, 32, 128, 128), (524288, 1, 4096, 32))
del buf1
buf3 = torch.ops.mkldnn._convolution_pointwise.default(buf2, _frozen_param600, _frozen_param566, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf3, (1, 64, 64, 64), (262144, 1, 4096, 64))
del buf2
buf4 = torch.ops.mkldnn._convolution_pointwise.default(buf3, _frozen_param601, _frozen_param567, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf4, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf5 = torch.ops.mkldnn._convolution_pointwise.default(buf4, _frozen_param602, _frozen_param568, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf5, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf6 = empty_strided_cpu((1, 64, 1, 1), (64, 1, 64, 64), torch.float32)
buf7 = reinterpret_tensor(buf6, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf6 # reuse
buf8 = empty_strided_cpu((8, ), (1, ), torch.float32)
cpp_fused_mean_1(buf7, buf5, buf8)
buf9 = torch.ops.mkldnn._convolution_pointwise.default(buf7, _frozen_param603, buf8, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf9, (1, 8, 1, 1), (8, 1, 8, 8))
del buf7
del buf8
buf10 = torch.ops.mkldnn._convolution_pointwise.default(buf9, _frozen_param604, _frozen_param18, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf10, (1, 64, 1, 1), (64, 1, 64, 64))
buf11 = buf5; del buf5 # reuse
cpp_fused_mul_2(buf11, buf10)
buf12 = torch.ops.mkldnn._convolution_pointwise.default(buf11, _frozen_param605, _frozen_param569, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf12, (1, 256, 64, 64), (1048576, 1, 16384, 256))
del buf11
buf13 = torch.ops.mkldnn._convolution_pointwise_.binary(buf12, buf3, _frozen_param606, _frozen_param570, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf3
del buf4
buf16 = empty_strided_cpu((1, 256, 64, 64), (1048576, 1, 16384, 256), torch.float32)
cpp_fused_silu_3(buf12, buf16)
buf17 = torch.ops.mkldnn._convolution_pointwise.default(buf16, _frozen_param607, _frozen_param571, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf17, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf18 = torch.ops.mkldnn._convolution_pointwise.default(buf17, _frozen_param608, _frozen_param572, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf18, (1, 64, 64, 64), (262144, 1, 4096, 64))
buf19 = buf10; del buf10 # reuse
buf20 = reinterpret_tensor(buf19, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf19 # reuse
buf21 = reinterpret_tensor(buf9, (8, ), (1, ), 0); del buf9 # reuse
cpp_fused_mean_4(buf20, buf18, buf21)
buf22 = torch.ops.mkldnn._convolution_pointwise.default(buf20, _frozen_param609, buf21, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf22, (1, 8, 1, 1), (8, 1, 8, 8))
del buf20
del buf21
buf23 = torch.ops.mkldnn._convolution_pointwise.default(buf22, _frozen_param610, _frozen_param34, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf23, (1, 64, 1, 1), (64, 1, 64, 64))
buf24 = buf18; del buf18 # reuse
cpp_fused_mul_5(buf24, buf23)
del buf23
buf25 = torch.ops.mkldnn._convolution_pointwise_.binary(buf16, buf24, _frozen_param611, _frozen_param573, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
buf28 = buf12; del buf12 # reuse
cpp_fused_silu_6(buf16, buf28)
del buf16
buf29 = torch.ops.mkldnn._convolution_pointwise.default(buf28, _frozen_param612, _frozen_param574, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf29, (1, 128, 64, 64), (524288, 1, 8192, 128))
buf30 = torch.ops.mkldnn._convolution_pointwise.default(buf29, _frozen_param613, _frozen_param575, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf30, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf31 = empty_strided_cpu((1, 128, 1, 1), (128, 1, 128, 128), torch.float32)
buf32 = reinterpret_tensor(buf31, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf31 # reuse
buf33 = reinterpret_tensor(buf22, (8, ), (1, ), 0); del buf22 # reuse
cpp_fused_mean_7(buf32, buf30, buf33)
buf34 = torch.ops.mkldnn._convolution_pointwise.default(buf32, _frozen_param614, buf33, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf34, (1, 8, 1, 1), (8, 1, 8, 8))
del buf32
del buf33
buf35 = torch.ops.mkldnn._convolution_pointwise.default(buf34, _frozen_param615, _frozen_param47, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf35, (1, 128, 1, 1), (128, 1, 128, 128))
buf36 = buf30; del buf30 # reuse
cpp_fused_mul_8(buf36, buf35)
buf37 = torch.ops.mkldnn._convolution_pointwise.default(buf36, _frozen_param616, _frozen_param576, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf37, (1, 512, 32, 32), (524288, 1, 16384, 512))
del buf36
buf38 = torch.ops.mkldnn._convolution_pointwise_.binary(buf37, buf28, _frozen_param617, _frozen_param577, [0, 0], [2, 2], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf28
buf41 = reinterpret_tensor(buf29, (1, 512, 32, 32), (524288, 1, 16384, 512), 0); del buf29 # reuse
cpp_fused_silu_9(buf37, buf41)
buf42 = torch.ops.mkldnn._convolution_pointwise.default(buf41, _frozen_param618, _frozen_param578, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf42, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf43 = torch.ops.mkldnn._convolution_pointwise.default(buf42, _frozen_param619, _frozen_param579, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf43, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf44 = buf35; del buf35 # reuse
buf45 = reinterpret_tensor(buf44, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf44 # reuse
buf46 = reinterpret_tensor(buf34, (8, ), (1, ), 0); del buf34 # reuse
cpp_fused_mean_10(buf45, buf43, buf46)
buf47 = torch.ops.mkldnn._convolution_pointwise.default(buf45, _frozen_param620, buf46, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf47, (1, 8, 1, 1), (8, 1, 8, 8))
del buf45
del buf46
buf48 = torch.ops.mkldnn._convolution_pointwise.default(buf47, _frozen_param621, _frozen_param63, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf48, (1, 128, 1, 1), (128, 1, 128, 128))
del buf47
buf49 = buf43; del buf43 # reuse
cpp_fused_mul_11(buf49, buf48)
del buf48
buf50 = torch.ops.mkldnn._convolution_pointwise_.binary(buf41, buf49, _frozen_param622, _frozen_param580, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
buf53 = buf37; del buf37 # reuse
cpp_fused_silu_12(buf41, buf53)
buf54 = torch.ops.mkldnn._convolution_pointwise.default(buf53, _frozen_param623, _frozen_param581, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf54, (1, 128, 32, 32), (131072, 1, 4096, 128))
buf55 = torch.ops.mkldnn._convolution_pointwise.default(buf54, _frozen_param624, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf55, (1, 384, 32, 32), (393216, 1, 12288, 384))
buf56 = empty_strided_cpu((4, 1024, 1024), (1048576, 1024, 1), torch.float32)
# Topologically Sorted Source Nodes: [matmul], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf55, (4, 1024, 32), (32, 384, 1), 0), reinterpret_tensor(buf55, (4, 32, 1024), (32, 1, 384), 128), out=buf56)
buf57 = reinterpret_tensor(buf49, (4, 32, 32, 32), (32768, 1024, 32, 1), 0); del buf49 # reuse
cpp_fused_clone_13(buf55, buf57)
buf58 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf57, (4096, 32), (32, 1), 0), _frozen_param628, _frozen_param627, None, 4096)
buf59 = buf57; del buf57 # reuse
cpp_fused_clone_14(buf55, buf59)
buf60 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf59, (4096, 32), (32, 1), 0), _frozen_param626, _frozen_param625, None, 4096)
buf61 = empty_strided_cpu((4, 1024, 1), (1024, 1, 4096), torch.float32)
buf62 = empty_strided_cpu((4, 1024, 1024), (1048576, 1024, 1), torch.float32)
buf63 = empty_strided_cpu((4, 1024, 1), (1024, 1, 4096), torch.float32)
buf64 = empty_strided_cpu((4, 1024, 1024), (1048576, 1024, 1), torch.float32)
cpp_fused__softmax_add_mul_15(buf56, buf58, buf60, buf61, buf62, buf63, buf64)
del buf56
del buf58
del buf60
del buf61
del buf62
del buf63
buf65 = reinterpret_tensor(buf59, (4, 1024, 32), (32768, 32, 1), 0); del buf59 # reuse
# Topologically Sorted Source Nodes: [attn_1, matmul_3], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf64, reinterpret_tensor(buf55, (4, 1024, 32), (32, 384, 1), 256), out=buf65)
del buf55
del buf64
buf66 = buf42; del buf42 # reuse
buf67 = buf66; del buf66 # reuse
cpp_fused__native_batch_norm_legit_no_training_silu_16(buf67, buf65, _frozen_param305, _frozen_param306, _frozen_param307, _frozen_param308)
del buf65
buf68 = torch.ops.mkldnn._convolution_pointwise_.binary(buf53, buf67, _frozen_param629, _frozen_param582, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf54
buf71 = buf41; del buf41 # reuse
cpp_fused_silu_17(buf53, buf71)
del buf53
buf72 = torch.ops.mkldnn._convolution_pointwise.default(buf71, _frozen_param630, _frozen_param583, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf72, (1, 256, 32, 32), (262144, 1, 8192, 256))
buf73 = torch.ops.mkldnn._convolution_pointwise.default(buf72, _frozen_param631, _frozen_param584, [1, 1], [2, 2], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf73, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf74 = empty_strided_cpu((1, 256, 1, 1), (256, 1, 256, 256), torch.float32)
buf75 = reinterpret_tensor(buf74, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf74 # reuse
cpp_fused_mean_18(buf75, buf73)
buf76 = torch.ops.mkldnn._convolution_pointwise.default(buf75, _frozen_param632, _frozen_param85, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf76, (1, 16, 1, 1), (16, 1, 16, 16))
del buf75
buf77 = torch.ops.mkldnn._convolution_pointwise.default(buf76, _frozen_param633, _frozen_param87, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf77, (1, 256, 1, 1), (256, 1, 256, 256))
del buf76
buf78 = buf73; del buf73 # reuse
cpp_fused_mul_19(buf78, buf77)
buf79 = torch.ops.mkldnn._convolution_pointwise.default(buf78, _frozen_param634, _frozen_param585, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf79, (1, 1024, 16, 16), (262144, 1, 16384, 1024))
del buf78
buf80 = torch.ops.mkldnn._convolution_pointwise_.binary(buf79, buf71, _frozen_param635, _frozen_param586, [0, 0], [2, 2], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf71
buf83 = reinterpret_tensor(buf72, (1, 1024, 16, 16), (262144, 1, 16384, 1024), 0); del buf72 # reuse
cpp_fused_silu_20(buf79, buf83)
buf84 = torch.ops.mkldnn._convolution_pointwise.default(buf83, _frozen_param636, _frozen_param587, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf84, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf85 = torch.ops.mkldnn._convolution_pointwise.default(buf84, _frozen_param637, _frozen_param588, [1, 1], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf85, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf86 = buf77; del buf77 # reuse
buf87 = reinterpret_tensor(buf86, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf86 # reuse
cpp_fused_mean_21(buf87, buf85)
buf88 = torch.ops.mkldnn._convolution_pointwise.default(buf87, _frozen_param638, _frozen_param101, [0, 0], [1, 1], [1, 1], 1, 'relu', [None], '')
assert_size_stride(buf88, (1, 16, 1, 1), (16, 1, 16, 16))
buf89 = torch.ops.mkldnn._convolution_pointwise.default(buf88, _frozen_param639, _frozen_param103, [0, 0], [1, 1], [1, 1], 1, 'sigmoid', [None], '')
assert_size_stride(buf89, (1, 256, 1, 1), (256, 1, 256, 256))
del buf88
buf90 = buf85; del buf85 # reuse
cpp_fused_mul_22(buf90, buf89)
buf91 = torch.ops.mkldnn._convolution_pointwise_.binary(buf83, buf90, _frozen_param640, _frozen_param589, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
buf94 = buf79; del buf79 # reuse
cpp_fused_silu_23(buf83, buf94)
buf95 = torch.ops.mkldnn._convolution_pointwise.default(buf94, _frozen_param641, _frozen_param590, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf95, (1, 256, 16, 16), (65536, 1, 4096, 256))
buf96 = torch.ops.mkldnn._convolution_pointwise.default(buf95, _frozen_param642, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf96, (1, 768, 16, 16), (196608, 1, 12288, 768))
buf97 = reinterpret_tensor(buf83, (4, 256, 256), (65536, 256, 1), 0); del buf83 # reuse
# Topologically Sorted Source Nodes: [matmul_4], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf96, (4, 256, 64), (64, 768, 1), 0), reinterpret_tensor(buf96, (4, 64, 256), (64, 1, 768), 256), out=buf97)
buf98 = reinterpret_tensor(buf90, (4, 16, 16, 64), (16384, 1024, 64, 1), 0); del buf90 # reuse
cpp_fused_clone_24(buf96, buf98)
buf99 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf98, (1024, 64), (64, 1), 0), _frozen_param646, _frozen_param645, None, 1024)
buf100 = buf98; del buf98 # reuse
cpp_fused_clone_25(buf96, buf100)
buf101 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf100, (1024, 64), (64, 1), 0), _frozen_param644, _frozen_param643, None, 1024)
buf102 = empty_strided_cpu((4, 256, 1), (256, 1, 1024), torch.float32)
buf103 = reinterpret_tensor(buf24, (4, 256, 256), (65536, 256, 1), 0); del buf24 # reuse
buf104 = empty_strided_cpu((4, 256, 1), (256, 1, 1024), torch.float32)
buf105 = reinterpret_tensor(buf17, (4, 256, 256), (65536, 256, 1), 0); del buf17 # reuse
cpp_fused__softmax_add_mul_26(buf97, buf99, buf101, buf102, buf103, buf104, buf105)
del buf101
del buf99
buf106 = reinterpret_tensor(buf100, (4, 256, 64), (16384, 64, 1), 0); del buf100 # reuse
# Topologically Sorted Source Nodes: [attn_3, matmul_7], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf105, reinterpret_tensor(buf96, (4, 256, 64), (64, 768, 1), 512), out=buf106)
del buf96
buf107 = buf84; del buf84 # reuse
buf108 = buf107; del buf107 # reuse
cpp_fused__native_batch_norm_legit_no_training_silu_27(buf108, buf106, _frozen_param349, _frozen_param350, _frozen_param351, _frozen_param352)
del buf106
buf109 = torch.ops.mkldnn._convolution_pointwise_.binary(buf94, buf108, _frozen_param647, _frozen_param591, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf108
del buf95
buf112 = reinterpret_tensor(buf105, (1, 1024, 16, 16), (262144, 1, 16384, 1024), 0); del buf105 # reuse
cpp_fused_silu_28(buf94, buf112)
buf113 = torch.ops.mkldnn._convolution_pointwise.default(buf112, _frozen_param648, _frozen_param592, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf113, (1, 512, 16, 16), (131072, 1, 8192, 512))
buf114 = torch.ops.mkldnn._convolution_pointwise.default(buf113, _frozen_param649, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf114, (1, 1536, 16, 16), (393216, 1, 24576, 1536))
buf115 = reinterpret_tensor(buf94, (4, 256, 256), (65536, 256, 1), 0); del buf94 # reuse
# Topologically Sorted Source Nodes: [matmul_8], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf114, (4, 256, 128), (128, 1536, 1), 0), reinterpret_tensor(buf114, (4, 128, 256), (128, 1, 1536), 512), out=buf115)
buf116 = reinterpret_tensor(buf67, (4, 16, 16, 128), (32768, 2048, 128, 1), 0); del buf67 # reuse
cpp_fused_clone_29(buf114, buf116)
buf117 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf116, (1024, 128), (128, 1), 0), _frozen_param653, _frozen_param652, None, 1024)
buf118 = buf116; del buf116 # reuse
cpp_fused_clone_30(buf114, buf118)
buf119 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf118, (1024, 128), (128, 1), 0), _frozen_param651, _frozen_param650, None, 1024)
buf120 = buf104; del buf104 # reuse
buf121 = buf97; del buf97 # reuse
buf122 = buf102; del buf102 # reuse
buf123 = buf103; del buf103 # reuse
cpp_fused__softmax_add_mul_31(buf115, buf117, buf119, buf120, buf121, buf122, buf123)
del buf115
del buf117
del buf119
del buf120
del buf121
del buf122
buf124 = reinterpret_tensor(buf118, (4, 256, 128), (32768, 128, 1), 0); del buf118 # reuse
# Topologically Sorted Source Nodes: [attn_5, matmul_11], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf123, reinterpret_tensor(buf114, (4, 256, 128), (128, 1536, 1), 1024), out=buf124)
del buf114
del buf123
buf125 = empty_strided_cpu((1, 512, 8, 8), (32768, 1, 4096, 512), torch.float32)
buf126 = buf125; del buf125 # reuse
cpp_fused__native_batch_norm_legit_no_training_avg_pool2d_silu_32(buf126, buf124, _frozen_param363, _frozen_param364, _frozen_param365, _frozen_param366)
del buf124
buf127 = torch.ops.mkldnn._convolution_pointwise.default(buf126, _frozen_param654, _frozen_param593, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf127, (1, 1536, 8, 8), (98304, 1, 12288, 1536))
buf128 = torch.ops.mkldnn._convolution_pointwise_.binary(buf127, buf112, _frozen_param655, _frozen_param594, [0, 0], [2, 2], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf112
del buf113
buf131 = empty_strided_cpu((1, 1536, 8, 8), (98304, 1, 12288, 1536), torch.float32)
cpp_fused_silu_33(buf127, buf131)
del buf127
buf132 = torch.ops.mkldnn._convolution_pointwise.default(buf131, _frozen_param656, _frozen_param595, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf132, (1, 512, 8, 8), (32768, 1, 4096, 512))
buf133 = torch.ops.mkldnn._convolution_pointwise.default(buf132, _frozen_param657, None, [0, 0], [1, 1], [1, 1], 1, 'none', [None], '')
assert_size_stride(buf133, (1, 1536, 8, 8), (98304, 1, 12288, 1536))
buf134 = empty_strided_cpu((4, 64, 64), (4096, 64, 1), torch.float32)
# Topologically Sorted Source Nodes: [matmul_12], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf133, (4, 64, 128), (128, 1536, 1), 0), reinterpret_tensor(buf133, (4, 128, 64), (128, 1, 1536), 512), out=buf134)
buf135 = reinterpret_tensor(buf126, (4, 8, 8, 128), (8192, 1024, 128, 1), 0); del buf126 # reuse
cpp_fused_clone_34(buf133, buf135)
buf136 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf135, (256, 128), (128, 1), 0), _frozen_param661, _frozen_param660, None, 256)
buf137 = buf135; del buf135 # reuse
cpp_fused_clone_35(buf133, buf137)
buf138 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf137, (256, 128), (128, 1), 0), _frozen_param659, _frozen_param658, None, 256)
buf139 = reinterpret_tensor(buf89, (4, 64, 1), (64, 1, 256), 0); del buf89 # reuse
buf140 = empty_strided_cpu((4, 64, 64), (4096, 64, 1), torch.float32)
buf141 = reinterpret_tensor(buf87, (4, 64, 1), (64, 1, 256), 0); del buf87 # reuse
buf142 = empty_strided_cpu((4, 64, 64), (4096, 64, 1), torch.float32)
cpp_fused__softmax_add_mul_36(buf134, buf136, buf138, buf139, buf140, buf141, buf142)
del buf134
del buf136
del buf138
del buf139
del buf140
del buf141
buf143 = reinterpret_tensor(buf137, (4, 64, 128), (8192, 128, 1), 0); del buf137 # reuse
# Topologically Sorted Source Nodes: [attn_7, matmul_15], Original ATen: [aten._softmax, aten.bmm]
extern_kernels.bmm(buf142, reinterpret_tensor(buf133, (4, 64, 128), (128, 1536, 1), 1024), out=buf143)
del buf142
buf144 = empty_strided_cpu((1, 512, 8, 8), (32768, 1, 4096, 512), torch.float32)
buf145 = buf144; del buf144 # reuse
cpp_fused__native_batch_norm_legit_no_training_silu_37(buf145, buf143, _frozen_param381, _frozen_param382, _frozen_param383, _frozen_param384)
del buf143
buf146 = torch.ops.mkldnn._convolution_pointwise_.binary(buf131, buf145, _frozen_param662, _frozen_param596, [0, 0], [1, 1], [1, 1], 1, 'add', 1.0, None, [None], None)
del buf132
del buf145
buf149 = buf133; del buf133 # reuse
cpp_fused_silu_38(buf131, buf149)
del buf131
buf150 = torch.ops.mkldnn._convolution_pointwise.default(buf149, _frozen_param663, _frozen_param597, [0, 0], [1, 1], [1, 1], 1, 'swish', [None], '')
assert_size_stride(buf150, (1, 1280, 8, 8), (81920, 1, 10240, 1280))
del buf149
buf151 = empty_strided_cpu((1, 1280, 1, 1), (1280, 1, 1280, 1280), torch.float32)
buf152 = buf151; del buf151 # reuse
cpp_fused_mean_39(buf152, buf150)
del buf150
buf153 = torch.ops.mkl._mkl_linear.default(reinterpret_tensor(buf152, (1, 1280), (0, 1), 0), _frozen_param665, _frozen_param664, None, 1)
del buf152
buf154 = buf153; del buf153 # reuse
cpp_fused_addmm_40(buf154, _frozen_param147)
return (buf154, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param18
_frozen_param18 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param34
_frozen_param34 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param47
_frozen_param47 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param63
_frozen_param63 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param85
_frozen_param85 = rand_strided((16, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param87
_frozen_param87 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param101
_frozen_param101 = rand_strided((16, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param103
_frozen_param103 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param147
_frozen_param147 = rand_strided((1000, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param564
_frozen_param564 = rand_strided((24, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param598
_frozen_param598 = rand_strided((24, 3, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param565
_frozen_param565 = rand_strided((32, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param599
_frozen_param599 = rand_strided((32, 24, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param566
_frozen_param566 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param600
_frozen_param600 = rand_strided((64, 32, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param567
_frozen_param567 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param601
_frozen_param601 = rand_strided((64, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param568
_frozen_param568 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param602
_frozen_param602 = rand_strided((64, 64, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param603
_frozen_param603 = rand_strided((8, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param604
_frozen_param604 = rand_strided((64, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param569
_frozen_param569 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param605
_frozen_param605 = rand_strided((256, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param570
_frozen_param570 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param606
_frozen_param606 = rand_strided((256, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param571
_frozen_param571 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param607
_frozen_param607 = rand_strided((64, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param572
_frozen_param572 = rand_strided((64, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param608
_frozen_param608 = rand_strided((64, 64, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param609
_frozen_param609 = rand_strided((8, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param610
_frozen_param610 = rand_strided((64, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param573
_frozen_param573 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param611
_frozen_param611 = rand_strided((256, 64, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param574
_frozen_param574 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param612
_frozen_param612 = rand_strided((128, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param575
_frozen_param575 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param613
_frozen_param613 = rand_strided((128, 128, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param614
_frozen_param614 = rand_strided((8, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param615
_frozen_param615 = rand_strided((128, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param576
_frozen_param576 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param616
_frozen_param616 = rand_strided((512, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param577
_frozen_param577 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param617
_frozen_param617 = rand_strided((512, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param578
_frozen_param578 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param618
_frozen_param618 = rand_strided((128, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param579
_frozen_param579 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param619
_frozen_param619 = rand_strided((128, 128, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param620
_frozen_param620 = rand_strided((8, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param621
_frozen_param621 = rand_strided((128, 8, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param580
_frozen_param580 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param622
_frozen_param622 = rand_strided((512, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param581
_frozen_param581 = rand_strided((128, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param623
_frozen_param623 = rand_strided((128, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param624
_frozen_param624 = rand_strided((384, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param625
_frozen_param625 = rand_strided((63, 32), (32, 1), device='cpu', dtype=torch.float32)
global _frozen_param626
_frozen_param626 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param627
_frozen_param627 = rand_strided((63, 32), (32, 1), device='cpu', dtype=torch.float32)
global _frozen_param628
_frozen_param628 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param305
_frozen_param305 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param306
_frozen_param306 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param307
_frozen_param307 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param308
_frozen_param308 = rand_strided((128, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param582
_frozen_param582 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param629
_frozen_param629 = rand_strided((512, 128, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param583
_frozen_param583 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param630
_frozen_param630 = rand_strided((256, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param584
_frozen_param584 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param631
_frozen_param631 = rand_strided((256, 256, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param632
_frozen_param632 = rand_strided((16, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param633
_frozen_param633 = rand_strided((256, 16, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param585
_frozen_param585 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param634
_frozen_param634 = rand_strided((1024, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param586
_frozen_param586 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param635
_frozen_param635 = rand_strided((1024, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param587
_frozen_param587 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param636
_frozen_param636 = rand_strided((256, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param588
_frozen_param588 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param637
_frozen_param637 = rand_strided((256, 256, 3, 3), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param638
_frozen_param638 = rand_strided((16, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param639
_frozen_param639 = rand_strided((256, 16, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param589
_frozen_param589 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param640
_frozen_param640 = rand_strided((1024, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param590
_frozen_param590 = rand_strided((256, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param641
_frozen_param641 = rand_strided((256, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param642
_frozen_param642 = rand_strided((768, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param643
_frozen_param643 = rand_strided((31, 64), (64, 1), device='cpu', dtype=torch.float32)
global _frozen_param644
_frozen_param644 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param645
_frozen_param645 = rand_strided((31, 64), (64, 1), device='cpu', dtype=torch.float32)
global _frozen_param646
_frozen_param646 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param349
_frozen_param349 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param350
_frozen_param350 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param351
_frozen_param351 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param352
_frozen_param352 = rand_strided((256, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param591
_frozen_param591 = rand_strided((1024, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param647
_frozen_param647 = rand_strided((1024, 256, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param592
_frozen_param592 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param648
_frozen_param648 = rand_strided((512, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param649
_frozen_param649 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param650
_frozen_param650 = rand_strided((31, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param651
_frozen_param651 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param652
_frozen_param652 = rand_strided((31, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param653
_frozen_param653 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param363
_frozen_param363 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param364
_frozen_param364 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param365
_frozen_param365 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param366
_frozen_param366 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param593
_frozen_param593 = rand_strided((1536, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param654
_frozen_param654 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param594
_frozen_param594 = rand_strided((1536, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param655
_frozen_param655 = rand_strided((1536, 1024, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param595
_frozen_param595 = rand_strided((512, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param656
_frozen_param656 = rand_strided((512, 1536, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param657
_frozen_param657 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param658
_frozen_param658 = rand_strided((15, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param659
_frozen_param659 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param660
_frozen_param660 = rand_strided((15, 128), (128, 1), device='cpu', dtype=torch.float32)
global _frozen_param661
_frozen_param661 = rand_strided((1982689, 1), (1, 0), device='cpu', dtype=torch.float32)
global _frozen_param381
_frozen_param381 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param382
_frozen_param382 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param383
_frozen_param383 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param384
_frozen_param384 = rand_strided((512, 1, 1), (1, 1, 1), device='cpu', dtype=torch.float32)
global _frozen_param596
_frozen_param596 = rand_strided((1536, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param662
_frozen_param662 = rand_strided((1536, 512, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param597
_frozen_param597 = rand_strided((1280, ), (1, ), device='cpu', dtype=torch.float32)
global _frozen_param663
_frozen_param663 = rand_strided((1280, 1536, 1, 1), (1, 0, 0, 0), device='cpu', dtype=torch.float32)
global _frozen_param664
_frozen_param664 = rand_strided((1000, 1280), (1280, 1), device='cpu', dtype=torch.float32)
global _frozen_param665
_frozen_param665 = rand_strided((3490017, 1), (1, 0), device='cpu', dtype=torch.float32)
arg224_1 = rand_strided((1, 3, 256, 256), (196608, 65536, 256, 1), device='cpu', dtype=torch.float32)
fn = lambda: call([arg224_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('sebotnet33ts_256', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment