Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Last active December 13, 2024 11:36
Show Gist options
  • Select an option

  • Save leslie-fang-intel/f2d4de4b4d14875a40d6c09f0fa5fbb3 to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/f2d4de4b4d14875a40d6c09f0fa5fbb3 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
_frozen_param0 = None # device(type='cuda', index=0) torch.bfloat16 (32000, 512) (512, 1) 7fe5e12f9cb0
_frozen_param1 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fb4c0
_frozen_param6 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f9580
_frozen_param10 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fbce0
_frozen_param15 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc0125c0
_frozen_param19 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc0127a0
_frozen_param24 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc012a20
_frozen_param28 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fa930
_frozen_param33 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc0113a0
_frozen_param37 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f9fd0
_frozen_param42 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc011f80
_frozen_param46 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5ed46b8d0
_frozen_param51 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc013790
_frozen_param55 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f99e0
_frozen_param60 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f9bc0
_frozen_param64 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc010d10
_frozen_param69 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc012bb0
_frozen_param73 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fb470
_frozen_param135 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3bec0
_frozen_param79 = None # device(type='cuda', index=0) torch.complex64 (1, 32, 1, 32) (1024, 32, 32, 1) 7fe55837e520
_frozen_param80 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d82f20
_frozen_param136 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c3a4d0
_frozen_param83 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe55821f880
_frozen_param137 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3bb00
_frozen_param87 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557e9ad90
_frozen_param138 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557df2f70
_frozen_param90 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d88180
_frozen_param139 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3b0b0
_frozen_param94 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89800
_frozen_param140 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c3a1b0
_frozen_param97 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d898f0
_frozen_param141 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3ade0
_frozen_param101 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89a30
_frozen_param142 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c3bab0
_frozen_param104 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d89b20
_frozen_param143 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c38e50
_frozen_param108 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89c60
_frozen_param144 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c39e40
_frozen_param111 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d89d50
_frozen_param145 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557d82d90
_frozen_param115 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89e90
_frozen_param146 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c39260
_frozen_param118 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d89f80
_frozen_param147 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c38590
_frozen_param122 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d8a0c0
_frozen_param148 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c38d60
_frozen_param125 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d8a1b0
_frozen_param149 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c38bd0
_frozen_param129 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d8a2f0
_frozen_param150 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c39c10
_frozen_param132 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d8a3e0
_frozen_param134 = None # device(type='cuda', index=0) torch.bfloat16 (512, 32000) (1, 512) 7fe557e302c0
# kernel path: /tmp/torchinductor_t/jq/cjqkz4jt5d5lkuk3r2wju344fss2vw6bnduyr5hqduvhlizskx43.py
# Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add
# float_1 => convert_element_type_1
# h => embedding
# mean => mean
# mul => mul
# mul_1 => mul_1
# output => convert_element_type_2
# pow_1 => pow_1
# rsqrt => rsqrt
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %convert_element_type_1 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%embedding, torch.float32), kwargs = {})
# %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_1, 2), kwargs = {})
# %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
# %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %rsqrt), kwargs = {})
# %convert_element_type_2 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %arg1_1), kwargs = {})
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1024
rnumel = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp1 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert(((0 <= tmp4) & (tmp4 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask & xmask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp26 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp12 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp13 = tmp0 + tmp12
tmp14 = tmp0 < 0
tmp15 = tl.where(tmp14, tmp13, tmp0)
tl.device_assert(((0 <= tmp15) & (tmp15 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp15 < 32000")
tmp17 = tl.load(in_ptr1 + (r1 + 512*tmp15), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp18 = tmp17.to(tl.float32)
tmp19 = 512.0
tmp20 = (tmp10 / tmp19).to(tl.float32)
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp25 = tmp24.to(tl.float32)
tmp27 = tmp25 * tmp26
tl.store(out_ptr1 + (r1 + 512*x0), tmp27, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/xc/cxcmzl6ilmlg65glf4pzjcxwxpduicy2vfmzzjio5qccmdqzy5s3.py
# Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy]
# Source node to ATen node mapping:
# setitem_1 => copy_1
# Graph fragment:
# %copy_1 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_4, %view_8), kwargs = {})
# %copy__default_1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor_1, %copy_1), kwargs = {})
triton_poi_fused_copy_1 = async_compile.triton('triton_poi_fused_copy_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_copy_1', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 512)
x3 = xindex // 512
x2 = xindex // 16384
x4 = (xindex % 16384)
tmp0 = tl.load(in_ptr0 + (x0 + 1536*x3), None).to(tl.float32)
tl.store(out_ptr0 + (512 + x4 + 524288*x2), tmp0, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/7m/c7m45i7porfguqnzfi4evudvthp6s46rtboxlqxs2ukmts7oogqy.py
# Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex]
# Source node to ATen node mapping:
# xq_ => view_as_complex
# Graph fragment:
# %view_as_complex : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_9,), kwargs = {})
triton_poi_fused_view_as_complex_2 = async_compile.triton('triton_poi_fused_view_as_complex_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_view_as_complex_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 512)
x1 = xindex // 512
x2 = xindex
tmp0 = tl.load(in_ptr0 + (512 + x0 + 1536*x1), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x2), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/65/c65ekeiwoihqn3z3eozs5kfs2w7fuc6dtsursyyedmm5n3idz6aw.py
# Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex]
# Source node to ATen node mapping:
# xk_ => view_as_complex_1
# Graph fragment:
# %view_as_complex_1 : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_10,), kwargs = {})
triton_poi_fused_view_as_complex_3 = async_compile.triton('triton_poi_fused_view_as_complex_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_view_as_complex_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 512)
x1 = xindex // 512
x2 = xindex
tmp0 = tl.load(in_ptr0 + (1024 + x0 + 1536*x1), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x2), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/fr/cfrphf4roqxuzvwmpfjsqykxw7ykp6blxuab7a762yx75y5ssqwv.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
# %_scaled_dot_product_flash_attention_default_7 : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%permute_default_21, %permute_default_22, %permute_default_23), kwargs = {scale: 0.125})
triton_poi_fused_4 = async_compile.triton('triton_poi_fused_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/pm/cpmwz4lr5hqfizv3wxqz5p5i5tjfamfs6btvjdkmpwz6sixmisvp.py
# Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy]
# Source node to ATen node mapping:
# setitem => copy
# xk_2 => convert_element_type_12
# Graph fragment:
# %convert_element_type_12 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_13, torch.bfloat16), kwargs = {})
# %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %convert_element_type_12), kwargs = {})
# %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor, %copy), kwargs = {})
triton_poi_fused__to_copy_copy_5 = async_compile.triton('triton_poi_fused__to_copy_copy_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_copy_5', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_copy_5(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x2 = xindex
x0 = (xindex % 16384)
x1 = xindex // 16384
tmp0 = tl.load(in_ptr0 + (x2), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (512 + x0 + 524288*x1), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/dr/cdrrgy3ocqlbwxy4cv6ytxs3apsjv37qs3l2swp2p2sgajmtmocx.py
# Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_2 => add_2
# float_5 => convert_element_type_21
# h => embedding
# h_1 => add_1
# mean_1 => mean_1
# mul_4 => mul_4
# mul_5 => mul_5
# output_3 => convert_element_type_22
# pow_2 => pow_2
# rsqrt_1 => rsqrt_1
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {})
# %convert_element_type_21 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.float32), kwargs = {})
# %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_21, 2), kwargs = {})
# %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
# %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_2,), kwargs = {})
# %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_21, %rsqrt_1), kwargs = {})
# %convert_element_type_22 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_4, torch.bfloat16), kwargs = {})
# %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %arg6_1), kwargs = {})
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
x0 = xindex
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp21 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.full([RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32)
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [RBLOCK])
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp14 = 512.0
tmp15 = (tmp13 / tmp14).to(tl.float32)
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tmp19 = tmp9 * tmp18
tmp20 = tmp19.to(tl.float32)
tmp22 = tmp20 * tmp21
tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/37/c37cieheeo7zfl4rbysqwzkbdr7pwfgmv5au7ifpl2r5xifjrj5v.py
# Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul]
# Source node to ATen node mapping:
# mul_6 => mul_7
# silu => convert_element_type_25, convert_element_type_26, mul_6, sigmoid
# Graph fragment:
# %convert_element_type_25 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_24, torch.float32), kwargs = {})
# %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_25,), kwargs = {})
# %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_25, %sigmoid), kwargs = {})
# %convert_element_type_26 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {})
# %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_26, %view_26), kwargs = {})
triton_poi_fused_mul_silu_7 = async_compile.triton('triton_poi_fused_mul_silu_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2097152},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_silu_7(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1572864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 1536)
x1 = xindex // 1536
x2 = xindex
tmp0 = tl.load(in_ptr0 + (1536 + x0 + 3072*x1), None).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0 + 3072*x1), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(out_ptr0 + (x2), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/oe/coezydc4r3hd4ebv3eien4ar5zw7e5oum674ezorliyuknllnim6.py
# Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_4 => add_4
# float_6 => convert_element_type_31
# h => embedding
# h_1 => add_1
# mean_2 => mean_2
# mul_7 => mul_8
# mul_8 => mul_9
# out => add_3
# output_4 => convert_element_type_32
# pow_3 => pow_3
# rsqrt_2 => rsqrt_2
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {})
# %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {})
# %convert_element_type_31 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.float32), kwargs = {})
# %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_31, 2), kwargs = {})
# %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
# %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
# %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_4,), kwargs = {})
# %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_31, %rsqrt_2), kwargs = {})
# %convert_element_type_32 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_8, torch.bfloat16), kwargs = {})
# %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_32, %arg10_1), kwargs = {})
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
x0 = xindex
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.full([RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32)
tmp8 = tmp6 + tmp7
tmp10 = tmp8 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tmp11 * tmp11
tmp13 = tl.broadcast_to(tmp12, [RBLOCK])
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp13, 0))
tmp16 = 512.0
tmp17 = (tmp15 / tmp16).to(tl.float32)
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp11 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (r1 + 512*x0), tmp24, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/hg/chg35he4i3wb65f2xla6t2f46hl3nvqroelcorgcij2tuqcakabg.py
# Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_6 => add_6
# float_10 => convert_element_type_51
# h => embedding
# h_1 => add_1
# h_2 => add_5
# mean_3 => mean_3
# mul_11 => mul_12
# mul_12 => mul_13
# out => add_3
# output_7 => convert_element_type_52
# pow_4 => pow_4
# rsqrt_3 => rsqrt_3
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {})
# %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {})
# %add_5 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %view_51), kwargs = {})
# %convert_element_type_51 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_5, torch.float32), kwargs = {})
# %pow_4 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_51, 2), kwargs = {})
# %mean_3 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
# %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
# %rsqrt_3 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {})
# %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_51, %rsqrt_3), kwargs = {})
# %convert_element_type_52 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_12, torch.bfloat16), kwargs = {})
# %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_52, %arg15_1), kwargs = {})
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
x0 = xindex
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp11 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp25 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.full([RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32)
tmp8 = tmp6 + tmp7
tmp10 = tmp8 + tmp9
tmp12 = tmp10 + tmp11
tmp13 = tmp12.to(tl.float32)
tmp14 = tmp13 * tmp13
tmp15 = tl.broadcast_to(tmp14, [RBLOCK])
tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp15, 0))
tmp18 = 512.0
tmp19 = (tmp17 / tmp18).to(tl.float32)
tmp20 = 1e-05
tmp21 = tmp19 + tmp20
tmp22 = libdevice.rsqrt(tmp21)
tmp23 = tmp13 * tmp22
tmp24 = tmp23.to(tl.float32)
tmp26 = tmp24 * tmp25
tl.store(in_out_ptr0 + (r1 + 512*x0), tmp12, None)
tl.store(out_ptr1 + (r1 + 512*x0), tmp26, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/iy/ciy4z4yq6ins6l5ivop2w7p7zdinblxm2voqklkpqk7pgxbr7pil.py
# Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_8 => add_8
# float_11 => convert_element_type_61
# mean_4 => mean_4
# mul_14 => mul_16
# mul_15 => mul_17
# out_1 => add_7
# output_8 => convert_element_type_62
# pow_5 => pow_5
# rsqrt_4 => rsqrt_4
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %convert_element_type_61 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.float32), kwargs = {})
# %pow_5 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_61, 2), kwargs = {})
# %mean_4 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
# %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
# %rsqrt_4 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {})
# %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_61, %rsqrt_4), kwargs = {})
# %convert_element_type_62 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
# %mul_17 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_62, %arg19_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp15 = tl.load(in_ptr2 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [RBLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp8 = 512.0
tmp9 = (tmp7 / tmp8).to(tl.float32)
tmp10 = 1e-05
tmp11 = tmp9 + tmp10
tmp12 = libdevice.rsqrt(tmp11)
tmp13 = tmp3 * tmp12
tmp14 = tmp13.to(tl.float32)
tmp16 = tmp14 * tmp15
tl.store(out_ptr1 + (r1 + 512*x0), tmp16, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/pn/cpnh5dfi4oms6ccoppdagviyr73sfnfnxsb7fxqww3qrixi7lyh4.py
# Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_10 => add_10
# float_15 => convert_element_type_81
# h_3 => add_9
# mean_5 => mean_5
# mul_18 => mul_20
# mul_19 => mul_21
# out_1 => add_7
# output_11 => convert_element_type_82
# pow_6 => pow_6
# rsqrt_5 => rsqrt_5
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {})
# %convert_element_type_81 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_9, torch.float32), kwargs = {})
# %pow_6 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_81, 2), kwargs = {})
# %mean_5 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_6, [-1], True), kwargs = {})
# %add_10 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_5, 1e-05), kwargs = {})
# %rsqrt_5 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_10,), kwargs = {})
# %mul_20 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_81, %rsqrt_5), kwargs = {})
# %convert_element_type_82 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_20, torch.bfloat16), kwargs = {})
# %mul_21 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_82, %arg24_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp5 * tmp5
tmp7 = tl.broadcast_to(tmp6, [RBLOCK])
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
tmp10 = 512.0
tmp11 = (tmp9 / tmp10).to(tl.float32)
tmp12 = 1e-05
tmp13 = tmp11 + tmp12
tmp14 = libdevice.rsqrt(tmp13)
tmp15 = tmp5 * tmp14
tmp16 = tmp15.to(tl.float32)
tmp18 = tmp16 * tmp17
tl.store(out_ptr1 + (r1 + 512*x0), tmp18, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/dw/cdwfdyb3igstcir6uoaelj2eyiyqrmimqovbk6l57r4zvxcvjy7k.py
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_12 => add_12
# float_16 => convert_element_type_91
# h_3 => add_9
# mean_6 => mean_6
# mul_21 => mul_24
# mul_22 => mul_25
# out_1 => add_7
# out_2 => add_11
# output_12 => convert_element_type_92
# pow_7 => pow_7
# rsqrt_6 => rsqrt_6
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {})
# %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {})
# %convert_element_type_91 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {})
# %pow_7 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_91, 2), kwargs = {})
# %mean_6 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_7, [-1], True), kwargs = {})
# %add_12 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_6, 1e-05), kwargs = {})
# %rsqrt_6 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {})
# %mul_24 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_91, %rsqrt_6), kwargs = {})
# %convert_element_type_92 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_24, torch.bfloat16), kwargs = {})
# %mul_25 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_92, %arg28_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp19 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [RBLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp12 = 512.0
tmp13 = (tmp11 / tmp12).to(tl.float32)
tmp14 = 1e-05
tmp15 = tmp13 + tmp14
tmp16 = libdevice.rsqrt(tmp15)
tmp17 = tmp7 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp20 = tmp18 * tmp19
tl.store(out_ptr1 + (r1 + 512*x0), tmp20, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/z3/cz3e34micft5wnsnm2gnpqmih6vve6ujd2ccrnpkysa5obrpxhpi.py
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_14 => add_14
# float_20 => convert_element_type_111
# h_3 => add_9
# h_4 => add_13
# mean_7 => mean_7
# mul_25 => mul_28
# mul_26 => mul_29
# out_1 => add_7
# out_2 => add_11
# output_15 => convert_element_type_112
# pow_8 => pow_8
# rsqrt_7 => rsqrt_7
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {})
# %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {})
# %add_13 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_109), kwargs = {})
# %convert_element_type_111 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_13, torch.float32), kwargs = {})
# %pow_8 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_111, 2), kwargs = {})
# %mean_7 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_8, [-1], True), kwargs = {})
# %add_14 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_7, 1e-05), kwargs = {})
# %rsqrt_7 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_14,), kwargs = {})
# %mul_28 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_111, %rsqrt_7), kwargs = {})
# %convert_element_type_112 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_28, torch.bfloat16), kwargs = {})
# %mul_29 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_112, %arg33_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp21 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [RBLOCK])
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp14 = 512.0
tmp15 = (tmp13 / tmp14).to(tl.float32)
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tmp19 = tmp9 * tmp18
tmp20 = tmp19.to(tl.float32)
tmp22 = tmp20 * tmp21
tl.store(in_out_ptr0 + (r1 + 512*x0), tmp8, None)
tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/lv/clv7iqdidc6yovjvhgzra2tlgehvzuyjf6c4drcdk3cgzbb7tbn4.py
# Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_32 => add_32
# float_41 => convert_element_type_241
# h_9 => mul_65
# mean_16 => mean_16
# mul_56 => mul_64
# out_7 => add_31
# output_32 => convert_element_type_242
# pow_17 => pow_17
# rsqrt_16 => rsqrt_16
# Graph fragment:
# %add_31 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_29, %view_231), kwargs = {})
# %convert_element_type_241 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_31, torch.float32), kwargs = {})
# %pow_17 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_241, 2), kwargs = {})
# %mean_16 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_17, [-1], True), kwargs = {})
# %add_32 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_16, 1e-05), kwargs = {})
# %rsqrt_16 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_32,), kwargs = {})
# %mul_64 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_241, %rsqrt_16), kwargs = {})
# %convert_element_type_242 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_64, torch.bfloat16), kwargs = {})
# %mul_65 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_242, %arg73_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14(in_out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp15 = tl.load(in_ptr1 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [RBLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp8 = 512.0
tmp9 = (tmp7 / tmp8).to(tl.float32)
tmp10 = 1e-05
tmp11 = tmp9 + tmp10
tmp12 = libdevice.rsqrt(tmp11)
tmp13 = tmp3 * tmp12
tmp14 = tmp13.to(tl.float32)
tmp16 = tmp14 * tmp15
tl.store(in_out_ptr0 + (r1 + 512*x0), tmp16, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/uy/cuyzkjsgklqarsm2yt3x6ruv2di7jsrafhofrtw3gdor4jvuttkh.py
# Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# float_42 => convert_element_type_245
# Graph fragment:
# %convert_element_type_245 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_56, torch.float32), kwargs = {})
triton_poi_fused__to_copy_15 = async_compile.triton('triton_poi_fused__to_copy_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_15(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1024000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1 = args
args.clear()
assert_size_stride(arg76_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg77_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg78_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg79_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg80_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg81_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg82_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg83_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg84_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg85_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg86_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg87_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg88_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg89_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg90_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg91_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg92_1, (32, 32), (32, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg92_1, _frozen_param0, _frozen_param1, buf1, 1024, 512, grid=grid(1024), stream=stream0)
buf2 = empty_strided_cuda((1024, 1536), (1536, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1, (1024, 512), (512, 1), 0), _frozen_param135, out=buf2)
# Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf2, arg77_1, 524288, grid=grid(524288), stream=stream0)
buf3 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32)
# Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf2, buf3, 524288, grid=grid(524288), stream=stream0)
buf10 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32)
# Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf2, buf10, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex]
buf4 = torch.ops.aten.view_as_complex.default(buf3)
buf5 = buf4
# Topologically Sorted Source Nodes: [mul_2], Original ATen: [aten.mul]
buf6 = torch.ops.aten.mul.Tensor(buf5, _frozen_param79)
del buf4
del buf5
buf7 = buf6
del buf6
# Topologically Sorted Source Nodes: [view_as_real], Original ATen: [aten.view_as_real]
buf8 = torch.ops.aten.view_as_real.default(buf7)
buf9 = buf8
buf19 = reinterpret_tensor(buf1, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf1 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf9, buf19, 524288, grid=grid(524288), stream=stream0)
del buf7
del buf8
del buf9
# Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex]
buf11 = torch.ops.aten.view_as_complex.default(buf10)
buf12 = buf11
# Topologically Sorted Source Nodes: [mul_3], Original ATen: [aten.mul]
buf13 = torch.ops.aten.mul.Tensor(buf12, _frozen_param79)
del buf11
del buf12
buf14 = buf13
del buf13
# Topologically Sorted Source Nodes: [view_as_real_1], Original ATen: [aten.view_as_real]
buf15 = torch.ops.aten.view_as_real.default(buf14)
buf16 = buf15
# Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf16, arg76_1, 524288, grid=grid(524288), stream=stream0)
del buf14
del buf15
del buf16
# Topologically Sorted Source Nodes: [], Original ATen: []
buf20 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf19, reinterpret_tensor(arg76_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg77_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf21 = buf20[0]
del buf20
buf26 = reinterpret_tensor(buf19, (1024, 512), (512, 1), 0); del buf19 # reuse
# Topologically Sorted Source Nodes: [linear_3], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf21, (1024, 512), (512, 1), 0), _frozen_param80, out=buf26)
buf28 = reinterpret_tensor(buf21, (32, 32, 512), (16384, 512, 1), 0); del buf21 # reuse
# Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6.run(arg92_1, _frozen_param0, buf26, _frozen_param6, buf28, 1024, 512, grid=grid(1024), stream=stream0)
buf29 = empty_strided_cuda((1024, 3072), (3072, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf28, (1024, 512), (512, 1), 0), _frozen_param136, out=buf29)
buf30 = reinterpret_tensor(buf2, (32, 32, 1536), (49152, 1536, 1), 0); del buf2 # reuse
# Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf29, buf30, 1572864, grid=grid(1572864), stream=stream0)
buf31 = reinterpret_tensor(buf28, (1024, 512), (512, 1), 0); del buf28 # reuse
# Topologically Sorted Source Nodes: [linear_6], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0), _frozen_param83, out=buf31)
buf33 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8.run(arg92_1, _frozen_param0, buf26, buf31, _frozen_param10, buf33, 1024, 512, grid=grid(1024), stream=stream0)
buf34 = reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0); del buf30 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf33, (1024, 512), (512, 1), 0), _frozen_param137, out=buf34)
# Topologically Sorted Source Nodes: [setitem_3], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf34, arg79_1, 524288, grid=grid(524288), stream=stream0)
buf35 = buf10; del buf10 # reuse
# Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf34, buf35, 524288, grid=grid(524288), stream=stream0)
buf42 = buf3; del buf3 # reuse
# Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf34, buf42, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex]
buf36 = torch.ops.aten.view_as_complex.default(buf35)
buf37 = buf36
# Topologically Sorted Source Nodes: [mul_9], Original ATen: [aten.mul]
buf38 = torch.ops.aten.mul.Tensor(buf37, _frozen_param79)
del buf36
del buf37
buf39 = buf38
del buf38
# Topologically Sorted Source Nodes: [view_as_real_2], Original ATen: [aten.view_as_real]
buf40 = torch.ops.aten.view_as_real.default(buf39)
buf41 = buf40
buf51 = reinterpret_tensor(buf33, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf33 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf41, buf51, 524288, grid=grid(524288), stream=stream0)
del buf39
del buf40
del buf41
# Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex]
buf43 = torch.ops.aten.view_as_complex.default(buf42)
buf44 = buf43
# Topologically Sorted Source Nodes: [mul_10], Original ATen: [aten.mul]
buf45 = torch.ops.aten.mul.Tensor(buf44, _frozen_param79)
del buf43
del buf44
buf46 = buf45
del buf45
# Topologically Sorted Source Nodes: [view_as_real_3], Original ATen: [aten.view_as_real]
buf47 = torch.ops.aten.view_as_real.default(buf46)
buf48 = buf47
# Topologically Sorted Source Nodes: [xk_5, setitem_2], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf48, arg78_1, 524288, grid=grid(524288), stream=stream0)
del buf46
del buf47
del buf48
# Topologically Sorted Source Nodes: [], Original ATen: []
buf52 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf51, reinterpret_tensor(arg78_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg79_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf53 = buf52[0]
del buf52
buf58 = reinterpret_tensor(buf51, (1024, 512), (512, 1), 0); del buf51 # reuse
# Topologically Sorted Source Nodes: [linear_10], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf53, (1024, 512), (512, 1), 0), _frozen_param87, out=buf58)
buf59 = reinterpret_tensor(buf26, (32, 32, 512), (16384, 512, 1), 0); del buf26 # reuse
buf61 = reinterpret_tensor(buf53, (32, 32, 512), (16384, 512, 1), 0); del buf53 # reuse
# Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9.run(buf59, arg92_1, _frozen_param0, buf31, buf58, _frozen_param15, buf61, 1024, 512, grid=grid(1024), stream=stream0)
del arg92_1
buf62 = buf29; del buf29 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf61, (1024, 512), (512, 1), 0), _frozen_param138, out=buf62)
buf63 = reinterpret_tensor(buf34, (32, 32, 1536), (49152, 1536, 1), 0); del buf34 # reuse
# Topologically Sorted Source Nodes: [silu_1, mul_13], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf62, buf63, 1572864, grid=grid(1572864), stream=stream0)
buf64 = reinterpret_tensor(buf61, (1024, 512), (512, 1), 0); del buf61 # reuse
# Topologically Sorted Source Nodes: [linear_13], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0), _frozen_param90, out=buf64)
buf66 = reinterpret_tensor(buf58, (32, 32, 512), (16384, 512, 1), 0); del buf58 # reuse
# Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf59, buf64, _frozen_param19, buf66, 1024, 512, grid=grid(1024), stream=stream0)
buf67 = reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0); del buf63 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf66, (1024, 512), (512, 1), 0), _frozen_param139, out=buf67)
# Topologically Sorted Source Nodes: [setitem_5], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf67, arg81_1, 524288, grid=grid(524288), stream=stream0)
buf68 = buf42; del buf42 # reuse
# Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf67, buf68, 524288, grid=grid(524288), stream=stream0)
buf75 = buf35; del buf35 # reuse
# Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf67, buf75, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex]
buf69 = torch.ops.aten.view_as_complex.default(buf68)
buf70 = buf69
# Topologically Sorted Source Nodes: [mul_16], Original ATen: [aten.mul]
buf71 = torch.ops.aten.mul.Tensor(buf70, _frozen_param79)
del buf69
del buf70
buf72 = buf71
del buf71
# Topologically Sorted Source Nodes: [view_as_real_4], Original ATen: [aten.view_as_real]
buf73 = torch.ops.aten.view_as_real.default(buf72)
buf74 = buf73
buf84 = reinterpret_tensor(buf66, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf66 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf74, buf84, 524288, grid=grid(524288), stream=stream0)
del buf72
del buf73
del buf74
# Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex]
buf76 = torch.ops.aten.view_as_complex.default(buf75)
buf77 = buf76
# Topologically Sorted Source Nodes: [mul_17], Original ATen: [aten.mul]
buf78 = torch.ops.aten.mul.Tensor(buf77, _frozen_param79)
del buf76
del buf77
buf79 = buf78
del buf78
# Topologically Sorted Source Nodes: [view_as_real_5], Original ATen: [aten.view_as_real]
buf80 = torch.ops.aten.view_as_real.default(buf79)
buf81 = buf80
# Topologically Sorted Source Nodes: [xk_8, setitem_4], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf81, arg80_1, 524288, grid=grid(524288), stream=stream0)
del buf79
del buf80
del buf81
# Topologically Sorted Source Nodes: [], Original ATen: []
buf85 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf84, reinterpret_tensor(arg80_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg81_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf86 = buf85[0]
del buf85
buf91 = reinterpret_tensor(buf84, (1024, 512), (512, 1), 0); del buf84 # reuse
# Topologically Sorted Source Nodes: [linear_17], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf86, (1024, 512), (512, 1), 0), _frozen_param94, out=buf91)
buf93 = reinterpret_tensor(buf86, (32, 32, 512), (16384, 512, 1), 0); del buf86 # reuse
# Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf59, buf64, buf91, _frozen_param24, buf93, 1024, 512, grid=grid(1024), stream=stream0)
buf94 = buf62; del buf62 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf93, (1024, 512), (512, 1), 0), _frozen_param140, out=buf94)
buf95 = reinterpret_tensor(buf67, (32, 32, 1536), (49152, 1536, 1), 0); del buf67 # reuse
# Topologically Sorted Source Nodes: [silu_2, mul_20], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf94, buf95, 1572864, grid=grid(1572864), stream=stream0)
buf96 = reinterpret_tensor(buf93, (1024, 512), (512, 1), 0); del buf93 # reuse
# Topologically Sorted Source Nodes: [linear_20], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0), _frozen_param97, out=buf96)
buf98 = reinterpret_tensor(buf31, (32, 32, 512), (16384, 512, 1), 0); del buf31 # reuse
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf59, buf64, buf91, buf96, _frozen_param28, buf98, 1024, 512, grid=grid(1024), stream=stream0)
buf99 = reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0); del buf95 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf98, (1024, 512), (512, 1), 0), _frozen_param141, out=buf99)
# Topologically Sorted Source Nodes: [setitem_7], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf99, arg83_1, 524288, grid=grid(524288), stream=stream0)
buf100 = buf75; del buf75 # reuse
# Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf99, buf100, 524288, grid=grid(524288), stream=stream0)
buf107 = buf68; del buf68 # reuse
# Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf99, buf107, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex]
buf101 = torch.ops.aten.view_as_complex.default(buf100)
buf102 = buf101
# Topologically Sorted Source Nodes: [mul_23], Original ATen: [aten.mul]
buf103 = torch.ops.aten.mul.Tensor(buf102, _frozen_param79)
del buf101
del buf102
buf104 = buf103
del buf103
# Topologically Sorted Source Nodes: [view_as_real_6], Original ATen: [aten.view_as_real]
buf105 = torch.ops.aten.view_as_real.default(buf104)
buf106 = buf105
buf116 = reinterpret_tensor(buf98, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf98 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf106, buf116, 524288, grid=grid(524288), stream=stream0)
del buf104
del buf105
del buf106
# Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex]
buf108 = torch.ops.aten.view_as_complex.default(buf107)
buf109 = buf108
# Topologically Sorted Source Nodes: [mul_24], Original ATen: [aten.mul]
buf110 = torch.ops.aten.mul.Tensor(buf109, _frozen_param79)
del buf108
del buf109
buf111 = buf110
del buf110
# Topologically Sorted Source Nodes: [view_as_real_7], Original ATen: [aten.view_as_real]
buf112 = torch.ops.aten.view_as_real.default(buf111)
buf113 = buf112
# Topologically Sorted Source Nodes: [xk_11, setitem_6], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf113, arg82_1, 524288, grid=grid(524288), stream=stream0)
del buf111
del buf112
del buf113
# Topologically Sorted Source Nodes: [], Original ATen: []
buf117 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf116, reinterpret_tensor(arg82_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg83_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf118 = buf117[0]
del buf117
buf123 = reinterpret_tensor(buf116, (1024, 512), (512, 1), 0); del buf116 # reuse
# Topologically Sorted Source Nodes: [linear_24], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf118, (1024, 512), (512, 1), 0), _frozen_param101, out=buf123)
buf124 = buf59; del buf59 # reuse
buf126 = reinterpret_tensor(buf118, (32, 32, 512), (16384, 512, 1), 0); del buf118 # reuse
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf124, buf64, buf91, buf96, buf123, _frozen_param33, buf126, 1024, 512, grid=grid(1024), stream=stream0)
del buf123
del buf64
buf127 = buf94; del buf94 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf126, (1024, 512), (512, 1), 0), _frozen_param142, out=buf127)
buf128 = reinterpret_tensor(buf99, (32, 32, 1536), (49152, 1536, 1), 0); del buf99 # reuse
# Topologically Sorted Source Nodes: [silu_3, mul_27], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf127, buf128, 1572864, grid=grid(1572864), stream=stream0)
buf129 = reinterpret_tensor(buf126, (1024, 512), (512, 1), 0); del buf126 # reuse
# Topologically Sorted Source Nodes: [linear_27], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0), _frozen_param104, out=buf129)
buf131 = reinterpret_tensor(buf96, (32, 32, 512), (16384, 512, 1), 0); del buf96 # reuse
# Topologically Sorted Source Nodes: [out_3, float_21, pow_9, mean_8, add_16, rsqrt_8, mul_28, output_16, mul_29], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf124, buf129, _frozen_param37, buf131, 1024, 512, grid=grid(1024), stream=stream0)
buf132 = reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0); del buf128 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf131, (1024, 512), (512, 1), 0), _frozen_param143, out=buf132)
# Topologically Sorted Source Nodes: [setitem_9], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf132, arg85_1, 524288, grid=grid(524288), stream=stream0)
buf133 = buf107; del buf107 # reuse
# Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf132, buf133, 524288, grid=grid(524288), stream=stream0)
buf140 = buf100; del buf100 # reuse
# Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf132, buf140, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex]
buf134 = torch.ops.aten.view_as_complex.default(buf133)
buf135 = buf134
# Topologically Sorted Source Nodes: [mul_30], Original ATen: [aten.mul]
buf136 = torch.ops.aten.mul.Tensor(buf135, _frozen_param79)
del buf134
del buf135
buf137 = buf136
del buf136
# Topologically Sorted Source Nodes: [view_as_real_8], Original ATen: [aten.view_as_real]
buf138 = torch.ops.aten.view_as_real.default(buf137)
buf139 = buf138
buf149 = reinterpret_tensor(buf131, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf131 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf139, buf149, 524288, grid=grid(524288), stream=stream0)
del buf137
del buf138
del buf139
# Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex]
buf141 = torch.ops.aten.view_as_complex.default(buf140)
buf142 = buf141
# Topologically Sorted Source Nodes: [mul_31], Original ATen: [aten.mul]
buf143 = torch.ops.aten.mul.Tensor(buf142, _frozen_param79)
del buf141
del buf142
buf144 = buf143
del buf143
# Topologically Sorted Source Nodes: [view_as_real_9], Original ATen: [aten.view_as_real]
buf145 = torch.ops.aten.view_as_real.default(buf144)
buf146 = buf145
# Topologically Sorted Source Nodes: [xk_14, setitem_8], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf146, arg84_1, 524288, grid=grid(524288), stream=stream0)
del buf144
del buf145
del buf146
# Topologically Sorted Source Nodes: [], Original ATen: []
buf150 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf149, reinterpret_tensor(arg84_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg85_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf151 = buf150[0]
del buf150
buf156 = reinterpret_tensor(buf149, (1024, 512), (512, 1), 0); del buf149 # reuse
# Topologically Sorted Source Nodes: [linear_31], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf151, (1024, 512), (512, 1), 0), _frozen_param108, out=buf156)
buf158 = reinterpret_tensor(buf151, (32, 32, 512), (16384, 512, 1), 0); del buf151 # reuse
# Topologically Sorted Source Nodes: [out_3, h_5, float_25, pow_10, mean_9, add_18, rsqrt_9, mul_32, output_19, mul_33], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf124, buf129, buf156, _frozen_param42, buf158, 1024, 512, grid=grid(1024), stream=stream0)
buf159 = buf127; del buf127 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf158, (1024, 512), (512, 1), 0), _frozen_param144, out=buf159)
buf160 = reinterpret_tensor(buf132, (32, 32, 1536), (49152, 1536, 1), 0); del buf132 # reuse
# Topologically Sorted Source Nodes: [silu_4, mul_34], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf159, buf160, 1572864, grid=grid(1572864), stream=stream0)
buf161 = reinterpret_tensor(buf158, (1024, 512), (512, 1), 0); del buf158 # reuse
# Topologically Sorted Source Nodes: [linear_34], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0), _frozen_param111, out=buf161)
buf163 = reinterpret_tensor(buf91, (32, 32, 512), (16384, 512, 1), 0); del buf91 # reuse
# Topologically Sorted Source Nodes: [out_3, h_5, out_4, float_26, pow_11, mean_10, add_20, rsqrt_10, mul_35, output_20, mul_36], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf124, buf129, buf156, buf161, _frozen_param46, buf163, 1024, 512, grid=grid(1024), stream=stream0)
buf164 = reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0); del buf160 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf163, (1024, 512), (512, 1), 0), _frozen_param145, out=buf164)
# Topologically Sorted Source Nodes: [setitem_11], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf164, arg87_1, 524288, grid=grid(524288), stream=stream0)
buf165 = buf140; del buf140 # reuse
# Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf164, buf165, 524288, grid=grid(524288), stream=stream0)
buf172 = buf133; del buf133 # reuse
# Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf164, buf172, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex]
buf166 = torch.ops.aten.view_as_complex.default(buf165)
buf167 = buf166
# Topologically Sorted Source Nodes: [mul_37], Original ATen: [aten.mul]
buf168 = torch.ops.aten.mul.Tensor(buf167, _frozen_param79)
del buf166
del buf167
buf169 = buf168
del buf168
# Topologically Sorted Source Nodes: [view_as_real_10], Original ATen: [aten.view_as_real]
buf170 = torch.ops.aten.view_as_real.default(buf169)
buf171 = buf170
buf181 = reinterpret_tensor(buf163, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf163 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf171, buf181, 524288, grid=grid(524288), stream=stream0)
del buf169
del buf170
del buf171
# Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex]
buf173 = torch.ops.aten.view_as_complex.default(buf172)
buf174 = buf173
# Topologically Sorted Source Nodes: [mul_38], Original ATen: [aten.mul]
buf175 = torch.ops.aten.mul.Tensor(buf174, _frozen_param79)
del buf173
del buf174
buf176 = buf175
del buf175
# Topologically Sorted Source Nodes: [view_as_real_11], Original ATen: [aten.view_as_real]
buf177 = torch.ops.aten.view_as_real.default(buf176)
buf178 = buf177
# Topologically Sorted Source Nodes: [xk_17, setitem_10], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf178, arg86_1, 524288, grid=grid(524288), stream=stream0)
del buf176
del buf177
del buf178
# Topologically Sorted Source Nodes: [], Original ATen: []
buf182 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf181, reinterpret_tensor(arg86_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg87_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf183 = buf182[0]
del buf182
buf188 = reinterpret_tensor(buf181, (1024, 512), (512, 1), 0); del buf181 # reuse
# Topologically Sorted Source Nodes: [linear_38], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf183, (1024, 512), (512, 1), 0), _frozen_param115, out=buf188)
buf189 = buf124; del buf124 # reuse
buf191 = reinterpret_tensor(buf183, (32, 32, 512), (16384, 512, 1), 0); del buf183 # reuse
# Topologically Sorted Source Nodes: [out_3, h_5, out_4, h_6, float_30, pow_12, mean_11, add_22, rsqrt_11, mul_39, output_23, mul_40], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf189, buf129, buf156, buf161, buf188, _frozen_param51, buf191, 1024, 512, grid=grid(1024), stream=stream0)
del buf129
del buf156
buf192 = buf159; del buf159 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf191, (1024, 512), (512, 1), 0), _frozen_param146, out=buf192)
buf193 = reinterpret_tensor(buf164, (32, 32, 1536), (49152, 1536, 1), 0); del buf164 # reuse
# Topologically Sorted Source Nodes: [silu_5, mul_41], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf192, buf193, 1572864, grid=grid(1572864), stream=stream0)
buf194 = reinterpret_tensor(buf191, (1024, 512), (512, 1), 0); del buf191 # reuse
# Topologically Sorted Source Nodes: [linear_41], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0), _frozen_param118, out=buf194)
buf196 = reinterpret_tensor(buf188, (32, 32, 512), (16384, 512, 1), 0); del buf188 # reuse
# Topologically Sorted Source Nodes: [out_5, float_31, pow_13, mean_12, add_24, rsqrt_12, mul_42, output_24, mul_43], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf189, buf194, _frozen_param55, buf196, 1024, 512, grid=grid(1024), stream=stream0)
buf197 = reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0); del buf193 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf196, (1024, 512), (512, 1), 0), _frozen_param147, out=buf197)
# Topologically Sorted Source Nodes: [setitem_13], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf197, arg89_1, 524288, grid=grid(524288), stream=stream0)
buf198 = buf172; del buf172 # reuse
# Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf197, buf198, 524288, grid=grid(524288), stream=stream0)
buf205 = buf165; del buf165 # reuse
# Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf197, buf205, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex]
buf199 = torch.ops.aten.view_as_complex.default(buf198)
buf200 = buf199
# Topologically Sorted Source Nodes: [mul_44], Original ATen: [aten.mul]
buf201 = torch.ops.aten.mul.Tensor(buf200, _frozen_param79)
del buf199
del buf200
buf202 = buf201
del buf201
# Topologically Sorted Source Nodes: [view_as_real_12], Original ATen: [aten.view_as_real]
buf203 = torch.ops.aten.view_as_real.default(buf202)
buf204 = buf203
buf214 = reinterpret_tensor(buf196, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf196 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf204, buf214, 524288, grid=grid(524288), stream=stream0)
del buf202
del buf203
del buf204
# Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex]
buf206 = torch.ops.aten.view_as_complex.default(buf205)
buf207 = buf206
# Topologically Sorted Source Nodes: [mul_45], Original ATen: [aten.mul]
buf208 = torch.ops.aten.mul.Tensor(buf207, _frozen_param79)
del buf206
del buf207
buf209 = buf208
del buf208
# Topologically Sorted Source Nodes: [view_as_real_13], Original ATen: [aten.view_as_real]
buf210 = torch.ops.aten.view_as_real.default(buf209)
buf211 = buf210
# Topologically Sorted Source Nodes: [xk_20, setitem_12], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf211, arg88_1, 524288, grid=grid(524288), stream=stream0)
del buf209
del buf210
del buf211
# Topologically Sorted Source Nodes: [], Original ATen: []
buf215 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf214, reinterpret_tensor(arg88_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg89_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf216 = buf215[0]
del buf215
buf221 = reinterpret_tensor(buf214, (1024, 512), (512, 1), 0); del buf214 # reuse
# Topologically Sorted Source Nodes: [linear_45], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf216, (1024, 512), (512, 1), 0), _frozen_param122, out=buf221)
buf223 = reinterpret_tensor(buf216, (32, 32, 512), (16384, 512, 1), 0); del buf216 # reuse
# Topologically Sorted Source Nodes: [out_5, h_7, float_35, pow_14, mean_13, add_26, rsqrt_13, mul_46, output_27, mul_47], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf189, buf194, buf221, _frozen_param60, buf223, 1024, 512, grid=grid(1024), stream=stream0)
buf224 = buf192; del buf192 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf223, (1024, 512), (512, 1), 0), _frozen_param148, out=buf224)
buf225 = reinterpret_tensor(buf197, (32, 32, 1536), (49152, 1536, 1), 0); del buf197 # reuse
# Topologically Sorted Source Nodes: [silu_6, mul_48], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf224, buf225, 1572864, grid=grid(1572864), stream=stream0)
buf226 = reinterpret_tensor(buf223, (1024, 512), (512, 1), 0); del buf223 # reuse
# Topologically Sorted Source Nodes: [linear_48], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0), _frozen_param125, out=buf226)
buf228 = reinterpret_tensor(buf161, (32, 32, 512), (16384, 512, 1), 0); del buf161 # reuse
# Topologically Sorted Source Nodes: [out_5, h_7, out_6, float_36, pow_15, mean_14, add_28, rsqrt_14, mul_49, output_28, mul_50], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf189, buf194, buf221, buf226, _frozen_param64, buf228, 1024, 512, grid=grid(1024), stream=stream0)
buf229 = reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0); del buf225 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf228, (1024, 512), (512, 1), 0), _frozen_param149, out=buf229)
# Topologically Sorted Source Nodes: [setitem_15], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf229, arg91_1, 524288, grid=grid(524288), stream=stream0)
buf230 = buf205; del buf205 # reuse
# Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf229, buf230, 524288, grid=grid(524288), stream=stream0)
buf237 = buf198; del buf198 # reuse
# Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf229, buf237, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex]
buf231 = torch.ops.aten.view_as_complex.default(buf230)
buf232 = buf231
# Topologically Sorted Source Nodes: [mul_51], Original ATen: [aten.mul]
buf233 = torch.ops.aten.mul.Tensor(buf232, _frozen_param79)
del buf230
del buf231
del buf232
buf234 = buf233
del buf233
# Topologically Sorted Source Nodes: [view_as_real_14], Original ATen: [aten.view_as_real]
buf235 = torch.ops.aten.view_as_real.default(buf234)
buf236 = buf235
buf246 = reinterpret_tensor(buf228, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf228 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf236, buf246, 524288, grid=grid(524288), stream=stream0)
del buf234
del buf235
del buf236
# Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex]
buf238 = torch.ops.aten.view_as_complex.default(buf237)
buf239 = buf238
# Topologically Sorted Source Nodes: [mul_52], Original ATen: [aten.mul]
buf240 = torch.ops.aten.mul.Tensor(buf239, _frozen_param79)
del buf237
del buf238
del buf239
buf241 = buf240
del buf240
# Topologically Sorted Source Nodes: [view_as_real_15], Original ATen: [aten.view_as_real]
buf242 = torch.ops.aten.view_as_real.default(buf241)
buf243 = buf242
# Topologically Sorted Source Nodes: [xk_23, setitem_14], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf243, arg90_1, 524288, grid=grid(524288), stream=stream0)
del buf241
del buf242
del buf243
# Topologically Sorted Source Nodes: [], Original ATen: []
buf247 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf246, reinterpret_tensor(arg90_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg91_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf248 = buf247[0]
del buf247
buf253 = reinterpret_tensor(buf246, (1024, 512), (512, 1), 0); del buf246 # reuse
# Topologically Sorted Source Nodes: [linear_52], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf248, (1024, 512), (512, 1), 0), _frozen_param129, out=buf253)
buf254 = buf189; del buf189 # reuse
buf256 = reinterpret_tensor(buf248, (32, 32, 512), (16384, 512, 1), 0); del buf248 # reuse
# Topologically Sorted Source Nodes: [out_5, h_7, out_6, h_8, float_40, pow_16, mean_15, add_30, rsqrt_15, mul_53, output_31, mul_54], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf254, buf194, buf221, buf226, buf253, _frozen_param69, buf256, 1024, 512, grid=grid(1024), stream=stream0)
del buf194
del buf221
del buf226
del buf253
buf257 = buf224; del buf224 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf256, (1024, 512), (512, 1), 0), _frozen_param150, out=buf257)
buf258 = reinterpret_tensor(buf229, (32, 32, 1536), (49152, 1536, 1), 0); del buf229 # reuse
# Topologically Sorted Source Nodes: [silu_7, mul_55], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf257, buf258, 1572864, grid=grid(1572864), stream=stream0)
del buf257
buf259 = reinterpret_tensor(buf256, (1024, 512), (512, 1), 0); del buf256 # reuse
# Topologically Sorted Source Nodes: [linear_55], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf258, (1024, 1536), (1536, 1), 0), _frozen_param132, out=buf259)
del buf258
buf261 = buf254; del buf254 # reuse
# Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf261, buf259, _frozen_param73, 1024, 512, grid=grid(1024), stream=stream0)
del buf259
buf262 = empty_strided_cuda((32, 32000), (32000, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [output_33], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf261, (32, 512), (16384, 1), 15872), _frozen_param134, out=buf262)
del buf261
buf263 = empty_strided_cuda((32, 32000), (32000, 1), torch.float32)
# Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_15.run(buf262, buf263, 1024000, grid=grid(1024000), stream=stream0)
del buf262
return (buf263, arg77_1, arg76_1, arg79_1, arg78_1, arg81_1, arg80_1, arg83_1, arg82_1, arg85_1, arg84_1, arg87_1, arg86_1, arg89_1, arg88_1, arg91_1, arg90_1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param0
_frozen_param0 = rand_strided((32000, 512), (512, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param1
_frozen_param1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param6
_frozen_param6 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param10
_frozen_param10 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param15
_frozen_param15 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param19
_frozen_param19 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param24
_frozen_param24 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param28
_frozen_param28 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param33
_frozen_param33 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param37
_frozen_param37 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param42
_frozen_param42 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param46
_frozen_param46 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param51
_frozen_param51 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param55
_frozen_param55 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param60
_frozen_param60 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param64
_frozen_param64 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param69
_frozen_param69 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param73
_frozen_param73 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param135
_frozen_param135 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param79
_frozen_param79 = rand_strided((1, 32, 1, 32), (1024, 32, 32, 1), device='cuda:0', dtype=torch.complex64)
global _frozen_param80
_frozen_param80 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param136
_frozen_param136 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param83
_frozen_param83 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param137
_frozen_param137 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param87
_frozen_param87 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param138
_frozen_param138 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param90
_frozen_param90 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param139
_frozen_param139 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param94
_frozen_param94 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param140
_frozen_param140 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param97
_frozen_param97 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param141
_frozen_param141 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param101
_frozen_param101 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param142
_frozen_param142 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param104
_frozen_param104 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param143
_frozen_param143 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param108
_frozen_param108 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param144
_frozen_param144 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param111
_frozen_param111 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param145
_frozen_param145 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param115
_frozen_param115 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param146
_frozen_param146 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param118
_frozen_param118 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param147
_frozen_param147 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param122
_frozen_param122 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param148
_frozen_param148 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param125
_frozen_param125 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param149
_frozen_param149 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param129
_frozen_param129 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param150
_frozen_param150 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param132
_frozen_param132 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param134
_frozen_param134 = rand_strided((512, 32000), (1, 512), device='cuda:0', dtype=torch.bfloat16)
arg76_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg77_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg78_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg79_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg80_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg81_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg82_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg83_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg84_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg85_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg86_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg87_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg88_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg89_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg90_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg91_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg92_1 = rand_strided((32, 32), (32, 1), device='cuda:0', dtype=torch.int32)
fn = lambda: call([arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('llama', benchmark_compiled_module)
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
_frozen_param0 = None # device(type='cuda', index=0) torch.bfloat16 (32000, 512) (512, 1) 7f03b04ed800
_frozen_param1 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b0477ce0
_frozen_param6 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04ef650
_frozen_param10 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04edf80
_frozen_param15 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04ede90
_frozen_param19 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b0477dd0
_frozen_param24 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a2390
_frozen_param28 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a0ae0
_frozen_param33 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b05b3c90
_frozen_param37 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b05b3ba0
_frozen_param42 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b05422a0
_frozen_param46 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04e34c0
_frozen_param51 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a38d0
_frozen_param55 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04e0f40
_frozen_param60 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a0400
_frozen_param64 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04e2480
_frozen_param69 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a23e0
_frozen_param73 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04ef9c0
_frozen_param135 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03241ca6b0
_frozen_param79 = None # device(type='cuda', index=0) torch.complex64 (1, 32, 1, 32) (1024, 32, 32, 1) 7f032416f060
_frozen_param80 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324131530
_frozen_param136 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f03241bf0b0
_frozen_param83 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f032427df30
_frozen_param137 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03241bda80
_frozen_param87 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324174b80
_frozen_param138 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f03241bfa60
_frozen_param90 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324174c70
_frozen_param139 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03243bc3b0
_frozen_param94 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324174db0
_frozen_param140 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f0324003830
_frozen_param97 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324174ea0
_frozen_param141 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03243bc770
_frozen_param101 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324174fe0
_frozen_param142 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f0324204950
_frozen_param104 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f03241750d0
_frozen_param143 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032418ff10
_frozen_param108 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324175210
_frozen_param144 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032418e1b0
_frozen_param111 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175300
_frozen_param145 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032418fa10
_frozen_param115 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324175440
_frozen_param146 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032418ef20
_frozen_param118 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175530
_frozen_param147 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032418fdd0
_frozen_param122 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324175670
_frozen_param148 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032401c450
_frozen_param125 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175760
_frozen_param149 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032401f830
_frozen_param129 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f03241758a0
_frozen_param150 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032401fbf0
_frozen_param132 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175990
_frozen_param134 = None # device(type='cuda', index=0) torch.bfloat16 (512, 32000) (1, 512) 7f0324130c70
# kernel path: /tmp/torchinductor_t/jq/cjqkz4jt5d5lkuk3r2wju344fss2vw6bnduyr5hqduvhlizskx43.py
# Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add
# float_1 => convert_element_type_1
# h => embedding
# mean => mean
# mul => mul
# mul_1 => mul_1
# output => convert_element_type_2
# pow_1 => pow_1
# rsqrt => rsqrt
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %convert_element_type_1 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%embedding, torch.float32), kwargs = {})
# %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_1, 2), kwargs = {})
# %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
# %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %rsqrt), kwargs = {})
# %convert_element_type_2 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %arg1_1), kwargs = {})
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1024
rnumel = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp1 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert(((0 <= tmp4) & (tmp4 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask & xmask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp26 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp12 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32)
tmp13 = tmp0 + tmp12
tmp14 = tmp0 < 0
tmp15 = tl.where(tmp14, tmp13, tmp0)
tl.device_assert(((0 <= tmp15) & (tmp15 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp15 < 32000")
tmp17 = tl.load(in_ptr1 + (r1 + 512*tmp15), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp18 = tmp17.to(tl.float32)
tmp19 = 512.0
tmp20 = (tmp10 / tmp19).to(tl.float32)
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp25 = tmp24.to(tl.float32)
tmp27 = tmp25 * tmp26
tl.store(out_ptr1 + (r1 + 512*x0), tmp27, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/xc/cxcmzl6ilmlg65glf4pzjcxwxpduicy2vfmzzjio5qccmdqzy5s3.py
# Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy]
# Source node to ATen node mapping:
# setitem_1 => copy_1
# Graph fragment:
# %copy_1 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_4, %view_8), kwargs = {})
# %copy__default_1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor_1, %copy_1), kwargs = {})
triton_poi_fused_copy_1 = async_compile.triton('triton_poi_fused_copy_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_copy_1', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 512)
x3 = xindex // 512
x2 = xindex // 16384
x4 = (xindex % 16384)
tmp0 = tl.load(in_ptr0 + (x0 + 1536*x3), None).to(tl.float32)
tl.store(out_ptr0 + (512 + x4 + 524288*x2), tmp0, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/7m/c7m45i7porfguqnzfi4evudvthp6s46rtboxlqxs2ukmts7oogqy.py
# Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex]
# Source node to ATen node mapping:
# xq_ => view_as_complex
# Graph fragment:
# %view_as_complex : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_9,), kwargs = {})
triton_poi_fused_view_as_complex_2 = async_compile.triton('triton_poi_fused_view_as_complex_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_view_as_complex_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 512)
x1 = xindex // 512
x2 = xindex
tmp0 = tl.load(in_ptr0 + (512 + x0 + 1536*x1), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x2), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/65/c65ekeiwoihqn3z3eozs5kfs2w7fuc6dtsursyyedmm5n3idz6aw.py
# Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex]
# Source node to ATen node mapping:
# xk_ => view_as_complex_1
# Graph fragment:
# %view_as_complex_1 : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_10,), kwargs = {})
triton_poi_fused_view_as_complex_3 = async_compile.triton('triton_poi_fused_view_as_complex_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_view_as_complex_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 512)
x1 = xindex // 512
x2 = xindex
tmp0 = tl.load(in_ptr0 + (1024 + x0 + 1536*x1), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x2), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/fr/cfrphf4roqxuzvwmpfjsqykxw7ykp6blxuab7a762yx75y5ssqwv.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
# %_scaled_dot_product_flash_attention_default_7 : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%permute_default_21, %permute_default_22, %permute_default_23), kwargs = {scale: 0.125})
triton_poi_fused_4 = async_compile.triton('triton_poi_fused_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/pm/cpmwz4lr5hqfizv3wxqz5p5i5tjfamfs6btvjdkmpwz6sixmisvp.py
# Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy]
# Source node to ATen node mapping:
# setitem => copy
# xk_2 => convert_element_type_12
# Graph fragment:
# %convert_element_type_12 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_13, torch.bfloat16), kwargs = {})
# %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %convert_element_type_12), kwargs = {})
# %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor, %copy), kwargs = {})
triton_poi_fused__to_copy_copy_5 = async_compile.triton('triton_poi_fused__to_copy_copy_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 524288},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_copy_5', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_copy_5(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x2 = xindex
x0 = (xindex % 16384)
x1 = xindex // 16384
tmp0 = tl.load(in_ptr0 + (x2), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (512 + x0 + 524288*x1), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/dr/cdrrgy3ocqlbwxy4cv6ytxs3apsjv37qs3l2swp2p2sgajmtmocx.py
# Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_2 => add_2
# float_5 => convert_element_type_21
# h => embedding
# h_1 => add_1
# mean_1 => mean_1
# mul_4 => mul_4
# mul_5 => mul_5
# output_3 => convert_element_type_22
# pow_2 => pow_2
# rsqrt_1 => rsqrt_1
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {})
# %convert_element_type_21 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.float32), kwargs = {})
# %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_21, 2), kwargs = {})
# %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
# %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_2,), kwargs = {})
# %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_21, %rsqrt_1), kwargs = {})
# %convert_element_type_22 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_4, torch.bfloat16), kwargs = {})
# %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %arg6_1), kwargs = {})
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
x0 = xindex
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp21 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.full([RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32)
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [RBLOCK])
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp14 = 512.0
tmp15 = (tmp13 / tmp14).to(tl.float32)
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tmp19 = tmp9 * tmp18
tmp20 = tmp19.to(tl.float32)
tmp22 = tmp20 * tmp21
tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/37/c37cieheeo7zfl4rbysqwzkbdr7pwfgmv5au7ifpl2r5xifjrj5v.py
# Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul]
# Source node to ATen node mapping:
# mul_6 => mul_7
# silu => convert_element_type_25, convert_element_type_26, mul_6, sigmoid
# Graph fragment:
# %convert_element_type_25 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_24, torch.float32), kwargs = {})
# %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_25,), kwargs = {})
# %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_25, %sigmoid), kwargs = {})
# %convert_element_type_26 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {})
# %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_26, %view_26), kwargs = {})
triton_poi_fused_mul_silu_7 = async_compile.triton('triton_poi_fused_mul_silu_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 2097152},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_silu_7(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1572864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = (xindex % 1536)
x1 = xindex // 1536
x2 = xindex
tmp0 = tl.load(in_ptr0 + (1536 + x0 + 3072*x1), None).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0 + 3072*x1), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(out_ptr0 + (x2), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/oe/coezydc4r3hd4ebv3eien4ar5zw7e5oum674ezorliyuknllnim6.py
# Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_4 => add_4
# float_6 => convert_element_type_31
# h => embedding
# h_1 => add_1
# mean_2 => mean_2
# mul_7 => mul_8
# mul_8 => mul_9
# out => add_3
# output_4 => convert_element_type_32
# pow_3 => pow_3
# rsqrt_2 => rsqrt_2
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {})
# %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {})
# %convert_element_type_31 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.float32), kwargs = {})
# %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_31, 2), kwargs = {})
# %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
# %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
# %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_4,), kwargs = {})
# %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_31, %rsqrt_2), kwargs = {})
# %convert_element_type_32 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_8, torch.bfloat16), kwargs = {})
# %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_32, %arg10_1), kwargs = {})
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
x0 = xindex
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp23 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.full([RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32)
tmp8 = tmp6 + tmp7
tmp10 = tmp8 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tmp11 * tmp11
tmp13 = tl.broadcast_to(tmp12, [RBLOCK])
tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp13, 0))
tmp16 = 512.0
tmp17 = (tmp15 / tmp16).to(tl.float32)
tmp18 = 1e-05
tmp19 = tmp17 + tmp18
tmp20 = libdevice.rsqrt(tmp19)
tmp21 = tmp11 * tmp20
tmp22 = tmp21.to(tl.float32)
tmp24 = tmp22 * tmp23
tl.store(out_ptr1 + (r1 + 512*x0), tmp24, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/hg/chg35he4i3wb65f2xla6t2f46hl3nvqroelcorgcij2tuqcakabg.py
# Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_6 => add_6
# float_10 => convert_element_type_51
# h => embedding
# h_1 => add_1
# h_2 => add_5
# mean_3 => mean_3
# mul_11 => mul_12
# mul_12 => mul_13
# out => add_3
# output_7 => convert_element_type_52
# pow_4 => pow_4
# rsqrt_3 => rsqrt_3
# Graph fragment:
# %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {})
# %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {})
# %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {})
# %add_5 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %view_51), kwargs = {})
# %convert_element_type_51 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_5, torch.float32), kwargs = {})
# %pow_4 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_51, 2), kwargs = {})
# %mean_3 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
# %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
# %rsqrt_3 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {})
# %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_51, %rsqrt_3), kwargs = {})
# %convert_element_type_52 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_12, torch.bfloat16), kwargs = {})
# %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_52, %arg15_1), kwargs = {})
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
x0 = xindex
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp7 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp11 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp25 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.full([RBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32)
tmp8 = tmp6 + tmp7
tmp10 = tmp8 + tmp9
tmp12 = tmp10 + tmp11
tmp13 = tmp12.to(tl.float32)
tmp14 = tmp13 * tmp13
tmp15 = tl.broadcast_to(tmp14, [RBLOCK])
tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp15, 0))
tmp18 = 512.0
tmp19 = (tmp17 / tmp18).to(tl.float32)
tmp20 = 1e-05
tmp21 = tmp19 + tmp20
tmp22 = libdevice.rsqrt(tmp21)
tmp23 = tmp13 * tmp22
tmp24 = tmp23.to(tl.float32)
tmp26 = tmp24 * tmp25
tl.store(in_out_ptr0 + (r1 + 512*x0), tmp12, None)
tl.store(out_ptr1 + (r1 + 512*x0), tmp26, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/iy/ciy4z4yq6ins6l5ivop2w7p7zdinblxm2voqklkpqk7pgxbr7pil.py
# Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_8 => add_8
# float_11 => convert_element_type_61
# mean_4 => mean_4
# mul_14 => mul_16
# mul_15 => mul_17
# out_1 => add_7
# output_8 => convert_element_type_62
# pow_5 => pow_5
# rsqrt_4 => rsqrt_4
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %convert_element_type_61 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.float32), kwargs = {})
# %pow_5 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_61, 2), kwargs = {})
# %mean_4 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
# %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
# %rsqrt_4 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {})
# %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_61, %rsqrt_4), kwargs = {})
# %convert_element_type_62 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
# %mul_17 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_62, %arg19_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp15 = tl.load(in_ptr2 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [RBLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp8 = 512.0
tmp9 = (tmp7 / tmp8).to(tl.float32)
tmp10 = 1e-05
tmp11 = tmp9 + tmp10
tmp12 = libdevice.rsqrt(tmp11)
tmp13 = tmp3 * tmp12
tmp14 = tmp13.to(tl.float32)
tmp16 = tmp14 * tmp15
tl.store(out_ptr1 + (r1 + 512*x0), tmp16, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/pn/cpnh5dfi4oms6ccoppdagviyr73sfnfnxsb7fxqww3qrixi7lyh4.py
# Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_10 => add_10
# float_15 => convert_element_type_81
# h_3 => add_9
# mean_5 => mean_5
# mul_18 => mul_20
# mul_19 => mul_21
# out_1 => add_7
# output_11 => convert_element_type_82
# pow_6 => pow_6
# rsqrt_5 => rsqrt_5
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {})
# %convert_element_type_81 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_9, torch.float32), kwargs = {})
# %pow_6 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_81, 2), kwargs = {})
# %mean_5 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_6, [-1], True), kwargs = {})
# %add_10 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_5, 1e-05), kwargs = {})
# %rsqrt_5 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_10,), kwargs = {})
# %mul_20 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_81, %rsqrt_5), kwargs = {})
# %convert_element_type_82 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_20, torch.bfloat16), kwargs = {})
# %mul_21 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_82, %arg24_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp5 * tmp5
tmp7 = tl.broadcast_to(tmp6, [RBLOCK])
tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
tmp10 = 512.0
tmp11 = (tmp9 / tmp10).to(tl.float32)
tmp12 = 1e-05
tmp13 = tmp11 + tmp12
tmp14 = libdevice.rsqrt(tmp13)
tmp15 = tmp5 * tmp14
tmp16 = tmp15.to(tl.float32)
tmp18 = tmp16 * tmp17
tl.store(out_ptr1 + (r1 + 512*x0), tmp18, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/dw/cdwfdyb3igstcir6uoaelj2eyiyqrmimqovbk6l57r4zvxcvjy7k.py
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_12 => add_12
# float_16 => convert_element_type_91
# h_3 => add_9
# mean_6 => mean_6
# mul_21 => mul_24
# mul_22 => mul_25
# out_1 => add_7
# out_2 => add_11
# output_12 => convert_element_type_92
# pow_7 => pow_7
# rsqrt_6 => rsqrt_6
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {})
# %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {})
# %convert_element_type_91 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {})
# %pow_7 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_91, 2), kwargs = {})
# %mean_6 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_7, [-1], True), kwargs = {})
# %add_12 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_6, 1e-05), kwargs = {})
# %rsqrt_6 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {})
# %mul_24 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_91, %rsqrt_6), kwargs = {})
# %convert_element_type_92 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_24, torch.bfloat16), kwargs = {})
# %mul_25 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_92, %arg28_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp19 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7 * tmp7
tmp9 = tl.broadcast_to(tmp8, [RBLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp12 = 512.0
tmp13 = (tmp11 / tmp12).to(tl.float32)
tmp14 = 1e-05
tmp15 = tmp13 + tmp14
tmp16 = libdevice.rsqrt(tmp15)
tmp17 = tmp7 * tmp16
tmp18 = tmp17.to(tl.float32)
tmp20 = tmp18 * tmp19
tl.store(out_ptr1 + (r1 + 512*x0), tmp20, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/z3/cz3e34micft5wnsnm2gnpqmih6vve6ujd2ccrnpkysa5obrpxhpi.py
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_14 => add_14
# float_20 => convert_element_type_111
# h_3 => add_9
# h_4 => add_13
# mean_7 => mean_7
# mul_25 => mul_28
# mul_26 => mul_29
# out_1 => add_7
# out_2 => add_11
# output_15 => convert_element_type_112
# pow_8 => pow_8
# rsqrt_7 => rsqrt_7
# Graph fragment:
# %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {})
# %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {})
# %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {})
# %add_13 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_109), kwargs = {})
# %convert_element_type_111 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_13, torch.float32), kwargs = {})
# %pow_8 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_111, 2), kwargs = {})
# %mean_7 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_8, [-1], True), kwargs = {})
# %add_14 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_7, 1e-05), kwargs = {})
# %rsqrt_7 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_14,), kwargs = {})
# %mul_28 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_111, %rsqrt_7), kwargs = {})
# %convert_element_type_112 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_28, torch.bfloat16), kwargs = {})
# %mul_29 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_112, %arg33_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32)
tmp21 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 + tmp3
tmp6 = tmp4 + tmp5
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [RBLOCK])
tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))
tmp14 = 512.0
tmp15 = (tmp13 / tmp14).to(tl.float32)
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tmp19 = tmp9 * tmp18
tmp20 = tmp19.to(tl.float32)
tmp22 = tmp20 * tmp21
tl.store(in_out_ptr0 + (r1 + 512*x0), tmp8, None)
tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/lv/clv7iqdidc6yovjvhgzra2tlgehvzuyjf6c4drcdk3cgzbb7tbn4.py
# Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_32 => add_32
# float_41 => convert_element_type_241
# h_9 => mul_65
# mean_16 => mean_16
# mul_56 => mul_64
# out_7 => add_31
# output_32 => convert_element_type_242
# pow_17 => pow_17
# rsqrt_16 => rsqrt_16
# Graph fragment:
# %add_31 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_29, %view_231), kwargs = {})
# %convert_element_type_241 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_31, torch.float32), kwargs = {})
# %pow_17 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_241, 2), kwargs = {})
# %mean_16 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_17, [-1], True), kwargs = {})
# %add_32 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_16, 1e-05), kwargs = {})
# %rsqrt_16 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_32,), kwargs = {})
# %mul_64 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_241, %rsqrt_16), kwargs = {})
# %convert_element_type_242 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_64, torch.bfloat16), kwargs = {})
# %mul_65 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_242, %arg73_1), kwargs = {})
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1024, 'r': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14(in_out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel):
xnumel = 1024
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = tl.full([RBLOCK], True, tl.int1)
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32)
tmp15 = tl.load(in_ptr1 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3 * tmp3
tmp5 = tl.broadcast_to(tmp4, [RBLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp8 = 512.0
tmp9 = (tmp7 / tmp8).to(tl.float32)
tmp10 = 1e-05
tmp11 = tmp9 + tmp10
tmp12 = libdevice.rsqrt(tmp11)
tmp13 = tmp3 * tmp12
tmp14 = tmp13.to(tl.float32)
tmp16 = tmp14 * tmp15
tl.store(in_out_ptr0 + (r1 + 512*x0), tmp16, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_t/uy/cuyzkjsgklqarsm2yt3x6ruv2di7jsrafhofrtw3gdor4jvuttkh.py
# Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# float_42 => convert_element_type_245
# Graph fragment:
# %convert_element_type_245 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_56, torch.float32), kwargs = {})
triton_poi_fused__to_copy_15 = async_compile.triton('triton_poi_fused__to_copy_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_15(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1024000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1 = args
args.clear()
assert_size_stride(arg76_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg77_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg78_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg79_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg80_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg81_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg82_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg83_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg84_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg85_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg86_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg87_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg88_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg89_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg90_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg91_1, (32, 1024, 8, 64), (524288, 512, 64, 1))
assert_size_stride(arg92_1, (32, 32), (32, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg92_1, _frozen_param0, _frozen_param1, buf1, 1024, 512, grid=grid(1024), stream=stream0)
buf2 = empty_strided_cuda((1024, 1536), (1536, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf1, (1024, 512), (512, 1), 0), _frozen_param135, out=buf2)
# Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf2, arg77_1, 524288, grid=grid(524288), stream=stream0)
buf3 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32)
# Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf2, buf3, 524288, grid=grid(524288), stream=stream0)
buf10 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32)
# Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf2, buf10, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex]
buf4 = torch.ops.aten.view_as_complex.default(buf3)
buf5 = buf4
# Topologically Sorted Source Nodes: [mul_2], Original ATen: [aten.mul]
buf6 = torch.ops.aten.mul.Tensor(buf5, _frozen_param79)
del buf4
del buf5
buf7 = buf6
del buf6
# Topologically Sorted Source Nodes: [view_as_real], Original ATen: [aten.view_as_real]
buf8 = torch.ops.aten.view_as_real.default(buf7)
buf9 = buf8
buf19 = reinterpret_tensor(buf1, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf1 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf9, buf19, 524288, grid=grid(524288), stream=stream0)
del buf7
del buf8
del buf9
# Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex]
buf11 = torch.ops.aten.view_as_complex.default(buf10)
buf12 = buf11
# Topologically Sorted Source Nodes: [mul_3], Original ATen: [aten.mul]
buf13 = torch.ops.aten.mul.Tensor(buf12, _frozen_param79)
del buf11
del buf12
buf14 = buf13
del buf13
# Topologically Sorted Source Nodes: [view_as_real_1], Original ATen: [aten.view_as_real]
buf15 = torch.ops.aten.view_as_real.default(buf14)
buf16 = buf15
# Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf16, arg76_1, 524288, grid=grid(524288), stream=stream0)
del buf14
del buf15
del buf16
# Topologically Sorted Source Nodes: [], Original ATen: []
buf20 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf19, reinterpret_tensor(arg76_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg77_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf21 = buf20[0]
del buf20
buf26 = reinterpret_tensor(buf19, (1024, 512), (512, 1), 0); del buf19 # reuse
# Topologically Sorted Source Nodes: [linear_3], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf21, (1024, 512), (512, 1), 0), _frozen_param80, out=buf26)
buf28 = reinterpret_tensor(buf21, (32, 32, 512), (16384, 512, 1), 0); del buf21 # reuse
# Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6.run(arg92_1, _frozen_param0, buf26, _frozen_param6, buf28, 1024, 512, grid=grid(1024), stream=stream0)
buf29 = empty_strided_cuda((1024, 3072), (3072, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf28, (1024, 512), (512, 1), 0), _frozen_param136, out=buf29)
buf30 = reinterpret_tensor(buf2, (32, 32, 1536), (49152, 1536, 1), 0); del buf2 # reuse
# Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf29, buf30, 1572864, grid=grid(1572864), stream=stream0)
buf31 = reinterpret_tensor(buf28, (1024, 512), (512, 1), 0); del buf28 # reuse
# Topologically Sorted Source Nodes: [linear_6], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0), _frozen_param83, out=buf31)
buf33 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8.run(arg92_1, _frozen_param0, buf26, buf31, _frozen_param10, buf33, 1024, 512, grid=grid(1024), stream=stream0)
buf34 = reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0); del buf30 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf33, (1024, 512), (512, 1), 0), _frozen_param137, out=buf34)
# Topologically Sorted Source Nodes: [setitem_3], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf34, arg79_1, 524288, grid=grid(524288), stream=stream0)
buf35 = buf10; del buf10 # reuse
# Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf34, buf35, 524288, grid=grid(524288), stream=stream0)
buf42 = buf3; del buf3 # reuse
# Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf34, buf42, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex]
buf36 = torch.ops.aten.view_as_complex.default(buf35)
buf37 = buf36
# Topologically Sorted Source Nodes: [mul_9], Original ATen: [aten.mul]
buf38 = torch.ops.aten.mul.Tensor(buf37, _frozen_param79)
del buf36
del buf37
buf39 = buf38
del buf38
# Topologically Sorted Source Nodes: [view_as_real_2], Original ATen: [aten.view_as_real]
buf40 = torch.ops.aten.view_as_real.default(buf39)
buf41 = buf40
buf51 = reinterpret_tensor(buf33, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf33 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf41, buf51, 524288, grid=grid(524288), stream=stream0)
del buf39
del buf40
del buf41
# Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex]
buf43 = torch.ops.aten.view_as_complex.default(buf42)
buf44 = buf43
# Topologically Sorted Source Nodes: [mul_10], Original ATen: [aten.mul]
buf45 = torch.ops.aten.mul.Tensor(buf44, _frozen_param79)
del buf43
del buf44
buf46 = buf45
del buf45
# Topologically Sorted Source Nodes: [view_as_real_3], Original ATen: [aten.view_as_real]
buf47 = torch.ops.aten.view_as_real.default(buf46)
buf48 = buf47
# Topologically Sorted Source Nodes: [xk_5, setitem_2], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf48, arg78_1, 524288, grid=grid(524288), stream=stream0)
del buf46
del buf47
del buf48
# Topologically Sorted Source Nodes: [], Original ATen: []
buf52 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf51, reinterpret_tensor(arg78_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg79_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf53 = buf52[0]
del buf52
buf58 = reinterpret_tensor(buf51, (1024, 512), (512, 1), 0); del buf51 # reuse
# Topologically Sorted Source Nodes: [linear_10], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf53, (1024, 512), (512, 1), 0), _frozen_param87, out=buf58)
buf59 = reinterpret_tensor(buf26, (32, 32, 512), (16384, 512, 1), 0); del buf26 # reuse
buf61 = reinterpret_tensor(buf53, (32, 32, 512), (16384, 512, 1), 0); del buf53 # reuse
# Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9.run(buf59, arg92_1, _frozen_param0, buf31, buf58, _frozen_param15, buf61, 1024, 512, grid=grid(1024), stream=stream0)
del arg92_1
buf62 = buf29; del buf29 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf61, (1024, 512), (512, 1), 0), _frozen_param138, out=buf62)
buf63 = reinterpret_tensor(buf34, (32, 32, 1536), (49152, 1536, 1), 0); del buf34 # reuse
# Topologically Sorted Source Nodes: [silu_1, mul_13], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf62, buf63, 1572864, grid=grid(1572864), stream=stream0)
buf64 = reinterpret_tensor(buf61, (1024, 512), (512, 1), 0); del buf61 # reuse
# Topologically Sorted Source Nodes: [linear_13], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0), _frozen_param90, out=buf64)
buf66 = reinterpret_tensor(buf58, (32, 32, 512), (16384, 512, 1), 0); del buf58 # reuse
# Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf59, buf64, _frozen_param19, buf66, 1024, 512, grid=grid(1024), stream=stream0)
buf67 = reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0); del buf63 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf66, (1024, 512), (512, 1), 0), _frozen_param139, out=buf67)
# Topologically Sorted Source Nodes: [setitem_5], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf67, arg81_1, 524288, grid=grid(524288), stream=stream0)
buf68 = buf42; del buf42 # reuse
# Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf67, buf68, 524288, grid=grid(524288), stream=stream0)
buf75 = buf35; del buf35 # reuse
# Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf67, buf75, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex]
buf69 = torch.ops.aten.view_as_complex.default(buf68)
buf70 = buf69
# Topologically Sorted Source Nodes: [mul_16], Original ATen: [aten.mul]
buf71 = torch.ops.aten.mul.Tensor(buf70, _frozen_param79)
del buf69
del buf70
buf72 = buf71
del buf71
# Topologically Sorted Source Nodes: [view_as_real_4], Original ATen: [aten.view_as_real]
buf73 = torch.ops.aten.view_as_real.default(buf72)
buf74 = buf73
buf84 = reinterpret_tensor(buf66, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf66 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf74, buf84, 524288, grid=grid(524288), stream=stream0)
del buf72
del buf73
del buf74
# Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex]
buf76 = torch.ops.aten.view_as_complex.default(buf75)
buf77 = buf76
# Topologically Sorted Source Nodes: [mul_17], Original ATen: [aten.mul]
buf78 = torch.ops.aten.mul.Tensor(buf77, _frozen_param79)
del buf76
del buf77
buf79 = buf78
del buf78
# Topologically Sorted Source Nodes: [view_as_real_5], Original ATen: [aten.view_as_real]
buf80 = torch.ops.aten.view_as_real.default(buf79)
buf81 = buf80
# Topologically Sorted Source Nodes: [xk_8, setitem_4], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf81, arg80_1, 524288, grid=grid(524288), stream=stream0)
del buf79
del buf80
del buf81
# Topologically Sorted Source Nodes: [], Original ATen: []
buf85 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf84, reinterpret_tensor(arg80_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg81_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf86 = buf85[0]
del buf85
buf91 = reinterpret_tensor(buf84, (1024, 512), (512, 1), 0); del buf84 # reuse
# Topologically Sorted Source Nodes: [linear_17], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf86, (1024, 512), (512, 1), 0), _frozen_param94, out=buf91)
buf93 = reinterpret_tensor(buf86, (32, 32, 512), (16384, 512, 1), 0); del buf86 # reuse
# Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf59, buf64, buf91, _frozen_param24, buf93, 1024, 512, grid=grid(1024), stream=stream0)
buf94 = buf62; del buf62 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf93, (1024, 512), (512, 1), 0), _frozen_param140, out=buf94)
buf95 = reinterpret_tensor(buf67, (32, 32, 1536), (49152, 1536, 1), 0); del buf67 # reuse
# Topologically Sorted Source Nodes: [silu_2, mul_20], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf94, buf95, 1572864, grid=grid(1572864), stream=stream0)
buf96 = reinterpret_tensor(buf93, (1024, 512), (512, 1), 0); del buf93 # reuse
# Topologically Sorted Source Nodes: [linear_20], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0), _frozen_param97, out=buf96)
buf98 = reinterpret_tensor(buf31, (32, 32, 512), (16384, 512, 1), 0); del buf31 # reuse
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf59, buf64, buf91, buf96, _frozen_param28, buf98, 1024, 512, grid=grid(1024), stream=stream0)
buf99 = reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0); del buf95 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf98, (1024, 512), (512, 1), 0), _frozen_param141, out=buf99)
# Topologically Sorted Source Nodes: [setitem_7], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf99, arg83_1, 524288, grid=grid(524288), stream=stream0)
buf100 = buf75; del buf75 # reuse
# Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf99, buf100, 524288, grid=grid(524288), stream=stream0)
buf107 = buf68; del buf68 # reuse
# Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf99, buf107, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex]
buf101 = torch.ops.aten.view_as_complex.default(buf100)
buf102 = buf101
# Topologically Sorted Source Nodes: [mul_23], Original ATen: [aten.mul]
buf103 = torch.ops.aten.mul.Tensor(buf102, _frozen_param79)
del buf101
del buf102
buf104 = buf103
del buf103
# Topologically Sorted Source Nodes: [view_as_real_6], Original ATen: [aten.view_as_real]
buf105 = torch.ops.aten.view_as_real.default(buf104)
buf106 = buf105
buf116 = reinterpret_tensor(buf98, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf98 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf106, buf116, 524288, grid=grid(524288), stream=stream0)
del buf104
del buf105
del buf106
# Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex]
buf108 = torch.ops.aten.view_as_complex.default(buf107)
buf109 = buf108
# Topologically Sorted Source Nodes: [mul_24], Original ATen: [aten.mul]
buf110 = torch.ops.aten.mul.Tensor(buf109, _frozen_param79)
del buf108
del buf109
buf111 = buf110
del buf110
# Topologically Sorted Source Nodes: [view_as_real_7], Original ATen: [aten.view_as_real]
buf112 = torch.ops.aten.view_as_real.default(buf111)
buf113 = buf112
# Topologically Sorted Source Nodes: [xk_11, setitem_6], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf113, arg82_1, 524288, grid=grid(524288), stream=stream0)
del buf111
del buf112
del buf113
# Topologically Sorted Source Nodes: [], Original ATen: []
buf117 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf116, reinterpret_tensor(arg82_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg83_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf118 = buf117[0]
del buf117
buf123 = reinterpret_tensor(buf116, (1024, 512), (512, 1), 0); del buf116 # reuse
# Topologically Sorted Source Nodes: [linear_24], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf118, (1024, 512), (512, 1), 0), _frozen_param101, out=buf123)
buf124 = buf59; del buf59 # reuse
buf126 = reinterpret_tensor(buf118, (32, 32, 512), (16384, 512, 1), 0); del buf118 # reuse
# Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf124, buf64, buf91, buf96, buf123, _frozen_param33, buf126, 1024, 512, grid=grid(1024), stream=stream0)
del buf123
del buf64
buf127 = buf94; del buf94 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf126, (1024, 512), (512, 1), 0), _frozen_param142, out=buf127)
buf128 = reinterpret_tensor(buf99, (32, 32, 1536), (49152, 1536, 1), 0); del buf99 # reuse
# Topologically Sorted Source Nodes: [silu_3, mul_27], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf127, buf128, 1572864, grid=grid(1572864), stream=stream0)
buf129 = reinterpret_tensor(buf126, (1024, 512), (512, 1), 0); del buf126 # reuse
# Topologically Sorted Source Nodes: [linear_27], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0), _frozen_param104, out=buf129)
buf131 = reinterpret_tensor(buf96, (32, 32, 512), (16384, 512, 1), 0); del buf96 # reuse
# Topologically Sorted Source Nodes: [out_3, float_21, pow_9, mean_8, add_16, rsqrt_8, mul_28, output_16, mul_29], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf124, buf129, _frozen_param37, buf131, 1024, 512, grid=grid(1024), stream=stream0)
buf132 = reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0); del buf128 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf131, (1024, 512), (512, 1), 0), _frozen_param143, out=buf132)
# Topologically Sorted Source Nodes: [setitem_9], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf132, arg85_1, 524288, grid=grid(524288), stream=stream0)
buf133 = buf107; del buf107 # reuse
# Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf132, buf133, 524288, grid=grid(524288), stream=stream0)
buf140 = buf100; del buf100 # reuse
# Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf132, buf140, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex]
buf134 = torch.ops.aten.view_as_complex.default(buf133)
buf135 = buf134
# Topologically Sorted Source Nodes: [mul_30], Original ATen: [aten.mul]
buf136 = torch.ops.aten.mul.Tensor(buf135, _frozen_param79)
del buf134
del buf135
buf137 = buf136
del buf136
# Topologically Sorted Source Nodes: [view_as_real_8], Original ATen: [aten.view_as_real]
buf138 = torch.ops.aten.view_as_real.default(buf137)
buf139 = buf138
buf149 = reinterpret_tensor(buf131, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf131 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf139, buf149, 524288, grid=grid(524288), stream=stream0)
del buf137
del buf138
del buf139
# Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex]
buf141 = torch.ops.aten.view_as_complex.default(buf140)
buf142 = buf141
# Topologically Sorted Source Nodes: [mul_31], Original ATen: [aten.mul]
buf143 = torch.ops.aten.mul.Tensor(buf142, _frozen_param79)
del buf141
del buf142
buf144 = buf143
del buf143
# Topologically Sorted Source Nodes: [view_as_real_9], Original ATen: [aten.view_as_real]
buf145 = torch.ops.aten.view_as_real.default(buf144)
buf146 = buf145
# Topologically Sorted Source Nodes: [xk_14, setitem_8], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf146, arg84_1, 524288, grid=grid(524288), stream=stream0)
del buf144
del buf145
del buf146
# Topologically Sorted Source Nodes: [], Original ATen: []
buf150 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf149, reinterpret_tensor(arg84_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg85_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf151 = buf150[0]
del buf150
buf156 = reinterpret_tensor(buf149, (1024, 512), (512, 1), 0); del buf149 # reuse
# Topologically Sorted Source Nodes: [linear_31], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf151, (1024, 512), (512, 1), 0), _frozen_param108, out=buf156)
buf158 = reinterpret_tensor(buf151, (32, 32, 512), (16384, 512, 1), 0); del buf151 # reuse
# Topologically Sorted Source Nodes: [out_3, h_5, float_25, pow_10, mean_9, add_18, rsqrt_9, mul_32, output_19, mul_33], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf124, buf129, buf156, _frozen_param42, buf158, 1024, 512, grid=grid(1024), stream=stream0)
buf159 = buf127; del buf127 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf158, (1024, 512), (512, 1), 0), _frozen_param144, out=buf159)
buf160 = reinterpret_tensor(buf132, (32, 32, 1536), (49152, 1536, 1), 0); del buf132 # reuse
# Topologically Sorted Source Nodes: [silu_4, mul_34], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf159, buf160, 1572864, grid=grid(1572864), stream=stream0)
buf161 = reinterpret_tensor(buf158, (1024, 512), (512, 1), 0); del buf158 # reuse
# Topologically Sorted Source Nodes: [linear_34], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0), _frozen_param111, out=buf161)
buf163 = reinterpret_tensor(buf91, (32, 32, 512), (16384, 512, 1), 0); del buf91 # reuse
# Topologically Sorted Source Nodes: [out_3, h_5, out_4, float_26, pow_11, mean_10, add_20, rsqrt_10, mul_35, output_20, mul_36], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf124, buf129, buf156, buf161, _frozen_param46, buf163, 1024, 512, grid=grid(1024), stream=stream0)
buf164 = reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0); del buf160 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf163, (1024, 512), (512, 1), 0), _frozen_param145, out=buf164)
# Topologically Sorted Source Nodes: [setitem_11], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf164, arg87_1, 524288, grid=grid(524288), stream=stream0)
buf165 = buf140; del buf140 # reuse
# Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf164, buf165, 524288, grid=grid(524288), stream=stream0)
buf172 = buf133; del buf133 # reuse
# Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf164, buf172, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex]
buf166 = torch.ops.aten.view_as_complex.default(buf165)
buf167 = buf166
# Topologically Sorted Source Nodes: [mul_37], Original ATen: [aten.mul]
buf168 = torch.ops.aten.mul.Tensor(buf167, _frozen_param79)
del buf166
del buf167
buf169 = buf168
del buf168
# Topologically Sorted Source Nodes: [view_as_real_10], Original ATen: [aten.view_as_real]
buf170 = torch.ops.aten.view_as_real.default(buf169)
buf171 = buf170
buf181 = reinterpret_tensor(buf163, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf163 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf171, buf181, 524288, grid=grid(524288), stream=stream0)
del buf169
del buf170
del buf171
# Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex]
buf173 = torch.ops.aten.view_as_complex.default(buf172)
buf174 = buf173
# Topologically Sorted Source Nodes: [mul_38], Original ATen: [aten.mul]
buf175 = torch.ops.aten.mul.Tensor(buf174, _frozen_param79)
del buf173
del buf174
buf176 = buf175
del buf175
# Topologically Sorted Source Nodes: [view_as_real_11], Original ATen: [aten.view_as_real]
buf177 = torch.ops.aten.view_as_real.default(buf176)
buf178 = buf177
# Topologically Sorted Source Nodes: [xk_17, setitem_10], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf178, arg86_1, 524288, grid=grid(524288), stream=stream0)
del buf176
del buf177
del buf178
# Topologically Sorted Source Nodes: [], Original ATen: []
buf182 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf181, reinterpret_tensor(arg86_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg87_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf183 = buf182[0]
del buf182
buf188 = reinterpret_tensor(buf181, (1024, 512), (512, 1), 0); del buf181 # reuse
# Topologically Sorted Source Nodes: [linear_38], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf183, (1024, 512), (512, 1), 0), _frozen_param115, out=buf188)
buf189 = buf124; del buf124 # reuse
buf191 = reinterpret_tensor(buf183, (32, 32, 512), (16384, 512, 1), 0); del buf183 # reuse
# Topologically Sorted Source Nodes: [out_3, h_5, out_4, h_6, float_30, pow_12, mean_11, add_22, rsqrt_11, mul_39, output_23, mul_40], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf189, buf129, buf156, buf161, buf188, _frozen_param51, buf191, 1024, 512, grid=grid(1024), stream=stream0)
del buf129
del buf156
buf192 = buf159; del buf159 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf191, (1024, 512), (512, 1), 0), _frozen_param146, out=buf192)
buf193 = reinterpret_tensor(buf164, (32, 32, 1536), (49152, 1536, 1), 0); del buf164 # reuse
# Topologically Sorted Source Nodes: [silu_5, mul_41], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf192, buf193, 1572864, grid=grid(1572864), stream=stream0)
buf194 = reinterpret_tensor(buf191, (1024, 512), (512, 1), 0); del buf191 # reuse
# Topologically Sorted Source Nodes: [linear_41], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0), _frozen_param118, out=buf194)
buf196 = reinterpret_tensor(buf188, (32, 32, 512), (16384, 512, 1), 0); del buf188 # reuse
# Topologically Sorted Source Nodes: [out_5, float_31, pow_13, mean_12, add_24, rsqrt_12, mul_42, output_24, mul_43], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf189, buf194, _frozen_param55, buf196, 1024, 512, grid=grid(1024), stream=stream0)
buf197 = reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0); del buf193 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf196, (1024, 512), (512, 1), 0), _frozen_param147, out=buf197)
# Topologically Sorted Source Nodes: [setitem_13], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf197, arg89_1, 524288, grid=grid(524288), stream=stream0)
buf198 = buf172; del buf172 # reuse
# Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf197, buf198, 524288, grid=grid(524288), stream=stream0)
buf205 = buf165; del buf165 # reuse
# Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf197, buf205, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex]
buf199 = torch.ops.aten.view_as_complex.default(buf198)
buf200 = buf199
# Topologically Sorted Source Nodes: [mul_44], Original ATen: [aten.mul]
buf201 = torch.ops.aten.mul.Tensor(buf200, _frozen_param79)
del buf199
del buf200
buf202 = buf201
del buf201
# Topologically Sorted Source Nodes: [view_as_real_12], Original ATen: [aten.view_as_real]
buf203 = torch.ops.aten.view_as_real.default(buf202)
buf204 = buf203
buf214 = reinterpret_tensor(buf196, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf196 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf204, buf214, 524288, grid=grid(524288), stream=stream0)
del buf202
del buf203
del buf204
# Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex]
buf206 = torch.ops.aten.view_as_complex.default(buf205)
buf207 = buf206
# Topologically Sorted Source Nodes: [mul_45], Original ATen: [aten.mul]
buf208 = torch.ops.aten.mul.Tensor(buf207, _frozen_param79)
del buf206
del buf207
buf209 = buf208
del buf208
# Topologically Sorted Source Nodes: [view_as_real_13], Original ATen: [aten.view_as_real]
buf210 = torch.ops.aten.view_as_real.default(buf209)
buf211 = buf210
# Topologically Sorted Source Nodes: [xk_20, setitem_12], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf211, arg88_1, 524288, grid=grid(524288), stream=stream0)
del buf209
del buf210
del buf211
# Topologically Sorted Source Nodes: [], Original ATen: []
buf215 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf214, reinterpret_tensor(arg88_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg89_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf216 = buf215[0]
del buf215
buf221 = reinterpret_tensor(buf214, (1024, 512), (512, 1), 0); del buf214 # reuse
# Topologically Sorted Source Nodes: [linear_45], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf216, (1024, 512), (512, 1), 0), _frozen_param122, out=buf221)
buf223 = reinterpret_tensor(buf216, (32, 32, 512), (16384, 512, 1), 0); del buf216 # reuse
# Topologically Sorted Source Nodes: [out_5, h_7, float_35, pow_14, mean_13, add_26, rsqrt_13, mul_46, output_27, mul_47], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf189, buf194, buf221, _frozen_param60, buf223, 1024, 512, grid=grid(1024), stream=stream0)
buf224 = buf192; del buf192 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf223, (1024, 512), (512, 1), 0), _frozen_param148, out=buf224)
buf225 = reinterpret_tensor(buf197, (32, 32, 1536), (49152, 1536, 1), 0); del buf197 # reuse
# Topologically Sorted Source Nodes: [silu_6, mul_48], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf224, buf225, 1572864, grid=grid(1572864), stream=stream0)
buf226 = reinterpret_tensor(buf223, (1024, 512), (512, 1), 0); del buf223 # reuse
# Topologically Sorted Source Nodes: [linear_48], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0), _frozen_param125, out=buf226)
buf228 = reinterpret_tensor(buf161, (32, 32, 512), (16384, 512, 1), 0); del buf161 # reuse
# Topologically Sorted Source Nodes: [out_5, h_7, out_6, float_36, pow_15, mean_14, add_28, rsqrt_14, mul_49, output_28, mul_50], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf189, buf194, buf221, buf226, _frozen_param64, buf228, 1024, 512, grid=grid(1024), stream=stream0)
buf229 = reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0); del buf225 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf228, (1024, 512), (512, 1), 0), _frozen_param149, out=buf229)
# Topologically Sorted Source Nodes: [setitem_15], Original ATen: [aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused_copy_1.run(buf229, arg91_1, 524288, grid=grid(524288), stream=stream0)
buf230 = buf205; del buf205 # reuse
# Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_2.run(buf229, buf230, 524288, grid=grid(524288), stream=stream0)
buf237 = buf198; del buf198 # reuse
# Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex]
stream0 = get_raw_stream(0)
triton_poi_fused_view_as_complex_3.run(buf229, buf237, 524288, grid=grid(524288), stream=stream0)
# Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex]
buf231 = torch.ops.aten.view_as_complex.default(buf230)
buf232 = buf231
# Topologically Sorted Source Nodes: [mul_51], Original ATen: [aten.mul]
buf233 = torch.ops.aten.mul.Tensor(buf232, _frozen_param79)
del buf230
del buf231
del buf232
buf234 = buf233
del buf233
# Topologically Sorted Source Nodes: [view_as_real_14], Original ATen: [aten.view_as_real]
buf235 = torch.ops.aten.view_as_real.default(buf234)
buf236 = buf235
buf246 = reinterpret_tensor(buf228, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf228 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_4.run(buf236, buf246, 524288, grid=grid(524288), stream=stream0)
del buf234
del buf235
del buf236
# Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex]
buf238 = torch.ops.aten.view_as_complex.default(buf237)
buf239 = buf238
# Topologically Sorted Source Nodes: [mul_52], Original ATen: [aten.mul]
buf240 = torch.ops.aten.mul.Tensor(buf239, _frozen_param79)
del buf237
del buf238
del buf239
buf241 = buf240
del buf240
# Topologically Sorted Source Nodes: [view_as_real_15], Original ATen: [aten.view_as_real]
buf242 = torch.ops.aten.view_as_real.default(buf241)
buf243 = buf242
# Topologically Sorted Source Nodes: [xk_23, setitem_14], Original ATen: [aten._to_copy, aten.copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_copy_5.run(buf243, arg90_1, 524288, grid=grid(524288), stream=stream0)
del buf241
del buf242
del buf243
# Topologically Sorted Source Nodes: [], Original ATen: []
buf247 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf246, reinterpret_tensor(arg90_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg91_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125)
buf248 = buf247[0]
del buf247
buf253 = reinterpret_tensor(buf246, (1024, 512), (512, 1), 0); del buf246 # reuse
# Topologically Sorted Source Nodes: [linear_52], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf248, (1024, 512), (512, 1), 0), _frozen_param129, out=buf253)
buf254 = buf189; del buf189 # reuse
buf256 = reinterpret_tensor(buf248, (32, 32, 512), (16384, 512, 1), 0); del buf248 # reuse
# Topologically Sorted Source Nodes: [out_5, h_7, out_6, h_8, float_40, pow_16, mean_15, add_30, rsqrt_15, mul_53, output_31, mul_54], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf254, buf194, buf221, buf226, buf253, _frozen_param69, buf256, 1024, 512, grid=grid(1024), stream=stream0)
del buf194
del buf221
del buf226
del buf253
buf257 = buf224; del buf224 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf256, (1024, 512), (512, 1), 0), _frozen_param150, out=buf257)
buf258 = reinterpret_tensor(buf229, (32, 32, 1536), (49152, 1536, 1), 0); del buf229 # reuse
# Topologically Sorted Source Nodes: [silu_7, mul_55], Original ATen: [aten.silu, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_7.run(buf257, buf258, 1572864, grid=grid(1572864), stream=stream0)
del buf257
buf259 = reinterpret_tensor(buf256, (1024, 512), (512, 1), 0); del buf256 # reuse
# Topologically Sorted Source Nodes: [linear_55], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf258, (1024, 1536), (1536, 1), 0), _frozen_param132, out=buf259)
del buf258
buf261 = buf254; del buf254 # reuse
# Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf261, buf259, _frozen_param73, 1024, 512, grid=grid(1024), stream=stream0)
del buf259
buf262 = empty_strided_cuda((32, 32000), (32000, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [output_33], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf261, (32, 512), (16384, 1), 15872), _frozen_param134, out=buf262)
del buf261
buf263 = empty_strided_cuda((32, 32000), (32000, 1), torch.float32)
# Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_15.run(buf262, buf263, 1024000, grid=grid(1024000), stream=stream0)
del buf262
return (buf263, arg77_1, arg76_1, arg79_1, arg78_1, arg81_1, arg80_1, arg83_1, arg82_1, arg85_1, arg84_1, arg87_1, arg86_1, arg89_1, arg88_1, arg91_1, arg90_1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param0
_frozen_param0 = rand_strided((32000, 512), (512, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param1
_frozen_param1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param6
_frozen_param6 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param10
_frozen_param10 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param15
_frozen_param15 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param19
_frozen_param19 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param24
_frozen_param24 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param28
_frozen_param28 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param33
_frozen_param33 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param37
_frozen_param37 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param42
_frozen_param42 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param46
_frozen_param46 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param51
_frozen_param51 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param55
_frozen_param55 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param60
_frozen_param60 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param64
_frozen_param64 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param69
_frozen_param69 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param73
_frozen_param73 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param135
_frozen_param135 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param79
_frozen_param79 = rand_strided((1, 32, 1, 32), (1024, 32, 32, 1), device='cuda:0', dtype=torch.complex64)
global _frozen_param80
_frozen_param80 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param136
_frozen_param136 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param83
_frozen_param83 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param137
_frozen_param137 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param87
_frozen_param87 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param138
_frozen_param138 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param90
_frozen_param90 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param139
_frozen_param139 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param94
_frozen_param94 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param140
_frozen_param140 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param97
_frozen_param97 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param141
_frozen_param141 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param101
_frozen_param101 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param142
_frozen_param142 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param104
_frozen_param104 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param143
_frozen_param143 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param108
_frozen_param108 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param144
_frozen_param144 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param111
_frozen_param111 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param145
_frozen_param145 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param115
_frozen_param115 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param146
_frozen_param146 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param118
_frozen_param118 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param147
_frozen_param147 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param122
_frozen_param122 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param148
_frozen_param148 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param125
_frozen_param125 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param149
_frozen_param149 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param129
_frozen_param129 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param150
_frozen_param150 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param132
_frozen_param132 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16)
global _frozen_param134
_frozen_param134 = rand_strided((512, 32000), (1, 512), device='cuda:0', dtype=torch.bfloat16)
arg76_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg77_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg78_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg79_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg80_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg81_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg82_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg83_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg84_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg85_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg86_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg87_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg88_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg89_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg90_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg91_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16)
arg92_1 = rand_strided((32, 32), (32, 1), device='cuda:0', dtype=torch.int32)
fn = lambda: call([arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('llama', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment