Created
December 25, 2024 05:07
-
-
Save leslie-fang-intel/12d32bdec1235ccd6dd8f92e1a4f703a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # AOT ID: ['0_inference'] | |
| from ctypes import c_void_p, c_long, c_int | |
| import torch | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| from math import inf, nan | |
| from torch._inductor.hooks import run_intermediate_hooks | |
| from torch._inductor.utils import maybe_profile | |
| from torch._inductor.codegen.memory_planning import _align as align | |
| from torch import device, empty_strided | |
| from torch._inductor.async_compile import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
| aten = torch.ops.aten | |
| inductor_ops = torch.ops.inductor | |
| _quantized = torch.ops._quantized | |
| assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
| empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
| empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
| empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
| reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
| alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
| async_compile = AsyncCompile() | |
| empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
| _frozen_param6 = None # device(type='cpu') torch.bfloat16 (1024, 512) (1, 0) 7fec343e7d30 | |
| _frozen_param7 = None # device(type='cpu') torch.bfloat16 (1024, 512) (1, 0) 7fec343e7ce0 | |
| _frozen_param8 = None # device(type='cpu') torch.bfloat16 (1024, 512) (1, 0) 7fec343e7d80 | |
| constant3 = None # device(type='cpu') torch.bfloat16 (64, 512, 16) (8192, 16, 1) 7fec338031f0 | |
| constant4 = None # device(type='cpu') torch.bfloat16 (64, 512, 16) (8192, 16, 1) 7fec338031a0 | |
| constant5 = None # device(type='cpu') torch.bfloat16 (64, 512, 16) (8192, 16, 1) 7fec33eceed0 | |
| cpp_fused_0 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], ''' | |
| #include "/tmp/torchinductor_leslie/2r/c2rnilspx43ivnzu4uieul65kx65dfhfbptbh5og4wk6rqebuxoo.h" | |
| #include <c10/util/Unroll.h> | |
| #include <torch/csrc/inductor/aoti_torch/c/shim.h> | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_48_1( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 48 / 16, 1, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 48 / 16, 1, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(2, C + 32 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| _tile_zero(1); | |
| _tile_zero(2); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(3, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 3, 6); | |
| _tile_stream_loadd(4, A + 16 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_dpbf16ps(1, 4, 6); | |
| _tile_stream_loadd(5, A + 32 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_dpbf16ps(2, 5, 6); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(2, C + 32 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 48 / 16, 1, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_32_1( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 32 / 16, 1, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 1, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| _tile_zero(1); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(4, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 2, 4); | |
| _tile_stream_loadd(3, A + 16 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_dpbf16ps(1, 3, 4); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 1, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_16_1( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 16 / 16, 1, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 1, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(1, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(2, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 1, 2); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 1, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t M, | |
| int64_t N, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc | |
| ) { | |
| AOTI_TORCH_CHECK(N % 16 == 0, "N dimension must be multiple of 16"); | |
| AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); | |
| // The ldb would not be block_n if N != block_n | |
| const int64_t updated_ldb = ldb; | |
| // TODO(jgong5): loop unroll for M and N | |
| for (int64_t n = 0; n < N; n += 16) { | |
| for (int64_t m = 0; m < M; m += 48) { | |
| int64_t block_m = std::min<int64_t>(M - m, 48); | |
| int64_t m_tail = m; | |
| if (block_m >= 48) { | |
| kernel_micro_gemm_amx_kernel_48_1<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 48; | |
| m_tail += 48; | |
| } | |
| else | |
| if (block_m >= 32) { | |
| kernel_micro_gemm_amx_kernel_32_1<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 32; | |
| m_tail += 32; | |
| } | |
| else | |
| if (block_m >= 16) { | |
| kernel_micro_gemm_amx_kernel_16_1<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 16; | |
| m_tail += 16; | |
| } | |
| if (block_m > 0) { | |
| kernel_micro_gemm_amx_kernel_16_1<accum>( | |
| amx_state, | |
| A + m_tail * lda, | |
| B + n, | |
| C + m_tail * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| block_m | |
| ); | |
| } | |
| } | |
| } | |
| } | |
| extern "C" | |
| void kernel(const bfloat16* X0, const bfloat16* W0, const bfloat16* W1, const bfloat16* W2, bfloat16* Y0, bfloat16* Y1, bfloat16* Y2) | |
| { | |
| constexpr int64_t num_threads = 56; | |
| constexpr int64_t N = 1024; | |
| constexpr int64_t K = 512; | |
| constexpr int64_t Mr = 48; | |
| constexpr int64_t Nr = 16; | |
| constexpr int64_t Kr = 32; | |
| constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; | |
| constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; | |
| constexpr int64_t M = static_cast<int64_t>(4L); | |
| constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; | |
| constexpr int64_t Mt_blocks = 1; | |
| constexpr int64_t Nt_blocks = 2; | |
| constexpr int64_t Kt_blocks = 16; | |
| constexpr int64_t Mc_blocks = 1; | |
| constexpr int64_t Nc_blocks = 1; | |
| constexpr int64_t Kc_blocks = 16; | |
| constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; | |
| constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; | |
| constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; | |
| constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; | |
| constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; | |
| // make sure all partitions are assigned | |
| AOTI_TORCH_CHECK( | |
| Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks, | |
| "Not all partitions are assigned." | |
| ); | |
| #pragma omp parallel num_threads(56) | |
| { | |
| const int tid = omp_get_thread_num(); | |
| const int64_t k_group_id = tid / num_Kt_blocks; | |
| const int64_t k_slice_id = tid % num_Kt_blocks; | |
| const int64_t n_group_id = k_group_id / num_Nt_blocks; | |
| const int64_t n_slice_id = k_group_id % num_Nt_blocks; | |
| const int64_t k_block_start = k_slice_id * Kt_blocks; | |
| const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); | |
| const int64_t n_block_start = n_slice_id * Nt_blocks; | |
| const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); | |
| const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); | |
| const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); | |
| const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; | |
| AMXState amx_state; | |
| auto _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_0 = _local_acc_buf_0.get(); | |
| auto _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_1 = _local_acc_buf_1.get(); | |
| auto _local_acc_buf_2 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_2 = _local_acc_buf_2.get(); | |
| for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { | |
| const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; | |
| const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; | |
| const int64_t m_start = mc * Mr; | |
| const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); | |
| const int64_t m_size = m_end - m_start; | |
| for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { | |
| const int64_t n_start = nc * Nr; | |
| const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); | |
| const int64_t n_size = n_end - n_start; | |
| // NB: assume we pad N, nc_block_end won't exceed padded N here. | |
| const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); | |
| if (_local_acc_buf_0 == nullptr) { _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_0 = _local_acc_buf_0.get(); } | |
| if (_local_acc_buf_1 == nullptr) { _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_1 = _local_acc_buf_1.get(); } | |
| if (_local_acc_buf_2 == nullptr) { _local_acc_buf_2 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_2 = _local_acc_buf_2.get(); } | |
| for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { | |
| int64_t k_start = kc * Kr; | |
| int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); | |
| for (int64_t nci = nc; nci < nc_block_end; nci++) { | |
| if (kc == k_block_start) { | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W0[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W1[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W2[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_2[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } else { | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W0[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W1[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W2[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_2[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } | |
| } | |
| } | |
| { | |
| { | |
| #pragma GCC ivdep | |
| for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) | |
| { | |
| for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L)) | |
| { | |
| { | |
| if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp4 = at::vec::Vectorized<float>::loadu(local_acc_buf_2 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp1 = at::vec::convert<bfloat16>(tmp0); | |
| auto tmp3 = at::vec::convert<bfloat16>(tmp2); | |
| auto tmp5 = at::vec::convert<bfloat16>(tmp4); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(16)); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(16)); | |
| tmp5.store(Y2 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(16)); | |
| } | |
| if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start)))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp4 = at::vec::Vectorized<float>::loadu(local_acc_buf_2 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp1 = at::vec::convert<bfloat16>(tmp0); | |
| auto tmp3 = at::vec::convert<bfloat16>(tmp2); | |
| auto tmp5 = at::vec::convert<bfloat16>(tmp4); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| tmp5.store(Y2 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| amx_state.release([]() { _tile_release(); }); | |
| } | |
| } | |
| ''') | |
| async_compile.wait(globals()) | |
| del async_compile | |
| def call(args): | |
| arg3_1, = args | |
| args.clear() | |
| assert_size_stride(arg3_1, (4, 512), (512, 1)) | |
| buf1 = empty_strided_cpu((4, 1024), (1024, 1), torch.bfloat16) | |
| buf2 = empty_strided_cpu((4, 1024), (1024, 1), torch.bfloat16) | |
| buf3 = empty_strided_cpu((4, 1024), (1024, 1), torch.bfloat16) | |
| buf0 = [buf1, buf2, buf3, ] | |
| cpp_fused_0(arg3_1, constant3, constant4, constant5, buf1, buf2, buf3) | |
| del arg3_1 | |
| buf1 = buf0[0] | |
| buf2 = buf0[1] | |
| buf3 = buf0[2] | |
| return (buf1, buf2, buf3, ) | |
| def benchmark_compiled_module(times=10, repeat=10): | |
| from torch._dynamo.testing import rand_strided | |
| from torch._inductor.utils import print_performance | |
| global _frozen_param6 | |
| _frozen_param6 = rand_strided((1024, 512), (1, 0), device='cpu', dtype=torch.bfloat16) | |
| global _frozen_param7 | |
| _frozen_param7 = rand_strided((1024, 512), (1, 0), device='cpu', dtype=torch.bfloat16) | |
| global _frozen_param8 | |
| _frozen_param8 = rand_strided((1024, 512), (1, 0), device='cpu', dtype=torch.bfloat16) | |
| global constant3 | |
| constant3 = rand_strided((64, 512, 16), (8192, 16, 1), device='cpu', dtype=torch.bfloat16) | |
| global constant4 | |
| constant4 = rand_strided((64, 512, 16), (8192, 16, 1), device='cpu', dtype=torch.bfloat16) | |
| global constant5 | |
| constant5 = rand_strided((64, 512, 16), (8192, 16, 1), device='cpu', dtype=torch.bfloat16) | |
| arg3_1 = rand_strided((4, 512), (512, 1), device='cpu', dtype=torch.bfloat16) | |
| fn = lambda: call([arg3_1]) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment