Last active
December 13, 2024 11:36
-
-
Save leslie-fang-intel/f2d4de4b4d14875a40d6c09f0fa5fbb3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # AOT ID: ['0_inference'] | |
| from ctypes import c_void_p, c_long, c_int | |
| import torch | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| from math import inf, nan | |
| from torch._inductor.hooks import run_intermediate_hooks | |
| from torch._inductor.utils import maybe_profile | |
| from torch._inductor.codegen.memory_planning import _align as align | |
| from torch import device, empty_strided | |
| from torch._inductor.async_compile import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime.triton_heuristics import ( | |
| grid, | |
| split_scan_grid, | |
| grid_combo_kernels, | |
| start_graph, | |
| end_graph, | |
| cooperative_reduction_grid, | |
| ) | |
| from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
| from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
| aten = torch.ops.aten | |
| inductor_ops = torch.ops.inductor | |
| _quantized = torch.ops._quantized | |
| assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
| empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
| empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
| empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
| reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
| alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
| async_compile = AsyncCompile() | |
| empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
| _frozen_param0 = None # device(type='cuda', index=0) torch.bfloat16 (32000, 512) (512, 1) 7fe5e12f9cb0 | |
| _frozen_param1 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fb4c0 | |
| _frozen_param6 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f9580 | |
| _frozen_param10 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fbce0 | |
| _frozen_param15 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc0125c0 | |
| _frozen_param19 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc0127a0 | |
| _frozen_param24 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc012a20 | |
| _frozen_param28 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fa930 | |
| _frozen_param33 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc0113a0 | |
| _frozen_param37 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f9fd0 | |
| _frozen_param42 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc011f80 | |
| _frozen_param46 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5ed46b8d0 | |
| _frozen_param51 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc013790 | |
| _frozen_param55 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f99e0 | |
| _frozen_param60 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12f9bc0 | |
| _frozen_param64 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc010d10 | |
| _frozen_param69 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5cc012bb0 | |
| _frozen_param73 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7fe5e12fb470 | |
| _frozen_param135 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3bec0 | |
| _frozen_param79 = None # device(type='cuda', index=0) torch.complex64 (1, 32, 1, 32) (1024, 32, 32, 1) 7fe55837e520 | |
| _frozen_param80 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d82f20 | |
| _frozen_param136 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c3a4d0 | |
| _frozen_param83 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe55821f880 | |
| _frozen_param137 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3bb00 | |
| _frozen_param87 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557e9ad90 | |
| _frozen_param138 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557df2f70 | |
| _frozen_param90 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d88180 | |
| _frozen_param139 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3b0b0 | |
| _frozen_param94 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89800 | |
| _frozen_param140 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c3a1b0 | |
| _frozen_param97 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d898f0 | |
| _frozen_param141 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c3ade0 | |
| _frozen_param101 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89a30 | |
| _frozen_param142 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c3bab0 | |
| _frozen_param104 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d89b20 | |
| _frozen_param143 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c38e50 | |
| _frozen_param108 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89c60 | |
| _frozen_param144 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c39e40 | |
| _frozen_param111 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d89d50 | |
| _frozen_param145 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557d82d90 | |
| _frozen_param115 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d89e90 | |
| _frozen_param146 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c39260 | |
| _frozen_param118 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d89f80 | |
| _frozen_param147 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c38590 | |
| _frozen_param122 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d8a0c0 | |
| _frozen_param148 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c38d60 | |
| _frozen_param125 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d8a1b0 | |
| _frozen_param149 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7fe557c38bd0 | |
| _frozen_param129 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7fe557d8a2f0 | |
| _frozen_param150 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7fe557c39c10 | |
| _frozen_param132 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7fe557d8a3e0 | |
| _frozen_param134 = None # device(type='cuda', index=0) torch.bfloat16 (512, 32000) (1, 512) 7fe557e302c0 | |
| # kernel path: /tmp/torchinductor_t/jq/cjqkz4jt5d5lkuk3r2wju344fss2vw6bnduyr5hqduvhlizskx43.py | |
| # Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add => add | |
| # float_1 => convert_element_type_1 | |
| # h => embedding | |
| # mean => mean | |
| # mul => mul | |
| # mul_1 => mul_1 | |
| # output => convert_element_type_2 | |
| # pow_1 => pow_1 | |
| # rsqrt => rsqrt | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %convert_element_type_1 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%embedding, torch.float32), kwargs = {}) | |
| # %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_1, 2), kwargs = {}) | |
| # %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {}) | |
| # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {}) | |
| # %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {}) | |
| # %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %rsqrt), kwargs = {}) | |
| # %convert_element_type_2 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
| # %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %arg1_1), kwargs = {}) | |
| triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.DEFAULT, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
| xnumel = 1024 | |
| rnumel = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = xindex < xnumel | |
| rbase = tl.arange(0, RBLOCK)[None, :] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last') | |
| _tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
| for roffset in range(0, rnumel, RBLOCK): | |
| rindex = roffset + rbase | |
| rmask = rindex < rnumel | |
| r1 = rindex | |
| tmp1 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert(((0 <= tmp4) & (tmp4 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp7 = tmp6.to(tl.float32) | |
| tmp8 = tmp7 * tmp7 | |
| tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK]) | |
| tmp11 = _tmp10 + tmp9 | |
| _tmp10 = tl.where(rmask & xmask, tmp11, _tmp10) | |
| tmp10 = tl.sum(_tmp10, 1)[:, None] | |
| for roffset in range(0, rnumel, RBLOCK): | |
| rindex = roffset + rbase | |
| rmask = rindex < rnumel | |
| r1 = rindex | |
| tmp26 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp12 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
| tmp13 = tmp0 + tmp12 | |
| tmp14 = tmp0 < 0 | |
| tmp15 = tl.where(tmp14, tmp13, tmp0) | |
| tl.device_assert(((0 <= tmp15) & (tmp15 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp15 < 32000") | |
| tmp17 = tl.load(in_ptr1 + (r1 + 512*tmp15), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
| tmp18 = tmp17.to(tl.float32) | |
| tmp19 = 512.0 | |
| tmp20 = (tmp10 / tmp19).to(tl.float32) | |
| tmp21 = 1e-05 | |
| tmp22 = tmp20 + tmp21 | |
| tmp23 = libdevice.rsqrt(tmp22) | |
| tmp24 = tmp18 * tmp23 | |
| tmp25 = tmp24.to(tl.float32) | |
| tmp27 = tmp25 * tmp26 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp27, rmask & xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/xc/cxcmzl6ilmlg65glf4pzjcxwxpduicy2vfmzzjio5qccmdqzy5s3.py | |
| # Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy] | |
| # Source node to ATen node mapping: | |
| # setitem_1 => copy_1 | |
| # Graph fragment: | |
| # %copy_1 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_4, %view_8), kwargs = {}) | |
| # %copy__default_1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor_1, %copy_1), kwargs = {}) | |
| triton_poi_fused_copy_1 = async_compile.triton('triton_poi_fused_copy_1', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_copy_1', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 512) | |
| x3 = xindex // 512 | |
| x2 = xindex // 16384 | |
| x4 = (xindex % 16384) | |
| tmp0 = tl.load(in_ptr0 + (x0 + 1536*x3), None).to(tl.float32) | |
| tl.store(out_ptr0 + (512 + x4 + 524288*x2), tmp0, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/7m/c7m45i7porfguqnzfi4evudvthp6s46rtboxlqxs2ukmts7oogqy.py | |
| # Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex] | |
| # Source node to ATen node mapping: | |
| # xq_ => view_as_complex | |
| # Graph fragment: | |
| # %view_as_complex : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_9,), kwargs = {}) | |
| triton_poi_fused_view_as_complex_2 = async_compile.triton('triton_poi_fused_view_as_complex_2', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_view_as_complex_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 512) | |
| x1 = xindex // 512 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (512 + x0 + 1536*x1), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x2), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/65/c65ekeiwoihqn3z3eozs5kfs2w7fuc6dtsursyyedmm5n3idz6aw.py | |
| # Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex] | |
| # Source node to ATen node mapping: | |
| # xk_ => view_as_complex_1 | |
| # Graph fragment: | |
| # %view_as_complex_1 : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_10,), kwargs = {}) | |
| triton_poi_fused_view_as_complex_3 = async_compile.triton('triton_poi_fused_view_as_complex_3', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_view_as_complex_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 512) | |
| x1 = xindex // 512 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (1024 + x0 + 1536*x1), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x2), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/fr/cfrphf4roqxuzvwmpfjsqykxw7ykp6blxuab7a762yx75y5ssqwv.py | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| # Source node to ATen node mapping: | |
| # Graph fragment: | |
| # %_scaled_dot_product_flash_attention_default_7 : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%permute_default_21, %permute_default_22, %permute_default_23), kwargs = {scale: 0.125}) | |
| triton_poi_fused_4 = async_compile.triton('triton_poi_fused_4', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/pm/cpmwz4lr5hqfizv3wxqz5p5i5tjfamfs6btvjdkmpwz6sixmisvp.py | |
| # Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy] | |
| # Source node to ATen node mapping: | |
| # setitem => copy | |
| # xk_2 => convert_element_type_12 | |
| # Graph fragment: | |
| # %convert_element_type_12 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_13, torch.bfloat16), kwargs = {}) | |
| # %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %convert_element_type_12), kwargs = {}) | |
| # %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor, %copy), kwargs = {}) | |
| triton_poi_fused__to_copy_copy_5 = async_compile.triton('triton_poi_fused__to_copy_copy_5', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_copy_5', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_copy_5(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x2 = xindex | |
| x0 = (xindex % 16384) | |
| x1 = xindex // 16384 | |
| tmp0 = tl.load(in_ptr0 + (x2), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (512 + x0 + 524288*x1), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/dr/cdrrgy3ocqlbwxy4cv6ytxs3apsjv37qs3l2swp2p2sgajmtmocx.py | |
| # Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_2 => add_2 | |
| # float_5 => convert_element_type_21 | |
| # h => embedding | |
| # h_1 => add_1 | |
| # mean_1 => mean_1 | |
| # mul_4 => mul_4 | |
| # mul_5 => mul_5 | |
| # output_3 => convert_element_type_22 | |
| # pow_2 => pow_2 | |
| # rsqrt_1 => rsqrt_1 | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {}) | |
| # %convert_element_type_21 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.float32), kwargs = {}) | |
| # %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_21, 2), kwargs = {}) | |
| # %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {}) | |
| # %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {}) | |
| # %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_2,), kwargs = {}) | |
| # %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_21, %rsqrt_1), kwargs = {}) | |
| # %convert_element_type_22 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_4, torch.bfloat16), kwargs = {}) | |
| # %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %arg6_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| x0 = xindex | |
| r1 = rindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
| tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp21 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp1 = tl.full([RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32) | |
| tmp8 = tmp6 + tmp7 | |
| tmp9 = tmp8.to(tl.float32) | |
| tmp10 = tmp9 * tmp9 | |
| tmp11 = tl.broadcast_to(tmp10, [RBLOCK]) | |
| tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0)) | |
| tmp14 = 512.0 | |
| tmp15 = (tmp13 / tmp14).to(tl.float32) | |
| tmp16 = 1e-05 | |
| tmp17 = tmp15 + tmp16 | |
| tmp18 = libdevice.rsqrt(tmp17) | |
| tmp19 = tmp9 * tmp18 | |
| tmp20 = tmp19.to(tl.float32) | |
| tmp22 = tmp20 * tmp21 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/37/c37cieheeo7zfl4rbysqwzkbdr7pwfgmv5au7ifpl2r5xifjrj5v.py | |
| # Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul] | |
| # Source node to ATen node mapping: | |
| # mul_6 => mul_7 | |
| # silu => convert_element_type_25, convert_element_type_26, mul_6, sigmoid | |
| # Graph fragment: | |
| # %convert_element_type_25 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_24, torch.float32), kwargs = {}) | |
| # %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_25,), kwargs = {}) | |
| # %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_25, %sigmoid), kwargs = {}) | |
| # %convert_element_type_26 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {}) | |
| # %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_26, %view_26), kwargs = {}) | |
| triton_poi_fused_mul_silu_7 = async_compile.triton('triton_poi_fused_mul_silu_7', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2097152}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_mul_silu_7(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1572864 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 1536) | |
| x1 = xindex // 1536 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (1536 + x0 + 3072*x1), None).to(tl.float32) | |
| tmp5 = tl.load(in_ptr0 + (x0 + 3072*x1), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp2 = tl.sigmoid(tmp1) | |
| tmp3 = tmp1 * tmp2 | |
| tmp4 = tmp3.to(tl.float32) | |
| tmp6 = tmp4 * tmp5 | |
| tl.store(out_ptr0 + (x2), tmp6, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/oe/coezydc4r3hd4ebv3eien4ar5zw7e5oum674ezorliyuknllnim6.py | |
| # Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_4 => add_4 | |
| # float_6 => convert_element_type_31 | |
| # h => embedding | |
| # h_1 => add_1 | |
| # mean_2 => mean_2 | |
| # mul_7 => mul_8 | |
| # mul_8 => mul_9 | |
| # out => add_3 | |
| # output_4 => convert_element_type_32 | |
| # pow_3 => pow_3 | |
| # rsqrt_2 => rsqrt_2 | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {}) | |
| # %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {}) | |
| # %convert_element_type_31 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.float32), kwargs = {}) | |
| # %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_31, 2), kwargs = {}) | |
| # %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {}) | |
| # %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {}) | |
| # %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_4,), kwargs = {}) | |
| # %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_31, %rsqrt_2), kwargs = {}) | |
| # %convert_element_type_32 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_8, torch.bfloat16), kwargs = {}) | |
| # %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_32, %arg10_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| x0 = xindex | |
| r1 = rindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
| tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp9 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp23 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp1 = tl.full([RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32) | |
| tmp8 = tmp6 + tmp7 | |
| tmp10 = tmp8 + tmp9 | |
| tmp11 = tmp10.to(tl.float32) | |
| tmp12 = tmp11 * tmp11 | |
| tmp13 = tl.broadcast_to(tmp12, [RBLOCK]) | |
| tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp13, 0)) | |
| tmp16 = 512.0 | |
| tmp17 = (tmp15 / tmp16).to(tl.float32) | |
| tmp18 = 1e-05 | |
| tmp19 = tmp17 + tmp18 | |
| tmp20 = libdevice.rsqrt(tmp19) | |
| tmp21 = tmp11 * tmp20 | |
| tmp22 = tmp21.to(tl.float32) | |
| tmp24 = tmp22 * tmp23 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp24, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/hg/chg35he4i3wb65f2xla6t2f46hl3nvqroelcorgcij2tuqcakabg.py | |
| # Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_6 => add_6 | |
| # float_10 => convert_element_type_51 | |
| # h => embedding | |
| # h_1 => add_1 | |
| # h_2 => add_5 | |
| # mean_3 => mean_3 | |
| # mul_11 => mul_12 | |
| # mul_12 => mul_13 | |
| # out => add_3 | |
| # output_7 => convert_element_type_52 | |
| # pow_4 => pow_4 | |
| # rsqrt_3 => rsqrt_3 | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {}) | |
| # %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {}) | |
| # %add_5 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %view_51), kwargs = {}) | |
| # %convert_element_type_51 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_5, torch.float32), kwargs = {}) | |
| # %pow_4 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_51, 2), kwargs = {}) | |
| # %mean_3 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {}) | |
| # %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {}) | |
| # %rsqrt_3 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {}) | |
| # %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_51, %rsqrt_3), kwargs = {}) | |
| # %convert_element_type_52 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_12, torch.bfloat16), kwargs = {}) | |
| # %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_52, %arg15_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| x0 = xindex | |
| r1 = rindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
| tmp7 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp9 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp11 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp25 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp1 = tl.full([RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32) | |
| tmp8 = tmp6 + tmp7 | |
| tmp10 = tmp8 + tmp9 | |
| tmp12 = tmp10 + tmp11 | |
| tmp13 = tmp12.to(tl.float32) | |
| tmp14 = tmp13 * tmp13 | |
| tmp15 = tl.broadcast_to(tmp14, [RBLOCK]) | |
| tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp15, 0)) | |
| tmp18 = 512.0 | |
| tmp19 = (tmp17 / tmp18).to(tl.float32) | |
| tmp20 = 1e-05 | |
| tmp21 = tmp19 + tmp20 | |
| tmp22 = libdevice.rsqrt(tmp21) | |
| tmp23 = tmp13 * tmp22 | |
| tmp24 = tmp23.to(tl.float32) | |
| tmp26 = tmp24 * tmp25 | |
| tl.store(in_out_ptr0 + (r1 + 512*x0), tmp12, None) | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp26, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/iy/ciy4z4yq6ins6l5ivop2w7p7zdinblxm2voqklkpqk7pgxbr7pil.py | |
| # Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_8 => add_8 | |
| # float_11 => convert_element_type_61 | |
| # mean_4 => mean_4 | |
| # mul_14 => mul_16 | |
| # mul_15 => mul_17 | |
| # out_1 => add_7 | |
| # output_8 => convert_element_type_62 | |
| # pow_5 => pow_5 | |
| # rsqrt_4 => rsqrt_4 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %convert_element_type_61 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.float32), kwargs = {}) | |
| # %pow_5 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_61, 2), kwargs = {}) | |
| # %mean_4 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {}) | |
| # %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {}) | |
| # %rsqrt_4 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {}) | |
| # %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_61, %rsqrt_4), kwargs = {}) | |
| # %convert_element_type_62 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) | |
| # %mul_17 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_62, %arg19_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp15 = tl.load(in_ptr2 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp4 = tmp3 * tmp3 | |
| tmp5 = tl.broadcast_to(tmp4, [RBLOCK]) | |
| tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
| tmp8 = 512.0 | |
| tmp9 = (tmp7 / tmp8).to(tl.float32) | |
| tmp10 = 1e-05 | |
| tmp11 = tmp9 + tmp10 | |
| tmp12 = libdevice.rsqrt(tmp11) | |
| tmp13 = tmp3 * tmp12 | |
| tmp14 = tmp13.to(tl.float32) | |
| tmp16 = tmp14 * tmp15 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp16, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/pn/cpnh5dfi4oms6ccoppdagviyr73sfnfnxsb7fxqww3qrixi7lyh4.py | |
| # Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_10 => add_10 | |
| # float_15 => convert_element_type_81 | |
| # h_3 => add_9 | |
| # mean_5 => mean_5 | |
| # mul_18 => mul_20 | |
| # mul_19 => mul_21 | |
| # out_1 => add_7 | |
| # output_11 => convert_element_type_82 | |
| # pow_6 => pow_6 | |
| # rsqrt_5 => rsqrt_5 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {}) | |
| # %convert_element_type_81 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_9, torch.float32), kwargs = {}) | |
| # %pow_6 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_81, 2), kwargs = {}) | |
| # %mean_5 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_6, [-1], True), kwargs = {}) | |
| # %add_10 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_5, 1e-05), kwargs = {}) | |
| # %rsqrt_5 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_10,), kwargs = {}) | |
| # %mul_20 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_81, %rsqrt_5), kwargs = {}) | |
| # %convert_element_type_82 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_20, torch.bfloat16), kwargs = {}) | |
| # %mul_21 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_82, %arg24_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp17 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp5 = tmp4.to(tl.float32) | |
| tmp6 = tmp5 * tmp5 | |
| tmp7 = tl.broadcast_to(tmp6, [RBLOCK]) | |
| tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0)) | |
| tmp10 = 512.0 | |
| tmp11 = (tmp9 / tmp10).to(tl.float32) | |
| tmp12 = 1e-05 | |
| tmp13 = tmp11 + tmp12 | |
| tmp14 = libdevice.rsqrt(tmp13) | |
| tmp15 = tmp5 * tmp14 | |
| tmp16 = tmp15.to(tl.float32) | |
| tmp18 = tmp16 * tmp17 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp18, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/dw/cdwfdyb3igstcir6uoaelj2eyiyqrmimqovbk6l57r4zvxcvjy7k.py | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_12 => add_12 | |
| # float_16 => convert_element_type_91 | |
| # h_3 => add_9 | |
| # mean_6 => mean_6 | |
| # mul_21 => mul_24 | |
| # mul_22 => mul_25 | |
| # out_1 => add_7 | |
| # out_2 => add_11 | |
| # output_12 => convert_element_type_92 | |
| # pow_7 => pow_7 | |
| # rsqrt_6 => rsqrt_6 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {}) | |
| # %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {}) | |
| # %convert_element_type_91 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {}) | |
| # %pow_7 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_91, 2), kwargs = {}) | |
| # %mean_6 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_7, [-1], True), kwargs = {}) | |
| # %add_12 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_6, 1e-05), kwargs = {}) | |
| # %rsqrt_6 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {}) | |
| # %mul_24 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_91, %rsqrt_6), kwargs = {}) | |
| # %convert_element_type_92 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_24, torch.bfloat16), kwargs = {}) | |
| # %mul_25 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_92, %arg28_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp5 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp19 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp6 = tmp4 + tmp5 | |
| tmp7 = tmp6.to(tl.float32) | |
| tmp8 = tmp7 * tmp7 | |
| tmp9 = tl.broadcast_to(tmp8, [RBLOCK]) | |
| tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
| tmp12 = 512.0 | |
| tmp13 = (tmp11 / tmp12).to(tl.float32) | |
| tmp14 = 1e-05 | |
| tmp15 = tmp13 + tmp14 | |
| tmp16 = libdevice.rsqrt(tmp15) | |
| tmp17 = tmp7 * tmp16 | |
| tmp18 = tmp17.to(tl.float32) | |
| tmp20 = tmp18 * tmp19 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp20, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/z3/cz3e34micft5wnsnm2gnpqmih6vve6ujd2ccrnpkysa5obrpxhpi.py | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_14 => add_14 | |
| # float_20 => convert_element_type_111 | |
| # h_3 => add_9 | |
| # h_4 => add_13 | |
| # mean_7 => mean_7 | |
| # mul_25 => mul_28 | |
| # mul_26 => mul_29 | |
| # out_1 => add_7 | |
| # out_2 => add_11 | |
| # output_15 => convert_element_type_112 | |
| # pow_8 => pow_8 | |
| # rsqrt_7 => rsqrt_7 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {}) | |
| # %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {}) | |
| # %add_13 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_109), kwargs = {}) | |
| # %convert_element_type_111 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_13, torch.float32), kwargs = {}) | |
| # %pow_8 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_111, 2), kwargs = {}) | |
| # %mean_7 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_8, [-1], True), kwargs = {}) | |
| # %add_14 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_7, 1e-05), kwargs = {}) | |
| # %rsqrt_7 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_14,), kwargs = {}) | |
| # %mul_28 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_111, %rsqrt_7), kwargs = {}) | |
| # %convert_element_type_112 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_28, torch.bfloat16), kwargs = {}) | |
| # %mul_29 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_112, %arg33_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp5 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp7 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp21 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp6 = tmp4 + tmp5 | |
| tmp8 = tmp6 + tmp7 | |
| tmp9 = tmp8.to(tl.float32) | |
| tmp10 = tmp9 * tmp9 | |
| tmp11 = tl.broadcast_to(tmp10, [RBLOCK]) | |
| tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0)) | |
| tmp14 = 512.0 | |
| tmp15 = (tmp13 / tmp14).to(tl.float32) | |
| tmp16 = 1e-05 | |
| tmp17 = tmp15 + tmp16 | |
| tmp18 = libdevice.rsqrt(tmp17) | |
| tmp19 = tmp9 * tmp18 | |
| tmp20 = tmp19.to(tl.float32) | |
| tmp22 = tmp20 * tmp21 | |
| tl.store(in_out_ptr0 + (r1 + 512*x0), tmp8, None) | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/lv/clv7iqdidc6yovjvhgzra2tlgehvzuyjf6c4drcdk3cgzbb7tbn4.py | |
| # Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_32 => add_32 | |
| # float_41 => convert_element_type_241 | |
| # h_9 => mul_65 | |
| # mean_16 => mean_16 | |
| # mul_56 => mul_64 | |
| # out_7 => add_31 | |
| # output_32 => convert_element_type_242 | |
| # pow_17 => pow_17 | |
| # rsqrt_16 => rsqrt_16 | |
| # Graph fragment: | |
| # %add_31 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_29, %view_231), kwargs = {}) | |
| # %convert_element_type_241 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_31, torch.float32), kwargs = {}) | |
| # %pow_17 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_241, 2), kwargs = {}) | |
| # %mean_16 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_17, [-1], True), kwargs = {}) | |
| # %add_32 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_16, 1e-05), kwargs = {}) | |
| # %rsqrt_16 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_32,), kwargs = {}) | |
| # %mul_64 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_241, %rsqrt_16), kwargs = {}) | |
| # %convert_element_type_242 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_64, torch.bfloat16), kwargs = {}) | |
| # %mul_65 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_242, %arg73_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14(in_out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp15 = tl.load(in_ptr1 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp4 = tmp3 * tmp3 | |
| tmp5 = tl.broadcast_to(tmp4, [RBLOCK]) | |
| tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
| tmp8 = 512.0 | |
| tmp9 = (tmp7 / tmp8).to(tl.float32) | |
| tmp10 = 1e-05 | |
| tmp11 = tmp9 + tmp10 | |
| tmp12 = libdevice.rsqrt(tmp11) | |
| tmp13 = tmp3 * tmp12 | |
| tmp14 = tmp13.to(tl.float32) | |
| tmp16 = tmp14 * tmp15 | |
| tl.store(in_out_ptr0 + (r1 + 512*x0), tmp16, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/uy/cuyzkjsgklqarsm2yt3x6ruv2di7jsrafhofrtw3gdor4jvuttkh.py | |
| # Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # float_42 => convert_element_type_245 | |
| # Graph fragment: | |
| # %convert_element_type_245 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_56, torch.float32), kwargs = {}) | |
| triton_poi_fused__to_copy_15 = async_compile.triton('triton_poi_fused__to_copy_15', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_15(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1024000 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| async_compile.wait(globals()) | |
| del async_compile | |
| def call(args): | |
| arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1 = args | |
| args.clear() | |
| assert_size_stride(arg76_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg77_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg78_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg79_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg80_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg81_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg82_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg83_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg84_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg85_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg86_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg87_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg88_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg89_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg90_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg91_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg92_1, (32, 32), (32, 1)) | |
| with torch.cuda._DeviceGuard(0): | |
| torch.cuda.set_device(0) | |
| buf1 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg92_1, _frozen_param0, _frozen_param1, buf1, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf2 = empty_strided_cuda((1024, 1536), (1536, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf1, (1024, 512), (512, 1), 0), _frozen_param135, out=buf2) | |
| # Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf2, arg77_1, 524288, grid=grid(524288), stream=stream0) | |
| buf3 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf2, buf3, 524288, grid=grid(524288), stream=stream0) | |
| buf10 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf2, buf10, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex] | |
| buf4 = torch.ops.aten.view_as_complex.default(buf3) | |
| buf5 = buf4 | |
| # Topologically Sorted Source Nodes: [mul_2], Original ATen: [aten.mul] | |
| buf6 = torch.ops.aten.mul.Tensor(buf5, _frozen_param79) | |
| del buf4 | |
| del buf5 | |
| buf7 = buf6 | |
| del buf6 | |
| # Topologically Sorted Source Nodes: [view_as_real], Original ATen: [aten.view_as_real] | |
| buf8 = torch.ops.aten.view_as_real.default(buf7) | |
| buf9 = buf8 | |
| buf19 = reinterpret_tensor(buf1, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf1 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf9, buf19, 524288, grid=grid(524288), stream=stream0) | |
| del buf7 | |
| del buf8 | |
| del buf9 | |
| # Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex] | |
| buf11 = torch.ops.aten.view_as_complex.default(buf10) | |
| buf12 = buf11 | |
| # Topologically Sorted Source Nodes: [mul_3], Original ATen: [aten.mul] | |
| buf13 = torch.ops.aten.mul.Tensor(buf12, _frozen_param79) | |
| del buf11 | |
| del buf12 | |
| buf14 = buf13 | |
| del buf13 | |
| # Topologically Sorted Source Nodes: [view_as_real_1], Original ATen: [aten.view_as_real] | |
| buf15 = torch.ops.aten.view_as_real.default(buf14) | |
| buf16 = buf15 | |
| # Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf16, arg76_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf14 | |
| del buf15 | |
| del buf16 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf20 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf19, reinterpret_tensor(arg76_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg77_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf21 = buf20[0] | |
| del buf20 | |
| buf26 = reinterpret_tensor(buf19, (1024, 512), (512, 1), 0); del buf19 # reuse | |
| # Topologically Sorted Source Nodes: [linear_3], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf21, (1024, 512), (512, 1), 0), _frozen_param80, out=buf26) | |
| buf28 = reinterpret_tensor(buf21, (32, 32, 512), (16384, 512, 1), 0); del buf21 # reuse | |
| # Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6.run(arg92_1, _frozen_param0, buf26, _frozen_param6, buf28, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf29 = empty_strided_cuda((1024, 3072), (3072, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf28, (1024, 512), (512, 1), 0), _frozen_param136, out=buf29) | |
| buf30 = reinterpret_tensor(buf2, (32, 32, 1536), (49152, 1536, 1), 0); del buf2 # reuse | |
| # Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf29, buf30, 1572864, grid=grid(1572864), stream=stream0) | |
| buf31 = reinterpret_tensor(buf28, (1024, 512), (512, 1), 0); del buf28 # reuse | |
| # Topologically Sorted Source Nodes: [linear_6], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0), _frozen_param83, out=buf31) | |
| buf33 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8.run(arg92_1, _frozen_param0, buf26, buf31, _frozen_param10, buf33, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf34 = reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0); del buf30 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf33, (1024, 512), (512, 1), 0), _frozen_param137, out=buf34) | |
| # Topologically Sorted Source Nodes: [setitem_3], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf34, arg79_1, 524288, grid=grid(524288), stream=stream0) | |
| buf35 = buf10; del buf10 # reuse | |
| # Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf34, buf35, 524288, grid=grid(524288), stream=stream0) | |
| buf42 = buf3; del buf3 # reuse | |
| # Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf34, buf42, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex] | |
| buf36 = torch.ops.aten.view_as_complex.default(buf35) | |
| buf37 = buf36 | |
| # Topologically Sorted Source Nodes: [mul_9], Original ATen: [aten.mul] | |
| buf38 = torch.ops.aten.mul.Tensor(buf37, _frozen_param79) | |
| del buf36 | |
| del buf37 | |
| buf39 = buf38 | |
| del buf38 | |
| # Topologically Sorted Source Nodes: [view_as_real_2], Original ATen: [aten.view_as_real] | |
| buf40 = torch.ops.aten.view_as_real.default(buf39) | |
| buf41 = buf40 | |
| buf51 = reinterpret_tensor(buf33, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf33 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf41, buf51, 524288, grid=grid(524288), stream=stream0) | |
| del buf39 | |
| del buf40 | |
| del buf41 | |
| # Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex] | |
| buf43 = torch.ops.aten.view_as_complex.default(buf42) | |
| buf44 = buf43 | |
| # Topologically Sorted Source Nodes: [mul_10], Original ATen: [aten.mul] | |
| buf45 = torch.ops.aten.mul.Tensor(buf44, _frozen_param79) | |
| del buf43 | |
| del buf44 | |
| buf46 = buf45 | |
| del buf45 | |
| # Topologically Sorted Source Nodes: [view_as_real_3], Original ATen: [aten.view_as_real] | |
| buf47 = torch.ops.aten.view_as_real.default(buf46) | |
| buf48 = buf47 | |
| # Topologically Sorted Source Nodes: [xk_5, setitem_2], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf48, arg78_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf46 | |
| del buf47 | |
| del buf48 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf52 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf51, reinterpret_tensor(arg78_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg79_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf53 = buf52[0] | |
| del buf52 | |
| buf58 = reinterpret_tensor(buf51, (1024, 512), (512, 1), 0); del buf51 # reuse | |
| # Topologically Sorted Source Nodes: [linear_10], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf53, (1024, 512), (512, 1), 0), _frozen_param87, out=buf58) | |
| buf59 = reinterpret_tensor(buf26, (32, 32, 512), (16384, 512, 1), 0); del buf26 # reuse | |
| buf61 = reinterpret_tensor(buf53, (32, 32, 512), (16384, 512, 1), 0); del buf53 # reuse | |
| # Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9.run(buf59, arg92_1, _frozen_param0, buf31, buf58, _frozen_param15, buf61, 1024, 512, grid=grid(1024), stream=stream0) | |
| del arg92_1 | |
| buf62 = buf29; del buf29 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf61, (1024, 512), (512, 1), 0), _frozen_param138, out=buf62) | |
| buf63 = reinterpret_tensor(buf34, (32, 32, 1536), (49152, 1536, 1), 0); del buf34 # reuse | |
| # Topologically Sorted Source Nodes: [silu_1, mul_13], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf62, buf63, 1572864, grid=grid(1572864), stream=stream0) | |
| buf64 = reinterpret_tensor(buf61, (1024, 512), (512, 1), 0); del buf61 # reuse | |
| # Topologically Sorted Source Nodes: [linear_13], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0), _frozen_param90, out=buf64) | |
| buf66 = reinterpret_tensor(buf58, (32, 32, 512), (16384, 512, 1), 0); del buf58 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf59, buf64, _frozen_param19, buf66, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf67 = reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0); del buf63 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf66, (1024, 512), (512, 1), 0), _frozen_param139, out=buf67) | |
| # Topologically Sorted Source Nodes: [setitem_5], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf67, arg81_1, 524288, grid=grid(524288), stream=stream0) | |
| buf68 = buf42; del buf42 # reuse | |
| # Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf67, buf68, 524288, grid=grid(524288), stream=stream0) | |
| buf75 = buf35; del buf35 # reuse | |
| # Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf67, buf75, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex] | |
| buf69 = torch.ops.aten.view_as_complex.default(buf68) | |
| buf70 = buf69 | |
| # Topologically Sorted Source Nodes: [mul_16], Original ATen: [aten.mul] | |
| buf71 = torch.ops.aten.mul.Tensor(buf70, _frozen_param79) | |
| del buf69 | |
| del buf70 | |
| buf72 = buf71 | |
| del buf71 | |
| # Topologically Sorted Source Nodes: [view_as_real_4], Original ATen: [aten.view_as_real] | |
| buf73 = torch.ops.aten.view_as_real.default(buf72) | |
| buf74 = buf73 | |
| buf84 = reinterpret_tensor(buf66, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf66 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf74, buf84, 524288, grid=grid(524288), stream=stream0) | |
| del buf72 | |
| del buf73 | |
| del buf74 | |
| # Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex] | |
| buf76 = torch.ops.aten.view_as_complex.default(buf75) | |
| buf77 = buf76 | |
| # Topologically Sorted Source Nodes: [mul_17], Original ATen: [aten.mul] | |
| buf78 = torch.ops.aten.mul.Tensor(buf77, _frozen_param79) | |
| del buf76 | |
| del buf77 | |
| buf79 = buf78 | |
| del buf78 | |
| # Topologically Sorted Source Nodes: [view_as_real_5], Original ATen: [aten.view_as_real] | |
| buf80 = torch.ops.aten.view_as_real.default(buf79) | |
| buf81 = buf80 | |
| # Topologically Sorted Source Nodes: [xk_8, setitem_4], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf81, arg80_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf79 | |
| del buf80 | |
| del buf81 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf85 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf84, reinterpret_tensor(arg80_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg81_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf86 = buf85[0] | |
| del buf85 | |
| buf91 = reinterpret_tensor(buf84, (1024, 512), (512, 1), 0); del buf84 # reuse | |
| # Topologically Sorted Source Nodes: [linear_17], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf86, (1024, 512), (512, 1), 0), _frozen_param94, out=buf91) | |
| buf93 = reinterpret_tensor(buf86, (32, 32, 512), (16384, 512, 1), 0); del buf86 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf59, buf64, buf91, _frozen_param24, buf93, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf94 = buf62; del buf62 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf93, (1024, 512), (512, 1), 0), _frozen_param140, out=buf94) | |
| buf95 = reinterpret_tensor(buf67, (32, 32, 1536), (49152, 1536, 1), 0); del buf67 # reuse | |
| # Topologically Sorted Source Nodes: [silu_2, mul_20], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf94, buf95, 1572864, grid=grid(1572864), stream=stream0) | |
| buf96 = reinterpret_tensor(buf93, (1024, 512), (512, 1), 0); del buf93 # reuse | |
| # Topologically Sorted Source Nodes: [linear_20], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0), _frozen_param97, out=buf96) | |
| buf98 = reinterpret_tensor(buf31, (32, 32, 512), (16384, 512, 1), 0); del buf31 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf59, buf64, buf91, buf96, _frozen_param28, buf98, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf99 = reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0); del buf95 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf98, (1024, 512), (512, 1), 0), _frozen_param141, out=buf99) | |
| # Topologically Sorted Source Nodes: [setitem_7], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf99, arg83_1, 524288, grid=grid(524288), stream=stream0) | |
| buf100 = buf75; del buf75 # reuse | |
| # Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf99, buf100, 524288, grid=grid(524288), stream=stream0) | |
| buf107 = buf68; del buf68 # reuse | |
| # Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf99, buf107, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex] | |
| buf101 = torch.ops.aten.view_as_complex.default(buf100) | |
| buf102 = buf101 | |
| # Topologically Sorted Source Nodes: [mul_23], Original ATen: [aten.mul] | |
| buf103 = torch.ops.aten.mul.Tensor(buf102, _frozen_param79) | |
| del buf101 | |
| del buf102 | |
| buf104 = buf103 | |
| del buf103 | |
| # Topologically Sorted Source Nodes: [view_as_real_6], Original ATen: [aten.view_as_real] | |
| buf105 = torch.ops.aten.view_as_real.default(buf104) | |
| buf106 = buf105 | |
| buf116 = reinterpret_tensor(buf98, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf98 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf106, buf116, 524288, grid=grid(524288), stream=stream0) | |
| del buf104 | |
| del buf105 | |
| del buf106 | |
| # Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex] | |
| buf108 = torch.ops.aten.view_as_complex.default(buf107) | |
| buf109 = buf108 | |
| # Topologically Sorted Source Nodes: [mul_24], Original ATen: [aten.mul] | |
| buf110 = torch.ops.aten.mul.Tensor(buf109, _frozen_param79) | |
| del buf108 | |
| del buf109 | |
| buf111 = buf110 | |
| del buf110 | |
| # Topologically Sorted Source Nodes: [view_as_real_7], Original ATen: [aten.view_as_real] | |
| buf112 = torch.ops.aten.view_as_real.default(buf111) | |
| buf113 = buf112 | |
| # Topologically Sorted Source Nodes: [xk_11, setitem_6], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf113, arg82_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf111 | |
| del buf112 | |
| del buf113 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf117 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf116, reinterpret_tensor(arg82_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg83_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf118 = buf117[0] | |
| del buf117 | |
| buf123 = reinterpret_tensor(buf116, (1024, 512), (512, 1), 0); del buf116 # reuse | |
| # Topologically Sorted Source Nodes: [linear_24], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf118, (1024, 512), (512, 1), 0), _frozen_param101, out=buf123) | |
| buf124 = buf59; del buf59 # reuse | |
| buf126 = reinterpret_tensor(buf118, (32, 32, 512), (16384, 512, 1), 0); del buf118 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf124, buf64, buf91, buf96, buf123, _frozen_param33, buf126, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf123 | |
| del buf64 | |
| buf127 = buf94; del buf94 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf126, (1024, 512), (512, 1), 0), _frozen_param142, out=buf127) | |
| buf128 = reinterpret_tensor(buf99, (32, 32, 1536), (49152, 1536, 1), 0); del buf99 # reuse | |
| # Topologically Sorted Source Nodes: [silu_3, mul_27], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf127, buf128, 1572864, grid=grid(1572864), stream=stream0) | |
| buf129 = reinterpret_tensor(buf126, (1024, 512), (512, 1), 0); del buf126 # reuse | |
| # Topologically Sorted Source Nodes: [linear_27], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0), _frozen_param104, out=buf129) | |
| buf131 = reinterpret_tensor(buf96, (32, 32, 512), (16384, 512, 1), 0); del buf96 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, float_21, pow_9, mean_8, add_16, rsqrt_8, mul_28, output_16, mul_29], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf124, buf129, _frozen_param37, buf131, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf132 = reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0); del buf128 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf131, (1024, 512), (512, 1), 0), _frozen_param143, out=buf132) | |
| # Topologically Sorted Source Nodes: [setitem_9], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf132, arg85_1, 524288, grid=grid(524288), stream=stream0) | |
| buf133 = buf107; del buf107 # reuse | |
| # Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf132, buf133, 524288, grid=grid(524288), stream=stream0) | |
| buf140 = buf100; del buf100 # reuse | |
| # Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf132, buf140, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex] | |
| buf134 = torch.ops.aten.view_as_complex.default(buf133) | |
| buf135 = buf134 | |
| # Topologically Sorted Source Nodes: [mul_30], Original ATen: [aten.mul] | |
| buf136 = torch.ops.aten.mul.Tensor(buf135, _frozen_param79) | |
| del buf134 | |
| del buf135 | |
| buf137 = buf136 | |
| del buf136 | |
| # Topologically Sorted Source Nodes: [view_as_real_8], Original ATen: [aten.view_as_real] | |
| buf138 = torch.ops.aten.view_as_real.default(buf137) | |
| buf139 = buf138 | |
| buf149 = reinterpret_tensor(buf131, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf131 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf139, buf149, 524288, grid=grid(524288), stream=stream0) | |
| del buf137 | |
| del buf138 | |
| del buf139 | |
| # Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex] | |
| buf141 = torch.ops.aten.view_as_complex.default(buf140) | |
| buf142 = buf141 | |
| # Topologically Sorted Source Nodes: [mul_31], Original ATen: [aten.mul] | |
| buf143 = torch.ops.aten.mul.Tensor(buf142, _frozen_param79) | |
| del buf141 | |
| del buf142 | |
| buf144 = buf143 | |
| del buf143 | |
| # Topologically Sorted Source Nodes: [view_as_real_9], Original ATen: [aten.view_as_real] | |
| buf145 = torch.ops.aten.view_as_real.default(buf144) | |
| buf146 = buf145 | |
| # Topologically Sorted Source Nodes: [xk_14, setitem_8], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf146, arg84_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf144 | |
| del buf145 | |
| del buf146 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf150 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf149, reinterpret_tensor(arg84_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg85_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf151 = buf150[0] | |
| del buf150 | |
| buf156 = reinterpret_tensor(buf149, (1024, 512), (512, 1), 0); del buf149 # reuse | |
| # Topologically Sorted Source Nodes: [linear_31], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf151, (1024, 512), (512, 1), 0), _frozen_param108, out=buf156) | |
| buf158 = reinterpret_tensor(buf151, (32, 32, 512), (16384, 512, 1), 0); del buf151 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, h_5, float_25, pow_10, mean_9, add_18, rsqrt_9, mul_32, output_19, mul_33], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf124, buf129, buf156, _frozen_param42, buf158, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf159 = buf127; del buf127 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf158, (1024, 512), (512, 1), 0), _frozen_param144, out=buf159) | |
| buf160 = reinterpret_tensor(buf132, (32, 32, 1536), (49152, 1536, 1), 0); del buf132 # reuse | |
| # Topologically Sorted Source Nodes: [silu_4, mul_34], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf159, buf160, 1572864, grid=grid(1572864), stream=stream0) | |
| buf161 = reinterpret_tensor(buf158, (1024, 512), (512, 1), 0); del buf158 # reuse | |
| # Topologically Sorted Source Nodes: [linear_34], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0), _frozen_param111, out=buf161) | |
| buf163 = reinterpret_tensor(buf91, (32, 32, 512), (16384, 512, 1), 0); del buf91 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, h_5, out_4, float_26, pow_11, mean_10, add_20, rsqrt_10, mul_35, output_20, mul_36], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf124, buf129, buf156, buf161, _frozen_param46, buf163, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf164 = reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0); del buf160 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf163, (1024, 512), (512, 1), 0), _frozen_param145, out=buf164) | |
| # Topologically Sorted Source Nodes: [setitem_11], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf164, arg87_1, 524288, grid=grid(524288), stream=stream0) | |
| buf165 = buf140; del buf140 # reuse | |
| # Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf164, buf165, 524288, grid=grid(524288), stream=stream0) | |
| buf172 = buf133; del buf133 # reuse | |
| # Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf164, buf172, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex] | |
| buf166 = torch.ops.aten.view_as_complex.default(buf165) | |
| buf167 = buf166 | |
| # Topologically Sorted Source Nodes: [mul_37], Original ATen: [aten.mul] | |
| buf168 = torch.ops.aten.mul.Tensor(buf167, _frozen_param79) | |
| del buf166 | |
| del buf167 | |
| buf169 = buf168 | |
| del buf168 | |
| # Topologically Sorted Source Nodes: [view_as_real_10], Original ATen: [aten.view_as_real] | |
| buf170 = torch.ops.aten.view_as_real.default(buf169) | |
| buf171 = buf170 | |
| buf181 = reinterpret_tensor(buf163, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf163 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf171, buf181, 524288, grid=grid(524288), stream=stream0) | |
| del buf169 | |
| del buf170 | |
| del buf171 | |
| # Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex] | |
| buf173 = torch.ops.aten.view_as_complex.default(buf172) | |
| buf174 = buf173 | |
| # Topologically Sorted Source Nodes: [mul_38], Original ATen: [aten.mul] | |
| buf175 = torch.ops.aten.mul.Tensor(buf174, _frozen_param79) | |
| del buf173 | |
| del buf174 | |
| buf176 = buf175 | |
| del buf175 | |
| # Topologically Sorted Source Nodes: [view_as_real_11], Original ATen: [aten.view_as_real] | |
| buf177 = torch.ops.aten.view_as_real.default(buf176) | |
| buf178 = buf177 | |
| # Topologically Sorted Source Nodes: [xk_17, setitem_10], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf178, arg86_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf176 | |
| del buf177 | |
| del buf178 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf182 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf181, reinterpret_tensor(arg86_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg87_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf183 = buf182[0] | |
| del buf182 | |
| buf188 = reinterpret_tensor(buf181, (1024, 512), (512, 1), 0); del buf181 # reuse | |
| # Topologically Sorted Source Nodes: [linear_38], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf183, (1024, 512), (512, 1), 0), _frozen_param115, out=buf188) | |
| buf189 = buf124; del buf124 # reuse | |
| buf191 = reinterpret_tensor(buf183, (32, 32, 512), (16384, 512, 1), 0); del buf183 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, h_5, out_4, h_6, float_30, pow_12, mean_11, add_22, rsqrt_11, mul_39, output_23, mul_40], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf189, buf129, buf156, buf161, buf188, _frozen_param51, buf191, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf129 | |
| del buf156 | |
| buf192 = buf159; del buf159 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf191, (1024, 512), (512, 1), 0), _frozen_param146, out=buf192) | |
| buf193 = reinterpret_tensor(buf164, (32, 32, 1536), (49152, 1536, 1), 0); del buf164 # reuse | |
| # Topologically Sorted Source Nodes: [silu_5, mul_41], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf192, buf193, 1572864, grid=grid(1572864), stream=stream0) | |
| buf194 = reinterpret_tensor(buf191, (1024, 512), (512, 1), 0); del buf191 # reuse | |
| # Topologically Sorted Source Nodes: [linear_41], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0), _frozen_param118, out=buf194) | |
| buf196 = reinterpret_tensor(buf188, (32, 32, 512), (16384, 512, 1), 0); del buf188 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, float_31, pow_13, mean_12, add_24, rsqrt_12, mul_42, output_24, mul_43], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf189, buf194, _frozen_param55, buf196, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf197 = reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0); del buf193 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf196, (1024, 512), (512, 1), 0), _frozen_param147, out=buf197) | |
| # Topologically Sorted Source Nodes: [setitem_13], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf197, arg89_1, 524288, grid=grid(524288), stream=stream0) | |
| buf198 = buf172; del buf172 # reuse | |
| # Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf197, buf198, 524288, grid=grid(524288), stream=stream0) | |
| buf205 = buf165; del buf165 # reuse | |
| # Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf197, buf205, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex] | |
| buf199 = torch.ops.aten.view_as_complex.default(buf198) | |
| buf200 = buf199 | |
| # Topologically Sorted Source Nodes: [mul_44], Original ATen: [aten.mul] | |
| buf201 = torch.ops.aten.mul.Tensor(buf200, _frozen_param79) | |
| del buf199 | |
| del buf200 | |
| buf202 = buf201 | |
| del buf201 | |
| # Topologically Sorted Source Nodes: [view_as_real_12], Original ATen: [aten.view_as_real] | |
| buf203 = torch.ops.aten.view_as_real.default(buf202) | |
| buf204 = buf203 | |
| buf214 = reinterpret_tensor(buf196, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf196 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf204, buf214, 524288, grid=grid(524288), stream=stream0) | |
| del buf202 | |
| del buf203 | |
| del buf204 | |
| # Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex] | |
| buf206 = torch.ops.aten.view_as_complex.default(buf205) | |
| buf207 = buf206 | |
| # Topologically Sorted Source Nodes: [mul_45], Original ATen: [aten.mul] | |
| buf208 = torch.ops.aten.mul.Tensor(buf207, _frozen_param79) | |
| del buf206 | |
| del buf207 | |
| buf209 = buf208 | |
| del buf208 | |
| # Topologically Sorted Source Nodes: [view_as_real_13], Original ATen: [aten.view_as_real] | |
| buf210 = torch.ops.aten.view_as_real.default(buf209) | |
| buf211 = buf210 | |
| # Topologically Sorted Source Nodes: [xk_20, setitem_12], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf211, arg88_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf209 | |
| del buf210 | |
| del buf211 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf215 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf214, reinterpret_tensor(arg88_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg89_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf216 = buf215[0] | |
| del buf215 | |
| buf221 = reinterpret_tensor(buf214, (1024, 512), (512, 1), 0); del buf214 # reuse | |
| # Topologically Sorted Source Nodes: [linear_45], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf216, (1024, 512), (512, 1), 0), _frozen_param122, out=buf221) | |
| buf223 = reinterpret_tensor(buf216, (32, 32, 512), (16384, 512, 1), 0); del buf216 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, h_7, float_35, pow_14, mean_13, add_26, rsqrt_13, mul_46, output_27, mul_47], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf189, buf194, buf221, _frozen_param60, buf223, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf224 = buf192; del buf192 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf223, (1024, 512), (512, 1), 0), _frozen_param148, out=buf224) | |
| buf225 = reinterpret_tensor(buf197, (32, 32, 1536), (49152, 1536, 1), 0); del buf197 # reuse | |
| # Topologically Sorted Source Nodes: [silu_6, mul_48], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf224, buf225, 1572864, grid=grid(1572864), stream=stream0) | |
| buf226 = reinterpret_tensor(buf223, (1024, 512), (512, 1), 0); del buf223 # reuse | |
| # Topologically Sorted Source Nodes: [linear_48], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0), _frozen_param125, out=buf226) | |
| buf228 = reinterpret_tensor(buf161, (32, 32, 512), (16384, 512, 1), 0); del buf161 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, h_7, out_6, float_36, pow_15, mean_14, add_28, rsqrt_14, mul_49, output_28, mul_50], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf189, buf194, buf221, buf226, _frozen_param64, buf228, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf229 = reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0); del buf225 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf228, (1024, 512), (512, 1), 0), _frozen_param149, out=buf229) | |
| # Topologically Sorted Source Nodes: [setitem_15], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf229, arg91_1, 524288, grid=grid(524288), stream=stream0) | |
| buf230 = buf205; del buf205 # reuse | |
| # Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf229, buf230, 524288, grid=grid(524288), stream=stream0) | |
| buf237 = buf198; del buf198 # reuse | |
| # Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf229, buf237, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex] | |
| buf231 = torch.ops.aten.view_as_complex.default(buf230) | |
| buf232 = buf231 | |
| # Topologically Sorted Source Nodes: [mul_51], Original ATen: [aten.mul] | |
| buf233 = torch.ops.aten.mul.Tensor(buf232, _frozen_param79) | |
| del buf230 | |
| del buf231 | |
| del buf232 | |
| buf234 = buf233 | |
| del buf233 | |
| # Topologically Sorted Source Nodes: [view_as_real_14], Original ATen: [aten.view_as_real] | |
| buf235 = torch.ops.aten.view_as_real.default(buf234) | |
| buf236 = buf235 | |
| buf246 = reinterpret_tensor(buf228, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf228 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf236, buf246, 524288, grid=grid(524288), stream=stream0) | |
| del buf234 | |
| del buf235 | |
| del buf236 | |
| # Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex] | |
| buf238 = torch.ops.aten.view_as_complex.default(buf237) | |
| buf239 = buf238 | |
| # Topologically Sorted Source Nodes: [mul_52], Original ATen: [aten.mul] | |
| buf240 = torch.ops.aten.mul.Tensor(buf239, _frozen_param79) | |
| del buf237 | |
| del buf238 | |
| del buf239 | |
| buf241 = buf240 | |
| del buf240 | |
| # Topologically Sorted Source Nodes: [view_as_real_15], Original ATen: [aten.view_as_real] | |
| buf242 = torch.ops.aten.view_as_real.default(buf241) | |
| buf243 = buf242 | |
| # Topologically Sorted Source Nodes: [xk_23, setitem_14], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf243, arg90_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf241 | |
| del buf242 | |
| del buf243 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf247 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf246, reinterpret_tensor(arg90_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg91_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf248 = buf247[0] | |
| del buf247 | |
| buf253 = reinterpret_tensor(buf246, (1024, 512), (512, 1), 0); del buf246 # reuse | |
| # Topologically Sorted Source Nodes: [linear_52], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf248, (1024, 512), (512, 1), 0), _frozen_param129, out=buf253) | |
| buf254 = buf189; del buf189 # reuse | |
| buf256 = reinterpret_tensor(buf248, (32, 32, 512), (16384, 512, 1), 0); del buf248 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, h_7, out_6, h_8, float_40, pow_16, mean_15, add_30, rsqrt_15, mul_53, output_31, mul_54], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf254, buf194, buf221, buf226, buf253, _frozen_param69, buf256, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf194 | |
| del buf221 | |
| del buf226 | |
| del buf253 | |
| buf257 = buf224; del buf224 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf256, (1024, 512), (512, 1), 0), _frozen_param150, out=buf257) | |
| buf258 = reinterpret_tensor(buf229, (32, 32, 1536), (49152, 1536, 1), 0); del buf229 # reuse | |
| # Topologically Sorted Source Nodes: [silu_7, mul_55], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf257, buf258, 1572864, grid=grid(1572864), stream=stream0) | |
| del buf257 | |
| buf259 = reinterpret_tensor(buf256, (1024, 512), (512, 1), 0); del buf256 # reuse | |
| # Topologically Sorted Source Nodes: [linear_55], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf258, (1024, 1536), (1536, 1), 0), _frozen_param132, out=buf259) | |
| del buf258 | |
| buf261 = buf254; del buf254 # reuse | |
| # Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf261, buf259, _frozen_param73, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf259 | |
| buf262 = empty_strided_cuda((32, 32000), (32000, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [output_33], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf261, (32, 512), (16384, 1), 15872), _frozen_param134, out=buf262) | |
| del buf261 | |
| buf263 = empty_strided_cuda((32, 32000), (32000, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_15.run(buf262, buf263, 1024000, grid=grid(1024000), stream=stream0) | |
| del buf262 | |
| return (buf263, arg77_1, arg76_1, arg79_1, arg78_1, arg81_1, arg80_1, arg83_1, arg82_1, arg85_1, arg84_1, arg87_1, arg86_1, arg89_1, arg88_1, arg91_1, arg90_1, ) | |
| def benchmark_compiled_module(times=10, repeat=10): | |
| from torch._dynamo.testing import rand_strided | |
| from torch._inductor.utils import print_performance | |
| global _frozen_param0 | |
| _frozen_param0 = rand_strided((32000, 512), (512, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param1 | |
| _frozen_param1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param6 | |
| _frozen_param6 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param10 | |
| _frozen_param10 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param15 | |
| _frozen_param15 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param19 | |
| _frozen_param19 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param24 | |
| _frozen_param24 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param28 | |
| _frozen_param28 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param33 | |
| _frozen_param33 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param37 | |
| _frozen_param37 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param42 | |
| _frozen_param42 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param46 | |
| _frozen_param46 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param51 | |
| _frozen_param51 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param55 | |
| _frozen_param55 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param60 | |
| _frozen_param60 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param64 | |
| _frozen_param64 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param69 | |
| _frozen_param69 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param73 | |
| _frozen_param73 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param135 | |
| _frozen_param135 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param79 | |
| _frozen_param79 = rand_strided((1, 32, 1, 32), (1024, 32, 32, 1), device='cuda:0', dtype=torch.complex64) | |
| global _frozen_param80 | |
| _frozen_param80 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param136 | |
| _frozen_param136 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param83 | |
| _frozen_param83 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param137 | |
| _frozen_param137 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param87 | |
| _frozen_param87 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param138 | |
| _frozen_param138 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param90 | |
| _frozen_param90 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param139 | |
| _frozen_param139 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param94 | |
| _frozen_param94 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param140 | |
| _frozen_param140 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param97 | |
| _frozen_param97 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param141 | |
| _frozen_param141 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param101 | |
| _frozen_param101 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param142 | |
| _frozen_param142 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param104 | |
| _frozen_param104 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param143 | |
| _frozen_param143 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param108 | |
| _frozen_param108 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param144 | |
| _frozen_param144 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param111 | |
| _frozen_param111 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param145 | |
| _frozen_param145 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param115 | |
| _frozen_param115 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param146 | |
| _frozen_param146 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param118 | |
| _frozen_param118 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param147 | |
| _frozen_param147 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param122 | |
| _frozen_param122 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param148 | |
| _frozen_param148 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param125 | |
| _frozen_param125 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param149 | |
| _frozen_param149 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param129 | |
| _frozen_param129 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param150 | |
| _frozen_param150 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param132 | |
| _frozen_param132 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param134 | |
| _frozen_param134 = rand_strided((512, 32000), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| arg76_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg77_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg78_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg79_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg80_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg81_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg82_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg83_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg84_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg85_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg86_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg87_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg88_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg89_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg90_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg91_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg92_1 = rand_strided((32, 32), (32, 1), device='cuda:0', dtype=torch.int32) | |
| fn = lambda: call([arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1]) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| compiled_module_main('llama', benchmark_compiled_module) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # AOT ID: ['0_inference'] | |
| from ctypes import c_void_p, c_long, c_int | |
| import torch | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| from math import inf, nan | |
| from torch._inductor.hooks import run_intermediate_hooks | |
| from torch._inductor.utils import maybe_profile | |
| from torch._inductor.codegen.memory_planning import _align as align | |
| from torch import device, empty_strided | |
| from torch._inductor.async_compile import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime.triton_heuristics import ( | |
| grid, | |
| split_scan_grid, | |
| grid_combo_kernels, | |
| start_graph, | |
| end_graph, | |
| cooperative_reduction_grid, | |
| ) | |
| from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
| from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
| aten = torch.ops.aten | |
| inductor_ops = torch.ops.inductor | |
| _quantized = torch.ops._quantized | |
| assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
| empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
| empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
| empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
| reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
| alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
| async_compile = AsyncCompile() | |
| empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
| _frozen_param0 = None # device(type='cuda', index=0) torch.bfloat16 (32000, 512) (512, 1) 7f03b04ed800 | |
| _frozen_param1 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b0477ce0 | |
| _frozen_param6 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04ef650 | |
| _frozen_param10 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04edf80 | |
| _frozen_param15 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04ede90 | |
| _frozen_param19 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b0477dd0 | |
| _frozen_param24 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a2390 | |
| _frozen_param28 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a0ae0 | |
| _frozen_param33 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b05b3c90 | |
| _frozen_param37 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b05b3ba0 | |
| _frozen_param42 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b05422a0 | |
| _frozen_param46 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04e34c0 | |
| _frozen_param51 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a38d0 | |
| _frozen_param55 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04e0f40 | |
| _frozen_param60 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a0400 | |
| _frozen_param64 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04e2480 | |
| _frozen_param69 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04a23e0 | |
| _frozen_param73 = None # device(type='cuda', index=0) torch.bfloat16 (512,) (1,) 7f03b04ef9c0 | |
| _frozen_param135 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03241ca6b0 | |
| _frozen_param79 = None # device(type='cuda', index=0) torch.complex64 (1, 32, 1, 32) (1024, 32, 32, 1) 7f032416f060 | |
| _frozen_param80 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324131530 | |
| _frozen_param136 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f03241bf0b0 | |
| _frozen_param83 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f032427df30 | |
| _frozen_param137 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03241bda80 | |
| _frozen_param87 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324174b80 | |
| _frozen_param138 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f03241bfa60 | |
| _frozen_param90 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324174c70 | |
| _frozen_param139 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03243bc3b0 | |
| _frozen_param94 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324174db0 | |
| _frozen_param140 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f0324003830 | |
| _frozen_param97 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324174ea0 | |
| _frozen_param141 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f03243bc770 | |
| _frozen_param101 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324174fe0 | |
| _frozen_param142 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f0324204950 | |
| _frozen_param104 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f03241750d0 | |
| _frozen_param143 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032418ff10 | |
| _frozen_param108 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324175210 | |
| _frozen_param144 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032418e1b0 | |
| _frozen_param111 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175300 | |
| _frozen_param145 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032418fa10 | |
| _frozen_param115 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324175440 | |
| _frozen_param146 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032418ef20 | |
| _frozen_param118 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175530 | |
| _frozen_param147 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032418fdd0 | |
| _frozen_param122 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f0324175670 | |
| _frozen_param148 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032401c450 | |
| _frozen_param125 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175760 | |
| _frozen_param149 = None # device(type='cuda', index=0) torch.bfloat16 (512, 1536) (1536, 1) 7f032401f830 | |
| _frozen_param129 = None # device(type='cuda', index=0) torch.bfloat16 (512, 512) (1, 512) 7f03241758a0 | |
| _frozen_param150 = None # device(type='cuda', index=0) torch.bfloat16 (512, 3072) (3072, 1) 7f032401fbf0 | |
| _frozen_param132 = None # device(type='cuda', index=0) torch.bfloat16 (1536, 512) (1, 1536) 7f0324175990 | |
| _frozen_param134 = None # device(type='cuda', index=0) torch.bfloat16 (512, 32000) (1, 512) 7f0324130c70 | |
| # kernel path: /tmp/torchinductor_t/jq/cjqkz4jt5d5lkuk3r2wju344fss2vw6bnduyr5hqduvhlizskx43.py | |
| # Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add => add | |
| # float_1 => convert_element_type_1 | |
| # h => embedding | |
| # mean => mean | |
| # mul => mul | |
| # mul_1 => mul_1 | |
| # output => convert_element_type_2 | |
| # pow_1 => pow_1 | |
| # rsqrt => rsqrt | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %convert_element_type_1 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%embedding, torch.float32), kwargs = {}) | |
| # %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_1, 2), kwargs = {}) | |
| # %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {}) | |
| # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {}) | |
| # %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {}) | |
| # %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %rsqrt), kwargs = {}) | |
| # %convert_element_type_2 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
| # %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %arg1_1), kwargs = {}) | |
| triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.DEFAULT, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
| xnumel = 1024 | |
| rnumel = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
| xmask = xindex < xnumel | |
| rbase = tl.arange(0, RBLOCK)[None, :] | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last') | |
| _tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) | |
| for roffset in range(0, rnumel, RBLOCK): | |
| rindex = roffset + rbase | |
| rmask = rindex < rnumel | |
| r1 = rindex | |
| tmp1 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert(((0 <= tmp4) & (tmp4 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp7 = tmp6.to(tl.float32) | |
| tmp8 = tmp7 * tmp7 | |
| tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK]) | |
| tmp11 = _tmp10 + tmp9 | |
| _tmp10 = tl.where(rmask & xmask, tmp11, _tmp10) | |
| tmp10 = tl.sum(_tmp10, 1)[:, None] | |
| for roffset in range(0, rnumel, RBLOCK): | |
| rindex = roffset + rbase | |
| rmask = rindex < rnumel | |
| r1 = rindex | |
| tmp26 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
| tmp12 = tl.full([XBLOCK, RBLOCK], 32000, tl.int32) | |
| tmp13 = tmp0 + tmp12 | |
| tmp14 = tmp0 < 0 | |
| tmp15 = tl.where(tmp14, tmp13, tmp0) | |
| tl.device_assert(((0 <= tmp15) & (tmp15 < 32000)) | ~(xmask), "index out of bounds: 0 <= tmp15 < 32000") | |
| tmp17 = tl.load(in_ptr1 + (r1 + 512*tmp15), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
| tmp18 = tmp17.to(tl.float32) | |
| tmp19 = 512.0 | |
| tmp20 = (tmp10 / tmp19).to(tl.float32) | |
| tmp21 = 1e-05 | |
| tmp22 = tmp20 + tmp21 | |
| tmp23 = libdevice.rsqrt(tmp22) | |
| tmp24 = tmp18 * tmp23 | |
| tmp25 = tmp24.to(tl.float32) | |
| tmp27 = tmp25 * tmp26 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp27, rmask & xmask) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/xc/cxcmzl6ilmlg65glf4pzjcxwxpduicy2vfmzzjio5qccmdqzy5s3.py | |
| # Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy] | |
| # Source node to ATen node mapping: | |
| # setitem_1 => copy_1 | |
| # Graph fragment: | |
| # %copy_1 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_4, %view_8), kwargs = {}) | |
| # %copy__default_1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor_1, %copy_1), kwargs = {}) | |
| triton_poi_fused_copy_1 = async_compile.triton('triton_poi_fused_copy_1', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_copy_1', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 512) | |
| x3 = xindex // 512 | |
| x2 = xindex // 16384 | |
| x4 = (xindex % 16384) | |
| tmp0 = tl.load(in_ptr0 + (x0 + 1536*x3), None).to(tl.float32) | |
| tl.store(out_ptr0 + (512 + x4 + 524288*x2), tmp0, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/7m/c7m45i7porfguqnzfi4evudvthp6s46rtboxlqxs2ukmts7oogqy.py | |
| # Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex] | |
| # Source node to ATen node mapping: | |
| # xq_ => view_as_complex | |
| # Graph fragment: | |
| # %view_as_complex : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_9,), kwargs = {}) | |
| triton_poi_fused_view_as_complex_2 = async_compile.triton('triton_poi_fused_view_as_complex_2', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_view_as_complex_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 512) | |
| x1 = xindex // 512 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (512 + x0 + 1536*x1), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x2), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/65/c65ekeiwoihqn3z3eozs5kfs2w7fuc6dtsursyyedmm5n3idz6aw.py | |
| # Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex] | |
| # Source node to ATen node mapping: | |
| # xk_ => view_as_complex_1 | |
| # Graph fragment: | |
| # %view_as_complex_1 : [num_users=1] = call_function[target=torch.ops.aten.view_as_complex.default](args = (%view_10,), kwargs = {}) | |
| triton_poi_fused_view_as_complex_3 = async_compile.triton('triton_poi_fused_view_as_complex_3', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_view_as_complex_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_view_as_complex_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 512) | |
| x1 = xindex // 512 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (1024 + x0 + 1536*x1), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x2), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/fr/cfrphf4roqxuzvwmpfjsqykxw7ykp6blxuab7a762yx75y5ssqwv.py | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| # Source node to ATen node mapping: | |
| # Graph fragment: | |
| # %_scaled_dot_product_flash_attention_default_7 : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%permute_default_21, %permute_default_22, %permute_default_23), kwargs = {scale: 0.125}) | |
| triton_poi_fused_4 = async_compile.triton('triton_poi_fused_4', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/pm/cpmwz4lr5hqfizv3wxqz5p5i5tjfamfs6btvjdkmpwz6sixmisvp.py | |
| # Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy] | |
| # Source node to ATen node mapping: | |
| # setitem => copy | |
| # xk_2 => convert_element_type_12 | |
| # Graph fragment: | |
| # %convert_element_type_12 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_13, torch.bfloat16), kwargs = {}) | |
| # %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %convert_element_type_12), kwargs = {}) | |
| # %copy__default : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_tensor, %copy), kwargs = {}) | |
| triton_poi_fused__to_copy_copy_5 = async_compile.triton('triton_poi_fused__to_copy_copy_5', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 524288}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_copy_5', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_copy_5(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 524288 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x2 = xindex | |
| x0 = (xindex % 16384) | |
| x1 = xindex // 16384 | |
| tmp0 = tl.load(in_ptr0 + (x2), None) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (512 + x0 + 524288*x1), tmp1, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/dr/cdrrgy3ocqlbwxy4cv6ytxs3apsjv37qs3l2swp2p2sgajmtmocx.py | |
| # Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_2 => add_2 | |
| # float_5 => convert_element_type_21 | |
| # h => embedding | |
| # h_1 => add_1 | |
| # mean_1 => mean_1 | |
| # mul_4 => mul_4 | |
| # mul_5 => mul_5 | |
| # output_3 => convert_element_type_22 | |
| # pow_2 => pow_2 | |
| # rsqrt_1 => rsqrt_1 | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {}) | |
| # %convert_element_type_21 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.float32), kwargs = {}) | |
| # %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_21, 2), kwargs = {}) | |
| # %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {}) | |
| # %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {}) | |
| # %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_2,), kwargs = {}) | |
| # %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_21, %rsqrt_1), kwargs = {}) | |
| # %convert_element_type_22 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_4, torch.bfloat16), kwargs = {}) | |
| # %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %arg6_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| x0 = xindex | |
| r1 = rindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
| tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp21 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp1 = tl.full([RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32) | |
| tmp8 = tmp6 + tmp7 | |
| tmp9 = tmp8.to(tl.float32) | |
| tmp10 = tmp9 * tmp9 | |
| tmp11 = tl.broadcast_to(tmp10, [RBLOCK]) | |
| tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0)) | |
| tmp14 = 512.0 | |
| tmp15 = (tmp13 / tmp14).to(tl.float32) | |
| tmp16 = 1e-05 | |
| tmp17 = tmp15 + tmp16 | |
| tmp18 = libdevice.rsqrt(tmp17) | |
| tmp19 = tmp9 * tmp18 | |
| tmp20 = tmp19.to(tl.float32) | |
| tmp22 = tmp20 * tmp21 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/37/c37cieheeo7zfl4rbysqwzkbdr7pwfgmv5au7ifpl2r5xifjrj5v.py | |
| # Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul] | |
| # Source node to ATen node mapping: | |
| # mul_6 => mul_7 | |
| # silu => convert_element_type_25, convert_element_type_26, mul_6, sigmoid | |
| # Graph fragment: | |
| # %convert_element_type_25 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_24, torch.float32), kwargs = {}) | |
| # %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_25,), kwargs = {}) | |
| # %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_25, %sigmoid), kwargs = {}) | |
| # %convert_element_type_26 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {}) | |
| # %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_26, %view_26), kwargs = {}) | |
| triton_poi_fused_mul_silu_7 = async_compile.triton('triton_poi_fused_mul_silu_7', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 2097152}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_mul_silu_7(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1572864 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = (xindex % 1536) | |
| x1 = xindex // 1536 | |
| x2 = xindex | |
| tmp0 = tl.load(in_ptr0 + (1536 + x0 + 3072*x1), None).to(tl.float32) | |
| tmp5 = tl.load(in_ptr0 + (x0 + 3072*x1), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tmp2 = tl.sigmoid(tmp1) | |
| tmp3 = tmp1 * tmp2 | |
| tmp4 = tmp3.to(tl.float32) | |
| tmp6 = tmp4 * tmp5 | |
| tl.store(out_ptr0 + (x2), tmp6, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/oe/coezydc4r3hd4ebv3eien4ar5zw7e5oum674ezorliyuknllnim6.py | |
| # Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_4 => add_4 | |
| # float_6 => convert_element_type_31 | |
| # h => embedding | |
| # h_1 => add_1 | |
| # mean_2 => mean_2 | |
| # mul_7 => mul_8 | |
| # mul_8 => mul_9 | |
| # out => add_3 | |
| # output_4 => convert_element_type_32 | |
| # pow_3 => pow_3 | |
| # rsqrt_2 => rsqrt_2 | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {}) | |
| # %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {}) | |
| # %convert_element_type_31 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.float32), kwargs = {}) | |
| # %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_31, 2), kwargs = {}) | |
| # %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {}) | |
| # %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {}) | |
| # %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_4,), kwargs = {}) | |
| # %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_31, %rsqrt_2), kwargs = {}) | |
| # %convert_element_type_32 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_8, torch.bfloat16), kwargs = {}) | |
| # %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_32, %arg10_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| x0 = xindex | |
| r1 = rindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
| tmp7 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp9 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp23 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp1 = tl.full([RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32) | |
| tmp8 = tmp6 + tmp7 | |
| tmp10 = tmp8 + tmp9 | |
| tmp11 = tmp10.to(tl.float32) | |
| tmp12 = tmp11 * tmp11 | |
| tmp13 = tl.broadcast_to(tmp12, [RBLOCK]) | |
| tmp15 = triton_helpers.promote_to_tensor(tl.sum(tmp13, 0)) | |
| tmp16 = 512.0 | |
| tmp17 = (tmp15 / tmp16).to(tl.float32) | |
| tmp18 = 1e-05 | |
| tmp19 = tmp17 + tmp18 | |
| tmp20 = libdevice.rsqrt(tmp19) | |
| tmp21 = tmp11 * tmp20 | |
| tmp22 = tmp21.to(tl.float32) | |
| tmp24 = tmp22 * tmp23 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp24, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/hg/chg35he4i3wb65f2xla6t2f46hl3nvqroelcorgcij2tuqcakabg.py | |
| # Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_6 => add_6 | |
| # float_10 => convert_element_type_51 | |
| # h => embedding | |
| # h_1 => add_1 | |
| # h_2 => add_5 | |
| # mean_3 => mean_3 | |
| # mul_11 => mul_12 | |
| # mul_12 => mul_13 | |
| # out => add_3 | |
| # output_7 => convert_element_type_52 | |
| # pow_4 => pow_4 | |
| # rsqrt_3 => rsqrt_3 | |
| # Graph fragment: | |
| # %embedding : [num_users=2] = call_function[target=torch.ops.aten.embedding.default](args = (%arg0_1, %arg92_1), kwargs = {}) | |
| # %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %view_22), kwargs = {}) | |
| # %add_3 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_28), kwargs = {}) | |
| # %add_5 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %view_51), kwargs = {}) | |
| # %convert_element_type_51 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_5, torch.float32), kwargs = {}) | |
| # %pow_4 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_51, 2), kwargs = {}) | |
| # %mean_3 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {}) | |
| # %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {}) | |
| # %rsqrt_3 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_6,), kwargs = {}) | |
| # %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_51, %rsqrt_3), kwargs = {}) | |
| # %convert_element_type_52 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_12, torch.bfloat16), kwargs = {}) | |
| # %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_52, %arg15_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| x0 = xindex | |
| r1 = rindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
| tmp7 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp9 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp11 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp25 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp1 = tl.full([RBLOCK], 32000, tl.int32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp0 < 0 | |
| tmp4 = tl.where(tmp3, tmp2, tmp0) | |
| tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000") | |
| tmp6 = tl.load(in_ptr1 + (r1 + 512*tmp4), None).to(tl.float32) | |
| tmp8 = tmp6 + tmp7 | |
| tmp10 = tmp8 + tmp9 | |
| tmp12 = tmp10 + tmp11 | |
| tmp13 = tmp12.to(tl.float32) | |
| tmp14 = tmp13 * tmp13 | |
| tmp15 = tl.broadcast_to(tmp14, [RBLOCK]) | |
| tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp15, 0)) | |
| tmp18 = 512.0 | |
| tmp19 = (tmp17 / tmp18).to(tl.float32) | |
| tmp20 = 1e-05 | |
| tmp21 = tmp19 + tmp20 | |
| tmp22 = libdevice.rsqrt(tmp21) | |
| tmp23 = tmp13 * tmp22 | |
| tmp24 = tmp23.to(tl.float32) | |
| tmp26 = tmp24 * tmp25 | |
| tl.store(in_out_ptr0 + (r1 + 512*x0), tmp12, None) | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp26, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/iy/ciy4z4yq6ins6l5ivop2w7p7zdinblxm2voqklkpqk7pgxbr7pil.py | |
| # Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_8 => add_8 | |
| # float_11 => convert_element_type_61 | |
| # mean_4 => mean_4 | |
| # mul_14 => mul_16 | |
| # mul_15 => mul_17 | |
| # out_1 => add_7 | |
| # output_8 => convert_element_type_62 | |
| # pow_5 => pow_5 | |
| # rsqrt_4 => rsqrt_4 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %convert_element_type_61 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.float32), kwargs = {}) | |
| # %pow_5 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_61, 2), kwargs = {}) | |
| # %mean_4 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {}) | |
| # %add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {}) | |
| # %rsqrt_4 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {}) | |
| # %mul_16 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_61, %rsqrt_4), kwargs = {}) | |
| # %convert_element_type_62 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) | |
| # %mul_17 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_62, %arg19_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp15 = tl.load(in_ptr2 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp4 = tmp3 * tmp3 | |
| tmp5 = tl.broadcast_to(tmp4, [RBLOCK]) | |
| tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
| tmp8 = 512.0 | |
| tmp9 = (tmp7 / tmp8).to(tl.float32) | |
| tmp10 = 1e-05 | |
| tmp11 = tmp9 + tmp10 | |
| tmp12 = libdevice.rsqrt(tmp11) | |
| tmp13 = tmp3 * tmp12 | |
| tmp14 = tmp13.to(tl.float32) | |
| tmp16 = tmp14 * tmp15 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp16, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/pn/cpnh5dfi4oms6ccoppdagviyr73sfnfnxsb7fxqww3qrixi7lyh4.py | |
| # Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_10 => add_10 | |
| # float_15 => convert_element_type_81 | |
| # h_3 => add_9 | |
| # mean_5 => mean_5 | |
| # mul_18 => mul_20 | |
| # mul_19 => mul_21 | |
| # out_1 => add_7 | |
| # output_11 => convert_element_type_82 | |
| # pow_6 => pow_6 | |
| # rsqrt_5 => rsqrt_5 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {}) | |
| # %convert_element_type_81 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_9, torch.float32), kwargs = {}) | |
| # %pow_6 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_81, 2), kwargs = {}) | |
| # %mean_5 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_6, [-1], True), kwargs = {}) | |
| # %add_10 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_5, 1e-05), kwargs = {}) | |
| # %rsqrt_5 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_10,), kwargs = {}) | |
| # %mul_20 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_81, %rsqrt_5), kwargs = {}) | |
| # %convert_element_type_82 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_20, torch.bfloat16), kwargs = {}) | |
| # %mul_21 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_82, %arg24_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp17 = tl.load(in_ptr3 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp5 = tmp4.to(tl.float32) | |
| tmp6 = tmp5 * tmp5 | |
| tmp7 = tl.broadcast_to(tmp6, [RBLOCK]) | |
| tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0)) | |
| tmp10 = 512.0 | |
| tmp11 = (tmp9 / tmp10).to(tl.float32) | |
| tmp12 = 1e-05 | |
| tmp13 = tmp11 + tmp12 | |
| tmp14 = libdevice.rsqrt(tmp13) | |
| tmp15 = tmp5 * tmp14 | |
| tmp16 = tmp15.to(tl.float32) | |
| tmp18 = tmp16 * tmp17 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp18, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/dw/cdwfdyb3igstcir6uoaelj2eyiyqrmimqovbk6l57r4zvxcvjy7k.py | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_12 => add_12 | |
| # float_16 => convert_element_type_91 | |
| # h_3 => add_9 | |
| # mean_6 => mean_6 | |
| # mul_21 => mul_24 | |
| # mul_22 => mul_25 | |
| # out_1 => add_7 | |
| # out_2 => add_11 | |
| # output_12 => convert_element_type_92 | |
| # pow_7 => pow_7 | |
| # rsqrt_6 => rsqrt_6 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {}) | |
| # %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {}) | |
| # %convert_element_type_91 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {}) | |
| # %pow_7 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_91, 2), kwargs = {}) | |
| # %mean_6 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_7, [-1], True), kwargs = {}) | |
| # %add_12 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_6, 1e-05), kwargs = {}) | |
| # %rsqrt_6 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {}) | |
| # %mul_24 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_91, %rsqrt_6), kwargs = {}) | |
| # %convert_element_type_92 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_24, torch.bfloat16), kwargs = {}) | |
| # %mul_25 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_92, %arg28_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 5, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp5 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp19 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp6 = tmp4 + tmp5 | |
| tmp7 = tmp6.to(tl.float32) | |
| tmp8 = tmp7 * tmp7 | |
| tmp9 = tl.broadcast_to(tmp8, [RBLOCK]) | |
| tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
| tmp12 = 512.0 | |
| tmp13 = (tmp11 / tmp12).to(tl.float32) | |
| tmp14 = 1e-05 | |
| tmp15 = tmp13 + tmp14 | |
| tmp16 = libdevice.rsqrt(tmp15) | |
| tmp17 = tmp7 * tmp16 | |
| tmp18 = tmp17.to(tl.float32) | |
| tmp20 = tmp18 * tmp19 | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp20, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/z3/cz3e34micft5wnsnm2gnpqmih6vve6ujd2ccrnpkysa5obrpxhpi.py | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_14 => add_14 | |
| # float_20 => convert_element_type_111 | |
| # h_3 => add_9 | |
| # h_4 => add_13 | |
| # mean_7 => mean_7 | |
| # mul_25 => mul_28 | |
| # mul_26 => mul_29 | |
| # out_1 => add_7 | |
| # out_2 => add_11 | |
| # output_15 => convert_element_type_112 | |
| # pow_8 => pow_8 | |
| # rsqrt_7 => rsqrt_7 | |
| # Graph fragment: | |
| # %add_7 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_57), kwargs = {}) | |
| # %add_9 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_80), kwargs = {}) | |
| # %add_11 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_86), kwargs = {}) | |
| # %add_13 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_109), kwargs = {}) | |
| # %convert_element_type_111 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_13, torch.float32), kwargs = {}) | |
| # %pow_8 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_111, 2), kwargs = {}) | |
| # %mean_7 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_8, [-1], True), kwargs = {}) | |
| # %add_14 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_7, 1e-05), kwargs = {}) | |
| # %rsqrt_7 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_14,), kwargs = {}) | |
| # %mul_28 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_111, %rsqrt_7), kwargs = {}) | |
| # %convert_element_type_112 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_28, torch.bfloat16), kwargs = {}) | |
| # %mul_29 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_112, %arg33_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 6, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp3 = tl.load(in_ptr1 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp5 = tl.load(in_ptr2 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp7 = tl.load(in_ptr3 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp21 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 + tmp3 | |
| tmp6 = tmp4 + tmp5 | |
| tmp8 = tmp6 + tmp7 | |
| tmp9 = tmp8.to(tl.float32) | |
| tmp10 = tmp9 * tmp9 | |
| tmp11 = tl.broadcast_to(tmp10, [RBLOCK]) | |
| tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0)) | |
| tmp14 = 512.0 | |
| tmp15 = (tmp13 / tmp14).to(tl.float32) | |
| tmp16 = 1e-05 | |
| tmp17 = tmp15 + tmp16 | |
| tmp18 = libdevice.rsqrt(tmp17) | |
| tmp19 = tmp9 * tmp18 | |
| tmp20 = tmp19.to(tl.float32) | |
| tmp22 = tmp20 * tmp21 | |
| tl.store(in_out_ptr0 + (r1 + 512*x0), tmp8, None) | |
| tl.store(out_ptr1 + (r1 + 512*x0), tmp22, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/lv/clv7iqdidc6yovjvhgzra2tlgehvzuyjf6c4drcdk3cgzbb7tbn4.py | |
| # Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| # Source node to ATen node mapping: | |
| # add_32 => add_32 | |
| # float_41 => convert_element_type_241 | |
| # h_9 => mul_65 | |
| # mean_16 => mean_16 | |
| # mul_56 => mul_64 | |
| # out_7 => add_31 | |
| # output_32 => convert_element_type_242 | |
| # pow_17 => pow_17 | |
| # rsqrt_16 => rsqrt_16 | |
| # Graph fragment: | |
| # %add_31 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_29, %view_231), kwargs = {}) | |
| # %convert_element_type_241 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_31, torch.float32), kwargs = {}) | |
| # %pow_17 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_241, 2), kwargs = {}) | |
| # %mean_16 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_17, [-1], True), kwargs = {}) | |
| # %add_32 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_16, 1e-05), kwargs = {}) | |
| # %rsqrt_16 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_32,), kwargs = {}) | |
| # %mul_64 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_241, %rsqrt_16), kwargs = {}) | |
| # %convert_element_type_242 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_64, torch.bfloat16), kwargs = {}) | |
| # %mul_65 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_242, %arg73_1), kwargs = {}) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14 = async_compile.triton('triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.persistent_reduction( | |
| size_hints={'x': 1024, 'r': 512}, | |
| reduction_hint=ReductionHint.INNER, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} | |
| ) | |
| @triton.jit | |
| def triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14(in_out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel): | |
| xnumel = 1024 | |
| XBLOCK: tl.constexpr = 1 | |
| rnumel = 512 | |
| RBLOCK: tl.constexpr = 512 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = tl.full([1], xoffset, tl.int32) | |
| xmask = tl.full([RBLOCK], True, tl.int1) | |
| rindex = tl.arange(0, RBLOCK)[:] | |
| roffset = 0 | |
| rmask = tl.full([RBLOCK], True, tl.int1) | |
| r1 = rindex | |
| x0 = xindex | |
| tmp0 = tl.load(in_out_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp1 = tl.load(in_ptr0 + (r1 + 512*x0), None).to(tl.float32) | |
| tmp15 = tl.load(in_ptr1 + (r1), None, eviction_policy='evict_last').to(tl.float32) | |
| tmp2 = tmp0 + tmp1 | |
| tmp3 = tmp2.to(tl.float32) | |
| tmp4 = tmp3 * tmp3 | |
| tmp5 = tl.broadcast_to(tmp4, [RBLOCK]) | |
| tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
| tmp8 = 512.0 | |
| tmp9 = (tmp7 / tmp8).to(tl.float32) | |
| tmp10 = 1e-05 | |
| tmp11 = tmp9 + tmp10 | |
| tmp12 = libdevice.rsqrt(tmp11) | |
| tmp13 = tmp3 * tmp12 | |
| tmp14 = tmp13.to(tl.float32) | |
| tmp16 = tmp14 * tmp15 | |
| tl.store(in_out_ptr0 + (r1 + 512*x0), tmp16, None) | |
| ''', device_str='cuda') | |
| # kernel path: /tmp/torchinductor_t/uy/cuyzkjsgklqarsm2yt3x6ruv2di7jsrafhofrtw3gdor4jvuttkh.py | |
| # Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy] | |
| # Source node to ATen node mapping: | |
| # float_42 => convert_element_type_245 | |
| # Graph fragment: | |
| # %convert_element_type_245 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_56, torch.float32), kwargs = {}) | |
| triton_poi_fused__to_copy_15 = async_compile.triton('triton_poi_fused__to_copy_15', ''' | |
| import triton | |
| import triton.language as tl | |
| from triton.compiler.compiler import AttrsDescriptor | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
| inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '64BAA4D72EC2265776862FB8CE89A91800211173499DCF3E6DB3986179E2EACB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused__to_copy_15(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 1024000 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = tl.full([XBLOCK], True, tl.int1) | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
| tmp1 = tmp0.to(tl.float32) | |
| tl.store(out_ptr0 + (x0), tmp1, None) | |
| ''', device_str='cuda') | |
| async_compile.wait(globals()) | |
| del async_compile | |
| def call(args): | |
| arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1 = args | |
| args.clear() | |
| assert_size_stride(arg76_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg77_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg78_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg79_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg80_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg81_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg82_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg83_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg84_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg85_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg86_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg87_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg88_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg89_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg90_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg91_1, (32, 1024, 8, 64), (524288, 512, 64, 1)) | |
| assert_size_stride(arg92_1, (32, 32), (32, 1)) | |
| with torch.cuda._DeviceGuard(0): | |
| torch.cuda.set_device(0) | |
| buf1 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [h, float_1, pow_1, mean, add, rsqrt, mul, output, mul_1], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg92_1, _frozen_param0, _frozen_param1, buf1, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf2 = empty_strided_cuda((1024, 1536), (1536, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf1, (1024, 512), (512, 1), 0), _frozen_param135, out=buf2) | |
| # Topologically Sorted Source Nodes: [setitem_1], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf2, arg77_1, 524288, grid=grid(524288), stream=stream0) | |
| buf3 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf2, buf3, 524288, grid=grid(524288), stream=stream0) | |
| buf10 = empty_strided_cuda((32, 32, 8, 32, 2), (16384, 512, 64, 2, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf2, buf10, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq_], Original ATen: [aten.view_as_complex] | |
| buf4 = torch.ops.aten.view_as_complex.default(buf3) | |
| buf5 = buf4 | |
| # Topologically Sorted Source Nodes: [mul_2], Original ATen: [aten.mul] | |
| buf6 = torch.ops.aten.mul.Tensor(buf5, _frozen_param79) | |
| del buf4 | |
| del buf5 | |
| buf7 = buf6 | |
| del buf6 | |
| # Topologically Sorted Source Nodes: [view_as_real], Original ATen: [aten.view_as_real] | |
| buf8 = torch.ops.aten.view_as_real.default(buf7) | |
| buf9 = buf8 | |
| buf19 = reinterpret_tensor(buf1, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf1 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf9, buf19, 524288, grid=grid(524288), stream=stream0) | |
| del buf7 | |
| del buf8 | |
| del buf9 | |
| # Topologically Sorted Source Nodes: [xk_], Original ATen: [aten.view_as_complex] | |
| buf11 = torch.ops.aten.view_as_complex.default(buf10) | |
| buf12 = buf11 | |
| # Topologically Sorted Source Nodes: [mul_3], Original ATen: [aten.mul] | |
| buf13 = torch.ops.aten.mul.Tensor(buf12, _frozen_param79) | |
| del buf11 | |
| del buf12 | |
| buf14 = buf13 | |
| del buf13 | |
| # Topologically Sorted Source Nodes: [view_as_real_1], Original ATen: [aten.view_as_real] | |
| buf15 = torch.ops.aten.view_as_real.default(buf14) | |
| buf16 = buf15 | |
| # Topologically Sorted Source Nodes: [xk_2, setitem], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf16, arg76_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf14 | |
| del buf15 | |
| del buf16 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf20 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf19, reinterpret_tensor(arg76_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg77_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf21 = buf20[0] | |
| del buf20 | |
| buf26 = reinterpret_tensor(buf19, (1024, 512), (512, 1), 0); del buf19 # reuse | |
| # Topologically Sorted Source Nodes: [linear_3], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf21, (1024, 512), (512, 1), 0), _frozen_param80, out=buf26) | |
| buf28 = reinterpret_tensor(buf21, (32, 32, 512), (16384, 512, 1), 0); del buf21 # reuse | |
| # Topologically Sorted Source Nodes: [h, h_1, float_5, pow_2, mean_1, add_2, rsqrt_1, mul_4, output_3, mul_5], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_6.run(arg92_1, _frozen_param0, buf26, _frozen_param6, buf28, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf29 = empty_strided_cuda((1024, 3072), (3072, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf28, (1024, 512), (512, 1), 0), _frozen_param136, out=buf29) | |
| buf30 = reinterpret_tensor(buf2, (32, 32, 1536), (49152, 1536, 1), 0); del buf2 # reuse | |
| # Topologically Sorted Source Nodes: [silu, mul_6], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf29, buf30, 1572864, grid=grid(1572864), stream=stream0) | |
| buf31 = reinterpret_tensor(buf28, (1024, 512), (512, 1), 0); del buf28 # reuse | |
| # Topologically Sorted Source Nodes: [linear_6], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0), _frozen_param83, out=buf31) | |
| buf33 = empty_strided_cuda((32, 32, 512), (16384, 512, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [h, h_1, out, float_6, pow_3, mean_2, add_4, rsqrt_2, mul_7, output_4, mul_8], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_8.run(arg92_1, _frozen_param0, buf26, buf31, _frozen_param10, buf33, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf34 = reinterpret_tensor(buf30, (1024, 1536), (1536, 1), 0); del buf30 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf33, (1024, 512), (512, 1), 0), _frozen_param137, out=buf34) | |
| # Topologically Sorted Source Nodes: [setitem_3], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf34, arg79_1, 524288, grid=grid(524288), stream=stream0) | |
| buf35 = buf10; del buf10 # reuse | |
| # Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf34, buf35, 524288, grid=grid(524288), stream=stream0) | |
| buf42 = buf3; del buf3 # reuse | |
| # Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf34, buf42, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__1], Original ATen: [aten.view_as_complex] | |
| buf36 = torch.ops.aten.view_as_complex.default(buf35) | |
| buf37 = buf36 | |
| # Topologically Sorted Source Nodes: [mul_9], Original ATen: [aten.mul] | |
| buf38 = torch.ops.aten.mul.Tensor(buf37, _frozen_param79) | |
| del buf36 | |
| del buf37 | |
| buf39 = buf38 | |
| del buf38 | |
| # Topologically Sorted Source Nodes: [view_as_real_2], Original ATen: [aten.view_as_real] | |
| buf40 = torch.ops.aten.view_as_real.default(buf39) | |
| buf41 = buf40 | |
| buf51 = reinterpret_tensor(buf33, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf33 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf41, buf51, 524288, grid=grid(524288), stream=stream0) | |
| del buf39 | |
| del buf40 | |
| del buf41 | |
| # Topologically Sorted Source Nodes: [xk__1], Original ATen: [aten.view_as_complex] | |
| buf43 = torch.ops.aten.view_as_complex.default(buf42) | |
| buf44 = buf43 | |
| # Topologically Sorted Source Nodes: [mul_10], Original ATen: [aten.mul] | |
| buf45 = torch.ops.aten.mul.Tensor(buf44, _frozen_param79) | |
| del buf43 | |
| del buf44 | |
| buf46 = buf45 | |
| del buf45 | |
| # Topologically Sorted Source Nodes: [view_as_real_3], Original ATen: [aten.view_as_real] | |
| buf47 = torch.ops.aten.view_as_real.default(buf46) | |
| buf48 = buf47 | |
| # Topologically Sorted Source Nodes: [xk_5, setitem_2], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf48, arg78_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf46 | |
| del buf47 | |
| del buf48 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf52 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf51, reinterpret_tensor(arg78_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg79_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf53 = buf52[0] | |
| del buf52 | |
| buf58 = reinterpret_tensor(buf51, (1024, 512), (512, 1), 0); del buf51 # reuse | |
| # Topologically Sorted Source Nodes: [linear_10], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf53, (1024, 512), (512, 1), 0), _frozen_param87, out=buf58) | |
| buf59 = reinterpret_tensor(buf26, (32, 32, 512), (16384, 512, 1), 0); del buf26 # reuse | |
| buf61 = reinterpret_tensor(buf53, (32, 32, 512), (16384, 512, 1), 0); del buf53 # reuse | |
| # Topologically Sorted Source Nodes: [h, h_1, out, h_2, float_10, pow_4, mean_3, add_6, rsqrt_3, mul_11, output_7, mul_12], Original ATen: [aten.embedding, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_9.run(buf59, arg92_1, _frozen_param0, buf31, buf58, _frozen_param15, buf61, 1024, 512, grid=grid(1024), stream=stream0) | |
| del arg92_1 | |
| buf62 = buf29; del buf29 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf61, (1024, 512), (512, 1), 0), _frozen_param138, out=buf62) | |
| buf63 = reinterpret_tensor(buf34, (32, 32, 1536), (49152, 1536, 1), 0); del buf34 # reuse | |
| # Topologically Sorted Source Nodes: [silu_1, mul_13], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf62, buf63, 1572864, grid=grid(1572864), stream=stream0) | |
| buf64 = reinterpret_tensor(buf61, (1024, 512), (512, 1), 0); del buf61 # reuse | |
| # Topologically Sorted Source Nodes: [linear_13], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0), _frozen_param90, out=buf64) | |
| buf66 = reinterpret_tensor(buf58, (32, 32, 512), (16384, 512, 1), 0); del buf58 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, float_11, pow_5, mean_4, add_8, rsqrt_4, mul_14, output_8, mul_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf59, buf64, _frozen_param19, buf66, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf67 = reinterpret_tensor(buf63, (1024, 1536), (1536, 1), 0); del buf63 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf66, (1024, 512), (512, 1), 0), _frozen_param139, out=buf67) | |
| # Topologically Sorted Source Nodes: [setitem_5], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf67, arg81_1, 524288, grid=grid(524288), stream=stream0) | |
| buf68 = buf42; del buf42 # reuse | |
| # Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf67, buf68, 524288, grid=grid(524288), stream=stream0) | |
| buf75 = buf35; del buf35 # reuse | |
| # Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf67, buf75, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__2], Original ATen: [aten.view_as_complex] | |
| buf69 = torch.ops.aten.view_as_complex.default(buf68) | |
| buf70 = buf69 | |
| # Topologically Sorted Source Nodes: [mul_16], Original ATen: [aten.mul] | |
| buf71 = torch.ops.aten.mul.Tensor(buf70, _frozen_param79) | |
| del buf69 | |
| del buf70 | |
| buf72 = buf71 | |
| del buf71 | |
| # Topologically Sorted Source Nodes: [view_as_real_4], Original ATen: [aten.view_as_real] | |
| buf73 = torch.ops.aten.view_as_real.default(buf72) | |
| buf74 = buf73 | |
| buf84 = reinterpret_tensor(buf66, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf66 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf74, buf84, 524288, grid=grid(524288), stream=stream0) | |
| del buf72 | |
| del buf73 | |
| del buf74 | |
| # Topologically Sorted Source Nodes: [xk__2], Original ATen: [aten.view_as_complex] | |
| buf76 = torch.ops.aten.view_as_complex.default(buf75) | |
| buf77 = buf76 | |
| # Topologically Sorted Source Nodes: [mul_17], Original ATen: [aten.mul] | |
| buf78 = torch.ops.aten.mul.Tensor(buf77, _frozen_param79) | |
| del buf76 | |
| del buf77 | |
| buf79 = buf78 | |
| del buf78 | |
| # Topologically Sorted Source Nodes: [view_as_real_5], Original ATen: [aten.view_as_real] | |
| buf80 = torch.ops.aten.view_as_real.default(buf79) | |
| buf81 = buf80 | |
| # Topologically Sorted Source Nodes: [xk_8, setitem_4], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf81, arg80_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf79 | |
| del buf80 | |
| del buf81 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf85 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf84, reinterpret_tensor(arg80_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg81_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf86 = buf85[0] | |
| del buf85 | |
| buf91 = reinterpret_tensor(buf84, (1024, 512), (512, 1), 0); del buf84 # reuse | |
| # Topologically Sorted Source Nodes: [linear_17], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf86, (1024, 512), (512, 1), 0), _frozen_param94, out=buf91) | |
| buf93 = reinterpret_tensor(buf86, (32, 32, 512), (16384, 512, 1), 0); del buf86 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, h_3, float_15, pow_6, mean_5, add_10, rsqrt_5, mul_18, output_11, mul_19], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf59, buf64, buf91, _frozen_param24, buf93, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf94 = buf62; del buf62 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf93, (1024, 512), (512, 1), 0), _frozen_param140, out=buf94) | |
| buf95 = reinterpret_tensor(buf67, (32, 32, 1536), (49152, 1536, 1), 0); del buf67 # reuse | |
| # Topologically Sorted Source Nodes: [silu_2, mul_20], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf94, buf95, 1572864, grid=grid(1572864), stream=stream0) | |
| buf96 = reinterpret_tensor(buf93, (1024, 512), (512, 1), 0); del buf93 # reuse | |
| # Topologically Sorted Source Nodes: [linear_20], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0), _frozen_param97, out=buf96) | |
| buf98 = reinterpret_tensor(buf31, (32, 32, 512), (16384, 512, 1), 0); del buf31 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, float_16, pow_7, mean_6, add_12, rsqrt_6, mul_21, output_12, mul_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf59, buf64, buf91, buf96, _frozen_param28, buf98, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf99 = reinterpret_tensor(buf95, (1024, 1536), (1536, 1), 0); del buf95 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf98, (1024, 512), (512, 1), 0), _frozen_param141, out=buf99) | |
| # Topologically Sorted Source Nodes: [setitem_7], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf99, arg83_1, 524288, grid=grid(524288), stream=stream0) | |
| buf100 = buf75; del buf75 # reuse | |
| # Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf99, buf100, 524288, grid=grid(524288), stream=stream0) | |
| buf107 = buf68; del buf68 # reuse | |
| # Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf99, buf107, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__3], Original ATen: [aten.view_as_complex] | |
| buf101 = torch.ops.aten.view_as_complex.default(buf100) | |
| buf102 = buf101 | |
| # Topologically Sorted Source Nodes: [mul_23], Original ATen: [aten.mul] | |
| buf103 = torch.ops.aten.mul.Tensor(buf102, _frozen_param79) | |
| del buf101 | |
| del buf102 | |
| buf104 = buf103 | |
| del buf103 | |
| # Topologically Sorted Source Nodes: [view_as_real_6], Original ATen: [aten.view_as_real] | |
| buf105 = torch.ops.aten.view_as_real.default(buf104) | |
| buf106 = buf105 | |
| buf116 = reinterpret_tensor(buf98, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf98 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf106, buf116, 524288, grid=grid(524288), stream=stream0) | |
| del buf104 | |
| del buf105 | |
| del buf106 | |
| # Topologically Sorted Source Nodes: [xk__3], Original ATen: [aten.view_as_complex] | |
| buf108 = torch.ops.aten.view_as_complex.default(buf107) | |
| buf109 = buf108 | |
| # Topologically Sorted Source Nodes: [mul_24], Original ATen: [aten.mul] | |
| buf110 = torch.ops.aten.mul.Tensor(buf109, _frozen_param79) | |
| del buf108 | |
| del buf109 | |
| buf111 = buf110 | |
| del buf110 | |
| # Topologically Sorted Source Nodes: [view_as_real_7], Original ATen: [aten.view_as_real] | |
| buf112 = torch.ops.aten.view_as_real.default(buf111) | |
| buf113 = buf112 | |
| # Topologically Sorted Source Nodes: [xk_11, setitem_6], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf113, arg82_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf111 | |
| del buf112 | |
| del buf113 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf117 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf116, reinterpret_tensor(arg82_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg83_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf118 = buf117[0] | |
| del buf117 | |
| buf123 = reinterpret_tensor(buf116, (1024, 512), (512, 1), 0); del buf116 # reuse | |
| # Topologically Sorted Source Nodes: [linear_24], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf118, (1024, 512), (512, 1), 0), _frozen_param101, out=buf123) | |
| buf124 = buf59; del buf59 # reuse | |
| buf126 = reinterpret_tensor(buf118, (32, 32, 512), (16384, 512, 1), 0); del buf118 # reuse | |
| # Topologically Sorted Source Nodes: [out_1, h_3, out_2, h_4, float_20, pow_8, mean_7, add_14, rsqrt_7, mul_25, output_15, mul_26], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf124, buf64, buf91, buf96, buf123, _frozen_param33, buf126, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf123 | |
| del buf64 | |
| buf127 = buf94; del buf94 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf126, (1024, 512), (512, 1), 0), _frozen_param142, out=buf127) | |
| buf128 = reinterpret_tensor(buf99, (32, 32, 1536), (49152, 1536, 1), 0); del buf99 # reuse | |
| # Topologically Sorted Source Nodes: [silu_3, mul_27], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf127, buf128, 1572864, grid=grid(1572864), stream=stream0) | |
| buf129 = reinterpret_tensor(buf126, (1024, 512), (512, 1), 0); del buf126 # reuse | |
| # Topologically Sorted Source Nodes: [linear_27], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0), _frozen_param104, out=buf129) | |
| buf131 = reinterpret_tensor(buf96, (32, 32, 512), (16384, 512, 1), 0); del buf96 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, float_21, pow_9, mean_8, add_16, rsqrt_8, mul_28, output_16, mul_29], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf124, buf129, _frozen_param37, buf131, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf132 = reinterpret_tensor(buf128, (1024, 1536), (1536, 1), 0); del buf128 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf131, (1024, 512), (512, 1), 0), _frozen_param143, out=buf132) | |
| # Topologically Sorted Source Nodes: [setitem_9], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf132, arg85_1, 524288, grid=grid(524288), stream=stream0) | |
| buf133 = buf107; del buf107 # reuse | |
| # Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf132, buf133, 524288, grid=grid(524288), stream=stream0) | |
| buf140 = buf100; del buf100 # reuse | |
| # Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf132, buf140, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__4], Original ATen: [aten.view_as_complex] | |
| buf134 = torch.ops.aten.view_as_complex.default(buf133) | |
| buf135 = buf134 | |
| # Topologically Sorted Source Nodes: [mul_30], Original ATen: [aten.mul] | |
| buf136 = torch.ops.aten.mul.Tensor(buf135, _frozen_param79) | |
| del buf134 | |
| del buf135 | |
| buf137 = buf136 | |
| del buf136 | |
| # Topologically Sorted Source Nodes: [view_as_real_8], Original ATen: [aten.view_as_real] | |
| buf138 = torch.ops.aten.view_as_real.default(buf137) | |
| buf139 = buf138 | |
| buf149 = reinterpret_tensor(buf131, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf131 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf139, buf149, 524288, grid=grid(524288), stream=stream0) | |
| del buf137 | |
| del buf138 | |
| del buf139 | |
| # Topologically Sorted Source Nodes: [xk__4], Original ATen: [aten.view_as_complex] | |
| buf141 = torch.ops.aten.view_as_complex.default(buf140) | |
| buf142 = buf141 | |
| # Topologically Sorted Source Nodes: [mul_31], Original ATen: [aten.mul] | |
| buf143 = torch.ops.aten.mul.Tensor(buf142, _frozen_param79) | |
| del buf141 | |
| del buf142 | |
| buf144 = buf143 | |
| del buf143 | |
| # Topologically Sorted Source Nodes: [view_as_real_9], Original ATen: [aten.view_as_real] | |
| buf145 = torch.ops.aten.view_as_real.default(buf144) | |
| buf146 = buf145 | |
| # Topologically Sorted Source Nodes: [xk_14, setitem_8], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf146, arg84_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf144 | |
| del buf145 | |
| del buf146 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf150 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf149, reinterpret_tensor(arg84_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg85_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf151 = buf150[0] | |
| del buf150 | |
| buf156 = reinterpret_tensor(buf149, (1024, 512), (512, 1), 0); del buf149 # reuse | |
| # Topologically Sorted Source Nodes: [linear_31], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf151, (1024, 512), (512, 1), 0), _frozen_param108, out=buf156) | |
| buf158 = reinterpret_tensor(buf151, (32, 32, 512), (16384, 512, 1), 0); del buf151 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, h_5, float_25, pow_10, mean_9, add_18, rsqrt_9, mul_32, output_19, mul_33], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf124, buf129, buf156, _frozen_param42, buf158, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf159 = buf127; del buf127 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf158, (1024, 512), (512, 1), 0), _frozen_param144, out=buf159) | |
| buf160 = reinterpret_tensor(buf132, (32, 32, 1536), (49152, 1536, 1), 0); del buf132 # reuse | |
| # Topologically Sorted Source Nodes: [silu_4, mul_34], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf159, buf160, 1572864, grid=grid(1572864), stream=stream0) | |
| buf161 = reinterpret_tensor(buf158, (1024, 512), (512, 1), 0); del buf158 # reuse | |
| # Topologically Sorted Source Nodes: [linear_34], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0), _frozen_param111, out=buf161) | |
| buf163 = reinterpret_tensor(buf91, (32, 32, 512), (16384, 512, 1), 0); del buf91 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, h_5, out_4, float_26, pow_11, mean_10, add_20, rsqrt_10, mul_35, output_20, mul_36], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf124, buf129, buf156, buf161, _frozen_param46, buf163, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf164 = reinterpret_tensor(buf160, (1024, 1536), (1536, 1), 0); del buf160 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf163, (1024, 512), (512, 1), 0), _frozen_param145, out=buf164) | |
| # Topologically Sorted Source Nodes: [setitem_11], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf164, arg87_1, 524288, grid=grid(524288), stream=stream0) | |
| buf165 = buf140; del buf140 # reuse | |
| # Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf164, buf165, 524288, grid=grid(524288), stream=stream0) | |
| buf172 = buf133; del buf133 # reuse | |
| # Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf164, buf172, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__5], Original ATen: [aten.view_as_complex] | |
| buf166 = torch.ops.aten.view_as_complex.default(buf165) | |
| buf167 = buf166 | |
| # Topologically Sorted Source Nodes: [mul_37], Original ATen: [aten.mul] | |
| buf168 = torch.ops.aten.mul.Tensor(buf167, _frozen_param79) | |
| del buf166 | |
| del buf167 | |
| buf169 = buf168 | |
| del buf168 | |
| # Topologically Sorted Source Nodes: [view_as_real_10], Original ATen: [aten.view_as_real] | |
| buf170 = torch.ops.aten.view_as_real.default(buf169) | |
| buf171 = buf170 | |
| buf181 = reinterpret_tensor(buf163, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf163 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf171, buf181, 524288, grid=grid(524288), stream=stream0) | |
| del buf169 | |
| del buf170 | |
| del buf171 | |
| # Topologically Sorted Source Nodes: [xk__5], Original ATen: [aten.view_as_complex] | |
| buf173 = torch.ops.aten.view_as_complex.default(buf172) | |
| buf174 = buf173 | |
| # Topologically Sorted Source Nodes: [mul_38], Original ATen: [aten.mul] | |
| buf175 = torch.ops.aten.mul.Tensor(buf174, _frozen_param79) | |
| del buf173 | |
| del buf174 | |
| buf176 = buf175 | |
| del buf175 | |
| # Topologically Sorted Source Nodes: [view_as_real_11], Original ATen: [aten.view_as_real] | |
| buf177 = torch.ops.aten.view_as_real.default(buf176) | |
| buf178 = buf177 | |
| # Topologically Sorted Source Nodes: [xk_17, setitem_10], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf178, arg86_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf176 | |
| del buf177 | |
| del buf178 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf182 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf181, reinterpret_tensor(arg86_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg87_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf183 = buf182[0] | |
| del buf182 | |
| buf188 = reinterpret_tensor(buf181, (1024, 512), (512, 1), 0); del buf181 # reuse | |
| # Topologically Sorted Source Nodes: [linear_38], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf183, (1024, 512), (512, 1), 0), _frozen_param115, out=buf188) | |
| buf189 = buf124; del buf124 # reuse | |
| buf191 = reinterpret_tensor(buf183, (32, 32, 512), (16384, 512, 1), 0); del buf183 # reuse | |
| # Topologically Sorted Source Nodes: [out_3, h_5, out_4, h_6, float_30, pow_12, mean_11, add_22, rsqrt_11, mul_39, output_23, mul_40], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf189, buf129, buf156, buf161, buf188, _frozen_param51, buf191, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf129 | |
| del buf156 | |
| buf192 = buf159; del buf159 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf191, (1024, 512), (512, 1), 0), _frozen_param146, out=buf192) | |
| buf193 = reinterpret_tensor(buf164, (32, 32, 1536), (49152, 1536, 1), 0); del buf164 # reuse | |
| # Topologically Sorted Source Nodes: [silu_5, mul_41], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf192, buf193, 1572864, grid=grid(1572864), stream=stream0) | |
| buf194 = reinterpret_tensor(buf191, (1024, 512), (512, 1), 0); del buf191 # reuse | |
| # Topologically Sorted Source Nodes: [linear_41], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0), _frozen_param118, out=buf194) | |
| buf196 = reinterpret_tensor(buf188, (32, 32, 512), (16384, 512, 1), 0); del buf188 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, float_31, pow_13, mean_12, add_24, rsqrt_12, mul_42, output_24, mul_43], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf189, buf194, _frozen_param55, buf196, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf197 = reinterpret_tensor(buf193, (1024, 1536), (1536, 1), 0); del buf193 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf196, (1024, 512), (512, 1), 0), _frozen_param147, out=buf197) | |
| # Topologically Sorted Source Nodes: [setitem_13], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf197, arg89_1, 524288, grid=grid(524288), stream=stream0) | |
| buf198 = buf172; del buf172 # reuse | |
| # Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf197, buf198, 524288, grid=grid(524288), stream=stream0) | |
| buf205 = buf165; del buf165 # reuse | |
| # Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf197, buf205, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__6], Original ATen: [aten.view_as_complex] | |
| buf199 = torch.ops.aten.view_as_complex.default(buf198) | |
| buf200 = buf199 | |
| # Topologically Sorted Source Nodes: [mul_44], Original ATen: [aten.mul] | |
| buf201 = torch.ops.aten.mul.Tensor(buf200, _frozen_param79) | |
| del buf199 | |
| del buf200 | |
| buf202 = buf201 | |
| del buf201 | |
| # Topologically Sorted Source Nodes: [view_as_real_12], Original ATen: [aten.view_as_real] | |
| buf203 = torch.ops.aten.view_as_real.default(buf202) | |
| buf204 = buf203 | |
| buf214 = reinterpret_tensor(buf196, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf196 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf204, buf214, 524288, grid=grid(524288), stream=stream0) | |
| del buf202 | |
| del buf203 | |
| del buf204 | |
| # Topologically Sorted Source Nodes: [xk__6], Original ATen: [aten.view_as_complex] | |
| buf206 = torch.ops.aten.view_as_complex.default(buf205) | |
| buf207 = buf206 | |
| # Topologically Sorted Source Nodes: [mul_45], Original ATen: [aten.mul] | |
| buf208 = torch.ops.aten.mul.Tensor(buf207, _frozen_param79) | |
| del buf206 | |
| del buf207 | |
| buf209 = buf208 | |
| del buf208 | |
| # Topologically Sorted Source Nodes: [view_as_real_13], Original ATen: [aten.view_as_real] | |
| buf210 = torch.ops.aten.view_as_real.default(buf209) | |
| buf211 = buf210 | |
| # Topologically Sorted Source Nodes: [xk_20, setitem_12], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf211, arg88_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf209 | |
| del buf210 | |
| del buf211 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf215 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf214, reinterpret_tensor(arg88_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg89_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf216 = buf215[0] | |
| del buf215 | |
| buf221 = reinterpret_tensor(buf214, (1024, 512), (512, 1), 0); del buf214 # reuse | |
| # Topologically Sorted Source Nodes: [linear_45], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf216, (1024, 512), (512, 1), 0), _frozen_param122, out=buf221) | |
| buf223 = reinterpret_tensor(buf216, (32, 32, 512), (16384, 512, 1), 0); del buf216 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, h_7, float_35, pow_14, mean_13, add_26, rsqrt_13, mul_46, output_27, mul_47], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf189, buf194, buf221, _frozen_param60, buf223, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf224 = buf192; del buf192 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf223, (1024, 512), (512, 1), 0), _frozen_param148, out=buf224) | |
| buf225 = reinterpret_tensor(buf197, (32, 32, 1536), (49152, 1536, 1), 0); del buf197 # reuse | |
| # Topologically Sorted Source Nodes: [silu_6, mul_48], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf224, buf225, 1572864, grid=grid(1572864), stream=stream0) | |
| buf226 = reinterpret_tensor(buf223, (1024, 512), (512, 1), 0); del buf223 # reuse | |
| # Topologically Sorted Source Nodes: [linear_48], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0), _frozen_param125, out=buf226) | |
| buf228 = reinterpret_tensor(buf161, (32, 32, 512), (16384, 512, 1), 0); del buf161 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, h_7, out_6, float_36, pow_15, mean_14, add_28, rsqrt_14, mul_49, output_28, mul_50], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf189, buf194, buf221, buf226, _frozen_param64, buf228, 1024, 512, grid=grid(1024), stream=stream0) | |
| buf229 = reinterpret_tensor(buf225, (1024, 1536), (1536, 1), 0); del buf225 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf228, (1024, 512), (512, 1), 0), _frozen_param149, out=buf229) | |
| # Topologically Sorted Source Nodes: [setitem_15], Original ATen: [aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_copy_1.run(buf229, arg91_1, 524288, grid=grid(524288), stream=stream0) | |
| buf230 = buf205; del buf205 # reuse | |
| # Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_2.run(buf229, buf230, 524288, grid=grid(524288), stream=stream0) | |
| buf237 = buf198; del buf198 # reuse | |
| # Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_view_as_complex_3.run(buf229, buf237, 524288, grid=grid(524288), stream=stream0) | |
| # Topologically Sorted Source Nodes: [xq__7], Original ATen: [aten.view_as_complex] | |
| buf231 = torch.ops.aten.view_as_complex.default(buf230) | |
| buf232 = buf231 | |
| # Topologically Sorted Source Nodes: [mul_51], Original ATen: [aten.mul] | |
| buf233 = torch.ops.aten.mul.Tensor(buf232, _frozen_param79) | |
| del buf230 | |
| del buf231 | |
| del buf232 | |
| buf234 = buf233 | |
| del buf233 | |
| # Topologically Sorted Source Nodes: [view_as_real_14], Original ATen: [aten.view_as_real] | |
| buf235 = torch.ops.aten.view_as_real.default(buf234) | |
| buf236 = buf235 | |
| buf246 = reinterpret_tensor(buf228, (32, 8, 32, 64), (16384, 64, 512, 1), 0); del buf228 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_4.run(buf236, buf246, 524288, grid=grid(524288), stream=stream0) | |
| del buf234 | |
| del buf235 | |
| del buf236 | |
| # Topologically Sorted Source Nodes: [xk__7], Original ATen: [aten.view_as_complex] | |
| buf238 = torch.ops.aten.view_as_complex.default(buf237) | |
| buf239 = buf238 | |
| # Topologically Sorted Source Nodes: [mul_52], Original ATen: [aten.mul] | |
| buf240 = torch.ops.aten.mul.Tensor(buf239, _frozen_param79) | |
| del buf237 | |
| del buf238 | |
| del buf239 | |
| buf241 = buf240 | |
| del buf240 | |
| # Topologically Sorted Source Nodes: [view_as_real_15], Original ATen: [aten.view_as_real] | |
| buf242 = torch.ops.aten.view_as_real.default(buf241) | |
| buf243 = buf242 | |
| # Topologically Sorted Source Nodes: [xk_23, setitem_14], Original ATen: [aten._to_copy, aten.copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_copy_5.run(buf243, arg90_1, 524288, grid=grid(524288), stream=stream0) | |
| del buf241 | |
| del buf242 | |
| del buf243 | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| buf247 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf246, reinterpret_tensor(arg90_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), reinterpret_tensor(arg91_1, (32, 8, 33, 64), (524288, 64, 512, 1), 0), scale=0.125) | |
| buf248 = buf247[0] | |
| del buf247 | |
| buf253 = reinterpret_tensor(buf246, (1024, 512), (512, 1), 0); del buf246 # reuse | |
| # Topologically Sorted Source Nodes: [linear_52], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf248, (1024, 512), (512, 1), 0), _frozen_param129, out=buf253) | |
| buf254 = buf189; del buf189 # reuse | |
| buf256 = reinterpret_tensor(buf248, (32, 32, 512), (16384, 512, 1), 0); del buf248 # reuse | |
| # Topologically Sorted Source Nodes: [out_5, h_7, out_6, h_8, float_40, pow_16, mean_15, add_30, rsqrt_15, mul_53, output_31, mul_54], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf254, buf194, buf221, buf226, buf253, _frozen_param69, buf256, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf194 | |
| del buf221 | |
| del buf226 | |
| del buf253 | |
| buf257 = buf224; del buf224 # reuse | |
| # Topologically Sorted Source Nodes: [], Original ATen: [] | |
| extern_kernels.mm(reinterpret_tensor(buf256, (1024, 512), (512, 1), 0), _frozen_param150, out=buf257) | |
| buf258 = reinterpret_tensor(buf229, (32, 32, 1536), (49152, 1536, 1), 0); del buf229 # reuse | |
| # Topologically Sorted Source Nodes: [silu_7, mul_55], Original ATen: [aten.silu, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused_mul_silu_7.run(buf257, buf258, 1572864, grid=grid(1572864), stream=stream0) | |
| del buf257 | |
| buf259 = reinterpret_tensor(buf256, (1024, 512), (512, 1), 0); del buf256 # reuse | |
| # Topologically Sorted Source Nodes: [linear_55], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf258, (1024, 1536), (1536, 1), 0), _frozen_param132, out=buf259) | |
| del buf258 | |
| buf261 = buf254; del buf254 # reuse | |
| # Topologically Sorted Source Nodes: [out_7, float_41, pow_17, mean_16, add_32, rsqrt_16, mul_56, output_32, h_9], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] | |
| stream0 = get_raw_stream(0) | |
| triton_per_fused__to_copy_add_mean_mul_pow_rsqrt_14.run(buf261, buf259, _frozen_param73, 1024, 512, grid=grid(1024), stream=stream0) | |
| del buf259 | |
| buf262 = empty_strided_cuda((32, 32000), (32000, 1), torch.bfloat16) | |
| # Topologically Sorted Source Nodes: [output_33], Original ATen: [aten.mm] | |
| extern_kernels.mm(reinterpret_tensor(buf261, (32, 512), (16384, 1), 15872), _frozen_param134, out=buf262) | |
| del buf261 | |
| buf263 = empty_strided_cuda((32, 32000), (32000, 1), torch.float32) | |
| # Topologically Sorted Source Nodes: [float_42], Original ATen: [aten._to_copy] | |
| stream0 = get_raw_stream(0) | |
| triton_poi_fused__to_copy_15.run(buf262, buf263, 1024000, grid=grid(1024000), stream=stream0) | |
| del buf262 | |
| return (buf263, arg77_1, arg76_1, arg79_1, arg78_1, arg81_1, arg80_1, arg83_1, arg82_1, arg85_1, arg84_1, arg87_1, arg86_1, arg89_1, arg88_1, arg91_1, arg90_1, ) | |
| def benchmark_compiled_module(times=10, repeat=10): | |
| from torch._dynamo.testing import rand_strided | |
| from torch._inductor.utils import print_performance | |
| global _frozen_param0 | |
| _frozen_param0 = rand_strided((32000, 512), (512, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param1 | |
| _frozen_param1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param6 | |
| _frozen_param6 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param10 | |
| _frozen_param10 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param15 | |
| _frozen_param15 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param19 | |
| _frozen_param19 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param24 | |
| _frozen_param24 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param28 | |
| _frozen_param28 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param33 | |
| _frozen_param33 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param37 | |
| _frozen_param37 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param42 | |
| _frozen_param42 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param46 | |
| _frozen_param46 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param51 | |
| _frozen_param51 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param55 | |
| _frozen_param55 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param60 | |
| _frozen_param60 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param64 | |
| _frozen_param64 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param69 | |
| _frozen_param69 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param73 | |
| _frozen_param73 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param135 | |
| _frozen_param135 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param79 | |
| _frozen_param79 = rand_strided((1, 32, 1, 32), (1024, 32, 32, 1), device='cuda:0', dtype=torch.complex64) | |
| global _frozen_param80 | |
| _frozen_param80 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param136 | |
| _frozen_param136 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param83 | |
| _frozen_param83 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param137 | |
| _frozen_param137 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param87 | |
| _frozen_param87 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param138 | |
| _frozen_param138 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param90 | |
| _frozen_param90 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param139 | |
| _frozen_param139 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param94 | |
| _frozen_param94 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param140 | |
| _frozen_param140 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param97 | |
| _frozen_param97 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param141 | |
| _frozen_param141 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param101 | |
| _frozen_param101 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param142 | |
| _frozen_param142 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param104 | |
| _frozen_param104 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param143 | |
| _frozen_param143 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param108 | |
| _frozen_param108 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param144 | |
| _frozen_param144 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param111 | |
| _frozen_param111 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param145 | |
| _frozen_param145 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param115 | |
| _frozen_param115 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param146 | |
| _frozen_param146 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param118 | |
| _frozen_param118 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param147 | |
| _frozen_param147 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param122 | |
| _frozen_param122 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param148 | |
| _frozen_param148 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param125 | |
| _frozen_param125 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param149 | |
| _frozen_param149 = rand_strided((512, 1536), (1536, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param129 | |
| _frozen_param129 = rand_strided((512, 512), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param150 | |
| _frozen_param150 = rand_strided((512, 3072), (3072, 1), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param132 | |
| _frozen_param132 = rand_strided((1536, 512), (1, 1536), device='cuda:0', dtype=torch.bfloat16) | |
| global _frozen_param134 | |
| _frozen_param134 = rand_strided((512, 32000), (1, 512), device='cuda:0', dtype=torch.bfloat16) | |
| arg76_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg77_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg78_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg79_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg80_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg81_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg82_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg83_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg84_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg85_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg86_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg87_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg88_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg89_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg90_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg91_1 = rand_strided((32, 1024, 8, 64), (524288, 512, 64, 1), device='cuda:0', dtype=torch.bfloat16) | |
| arg92_1 = rand_strided((32, 32), (32, 1), device='cuda:0', dtype=torch.int32) | |
| fn = lambda: call([arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1]) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| compiled_module_main('llama', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment