Last active
January 7, 2025 07:22
-
-
Save leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch._inductor.codecache import CppWrapperCodeCache | |
| cpp_wrapper_src = ( | |
| ''' | |
| #include <optional> | |
| #include <Python.h> | |
| #define PYBIND11_SIMPLE_GIL_MANAGEMENT | |
| #include <pybind11/gil.h> | |
| namespace py = pybind11; | |
| class RAIIPyObject { | |
| public: | |
| RAIIPyObject() : obj_(nullptr) {} | |
| RAIIPyObject(PyObject* obj) : obj_(obj) {} | |
| ~RAIIPyObject() { | |
| Py_XDECREF(obj_); | |
| } | |
| RAIIPyObject& operator=(const RAIIPyObject& other) { | |
| if (this != &other) { | |
| Py_XDECREF(obj_); | |
| obj_ = other.obj_; | |
| Py_XINCREF(obj_); | |
| } | |
| return *this; | |
| } | |
| operator PyObject*() { | |
| return obj_; | |
| } | |
| PyObject* get() { | |
| return obj_; | |
| } | |
| private: | |
| PyObject* obj_; | |
| }; | |
| #include <torch/csrc/inductor/aoti_runtime/device_utils.h> | |
| #include <torch/csrc/inductor/aoti_runtime/utils.h> | |
| using namespace torch::aot_inductor; | |
| #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h> | |
| #include <torch/csrc/inductor/aoti_runtime/thread_local.h> | |
| #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h> | |
| #include <torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h> | |
| #include <c10/util/generic_math.h> | |
| typedef at::Half half; | |
| typedef at::BFloat16 bfloat16; | |
| // Round up to the nearest multiple of 64 | |
| [[maybe_unused]] static int64_t align(int64_t nbytes) { | |
| return (nbytes + 64 - 1) & -64; | |
| } | |
| // _frozen_param4 device(type='cpu') torch.float16 (32, 52) (1, 0) 7f20b991d620 | |
| // _frozen_param5 device(type='cpu') torch.float16 (32, 52) (1, 0) 7f20b991ee80 | |
| // constant2 device(type='cpu') torch.float16 (1, 52, 32) (1664, 32, 1) 7f20bb9676a0 | |
| // constant3 device(type='cpu') torch.float16 (1, 52, 32) (1664, 32, 1) 7f20b9903970 | |
| #include "/tmp/torchinductor_leslie/db/cdb7hyptwxpzukwd42x4ajfjlgrpum4a4htdd6lhb65apclsmno4.h" | |
| #include <c10/util/Unroll.h> | |
| #include <torch/csrc/inductor/aoti_torch/c/shim.h> | |
| template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum> | |
| inline void cpp_fused_0_micro_gemm_kernel( | |
| const half* __restrict__ A, | |
| const half* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc | |
| ) { | |
| using Vectorized = at::vec::Vectorized<float>; | |
| using VectorizedIn = at::vec::Vectorized<half>; | |
| constexpr auto VLEN = Vectorized::size(); | |
| constexpr auto ROWS = BLOCK_M; | |
| constexpr auto COLS = BLOCK_N / VLEN; | |
| Vectorized va; | |
| at::vec::VectorizedN<float, COLS> vb; | |
| at::vec::VectorizedN<float, ROWS*COLS> vc; | |
| auto loadc = [&](auto i) { | |
| if constexpr (accum) { | |
| constexpr int row = i / COLS; | |
| constexpr int col = i % COLS; | |
| vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); | |
| } else { | |
| vc[i] = Vectorized(0.0f); | |
| } | |
| }; | |
| c10::ForcedUnroll<ROWS * COLS>{}(loadc); | |
| auto compute = [&, COLS](auto i, int k) { | |
| constexpr int row = i / COLS; | |
| constexpr int col = i % COLS; | |
| if constexpr (col == 0) { | |
| va = Vectorized(static_cast<float>(A[row * lda + k])); | |
| } | |
| if constexpr (row == 0) { | |
| auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); | |
| vb[col] = at::vec::convert<float>(b); | |
| } | |
| constexpr int idx = row * COLS + col; | |
| vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); | |
| }; | |
| for (int k = 0; k < K; ++k) { | |
| c10::ForcedUnroll<ROWS * COLS>{}(compute, k); | |
| } | |
| // store to C | |
| auto storec = [&](auto i) { | |
| constexpr int row = i / COLS; | |
| constexpr int col = i % COLS; | |
| vc[i].store(C + row * ldc + col * VLEN); | |
| }; | |
| c10::ForcedUnroll<ROWS * COLS>{}(storec); | |
| } | |
| template <bool accum> | |
| inline void cpp_fused_0_micro_gemm( | |
| const half* __restrict__ A, | |
| const half* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t M, | |
| int64_t N, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc | |
| ) { | |
| TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); | |
| TORCH_CHECK(K % 1 == 0, "K dimension must be multiple of 1"); | |
| // TODO(jgong5): loop unroll for M and N | |
| for (int64_t m = 0; m < M; m += 8) { | |
| int64_t block_m = std::min<int64_t>(M - m, 8); | |
| for (int64_t n = 0; n < N; n += 32) { | |
| if (block_m == 8) { | |
| cpp_fused_0_micro_gemm_kernel<8, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| } else { | |
| switch (block_m) { | |
| case 7: | |
| cpp_fused_0_micro_gemm_kernel<7, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 6: | |
| cpp_fused_0_micro_gemm_kernel<6, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 5: | |
| cpp_fused_0_micro_gemm_kernel<5, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 4: | |
| cpp_fused_0_micro_gemm_kernel<4, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 3: | |
| cpp_fused_0_micro_gemm_kernel<3, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 2: | |
| cpp_fused_0_micro_gemm_kernel<2, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 1: | |
| cpp_fused_0_micro_gemm_kernel<1, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| default: | |
| TORCH_CHECK(false, "Unsupported block_m: 8"); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| extern "C" | |
| void cpp_fused_0(const half* X0, const half* W0, const half* W1, half* Y0, half* Y1) | |
| { | |
| constexpr int64_t num_threads = 56; | |
| constexpr int64_t N = 32; | |
| constexpr int64_t K = 52; | |
| constexpr int64_t Mr = 8; | |
| constexpr int64_t Nr = 32; | |
| constexpr int64_t Kr = 1; | |
| constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; | |
| constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; | |
| constexpr int64_t M = static_cast<int64_t>(16L); | |
| constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; | |
| constexpr int64_t Mt_blocks = 1; | |
| constexpr int64_t Nt_blocks = 1; | |
| constexpr int64_t Kt_blocks = 52; | |
| constexpr int64_t Mc_blocks = 1; | |
| constexpr int64_t Nc_blocks = 1; | |
| constexpr int64_t Kc_blocks = 52; | |
| constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; | |
| constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; | |
| constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; | |
| constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; | |
| constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; | |
| // make sure all partitions are assigned | |
| TORCH_CHECK( | |
| Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks, | |
| "Not all partitions are assigned." | |
| ); | |
| #pragma omp parallel num_threads(56) | |
| { | |
| const int tid = omp_get_thread_num(); | |
| const int64_t k_group_id = tid / num_Kt_blocks; | |
| const int64_t k_slice_id = tid % num_Kt_blocks; | |
| const int64_t n_group_id = k_group_id / num_Nt_blocks; | |
| const int64_t n_slice_id = k_group_id % num_Nt_blocks; | |
| const int64_t k_block_start = k_slice_id * Kt_blocks; | |
| const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); | |
| const int64_t n_block_start = n_slice_id * Nt_blocks; | |
| const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); | |
| const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); | |
| const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); | |
| const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; | |
| auto _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_0 = _local_acc_buf_0.get(); | |
| auto _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_1 = _local_acc_buf_1.get(); | |
| for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { | |
| const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; | |
| const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; | |
| const int64_t m_start = mc * Mr; | |
| const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); | |
| const int64_t m_size = m_end - m_start; | |
| for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { | |
| const int64_t n_start = nc * Nr; | |
| const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); | |
| const int64_t n_size = n_end - n_start; | |
| // NB: assume we pad N, nc_block_end won't exceed padded N here. | |
| const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); | |
| if (_local_acc_buf_0 == nullptr) { _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_0 = _local_acc_buf_0.get(); } | |
| if (_local_acc_buf_1 == nullptr) { _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_1 = _local_acc_buf_1.get(); } | |
| for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { | |
| int64_t k_start = kc * Kr; | |
| int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); | |
| for (int64_t nci = nc; nci < nc_block_end; nci++) { | |
| if (kc == k_block_start) { | |
| cpp_fused_0_micro_gemm<static_cast<bool>(false)>( | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| cpp_fused_0_micro_gemm<static_cast<bool>(false)>( | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } else { | |
| cpp_fused_0_micro_gemm<static_cast<bool>(true)>( | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| cpp_fused_0_micro_gemm<static_cast<bool>(true)>( | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } | |
| } | |
| } | |
| { | |
| { | |
| #pragma GCC ivdep | |
| for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) | |
| { | |
| for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L)) | |
| { | |
| { | |
| if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp1 = at::vec::convert<half>(tmp0); | |
| auto tmp3 = at::vec::convert<half>(tmp2); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16)); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16)); | |
| } | |
| if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start)))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp1 = at::vec::convert<half>(tmp0); | |
| auto tmp3 = at::vec::convert<half>(tmp2); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| CACHE_TORCH_DTYPE(float16); | |
| CACHE_TORCH_DEVICE(cpu); | |
| void inductor_entry_impl( | |
| AtenTensorHandle* | |
| input_handles, // array of input AtenTensorHandle; handles | |
| // are stolen; the array itself is borrowed | |
| AtenTensorHandle* | |
| output_handles // array for writing output AtenTensorHandle; handles | |
| // will be stolen by the caller; the array itself is | |
| // borrowed) | |
| ) { | |
| py::gil_scoped_release release; | |
| auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 5); | |
| auto arg2_1 = std::move(inputs[0]); | |
| [[maybe_unused]] auto _frozen_param4 = std::move(inputs[1]); | |
| [[maybe_unused]] auto _frozen_param5 = std::move(inputs[2]); | |
| [[maybe_unused]] auto constant2 = std::move(inputs[3]); | |
| [[maybe_unused]] auto constant3 = std::move(inputs[4]); | |
| static constexpr int64_t int_array_0[] = {16L, 32L}; | |
| static constexpr int64_t int_array_1[] = {32L, 1L}; | |
| AtenTensorHandle buf1_handle; | |
| AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float16, cached_torch_device_type_cpu, 0, &buf1_handle)); | |
| RAIIAtenTensorHandle buf1(buf1_handle); | |
| AtenTensorHandle buf2_handle; | |
| AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float16, cached_torch_device_type_cpu, 0, &buf2_handle)); | |
| RAIIAtenTensorHandle buf2(buf2_handle); | |
| cpp_fused_0((const half*)(arg2_1.data_ptr()), (const half*)(constant2.data_ptr()), (const half*)(constant3.data_ptr()), (half*)(buf1.data_ptr()), (half*)(buf2.data_ptr())); | |
| arg2_1.reset(); | |
| output_handles[0] = buf1.release(); | |
| output_handles[1] = buf2.release(); | |
| } // inductor_entry_impl | |
| ''' | |
| ) | |
| inductor_entry = CppWrapperCodeCache.load_pybinding( | |
| ["std::vector<AtenTensorHandle>"], cpp_wrapper_src, "cpu", 2) | |
| def _wrap_func(f): | |
| def g(args): | |
| input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args] | |
| constants_tensor = [_frozen_param4, _frozen_param5, constant2, constant3] | |
| input_tensors.extend(constants_tensor) | |
| input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) | |
| args.clear() | |
| output_handles = f(input_handles) | |
| output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) | |
| return output_tensors | |
| return g | |
| call = _wrap_func(inductor_entry) | |
| def benchmark_compiled_module(times=10, repeat=10): | |
| from torch._dynamo.testing import rand_strided | |
| from torch._inductor.utils import print_performance | |
| global _frozen_param4 | |
| _frozen_param4 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.float16) | |
| global _frozen_param5 | |
| _frozen_param5 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.float16) | |
| global constant2 | |
| constant2 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.float16) | |
| global constant3 | |
| constant3 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.float16) | |
| arg2_1 = rand_strided((16, 52), (52, 1), device='cpu', dtype=torch.float16) | |
| fn = lambda: call([arg2_1]) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| compiled_module_main('None', benchmark_compiled_module) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # AOT ID: ['0_inference'] | |
| from ctypes import c_void_p, c_long, c_int | |
| import torch | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| from math import inf, nan | |
| from torch._inductor.hooks import run_intermediate_hooks | |
| from torch._inductor.utils import maybe_profile | |
| from torch._inductor.codegen.memory_planning import _align as align | |
| from torch import device, empty_strided | |
| from torch._inductor.async_compile import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
| aten = torch.ops.aten | |
| inductor_ops = torch.ops.inductor | |
| _quantized = torch.ops._quantized | |
| assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
| empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
| empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
| empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
| reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
| alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
| async_compile = AsyncCompile() | |
| empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
| _frozen_param4 = None # device(type='cpu') torch.bfloat16 (32, 52) (1, 0) 7f6281a2b970 | |
| _frozen_param5 = None # device(type='cpu') torch.bfloat16 (32, 52) (1, 0) 7f6281a2bab0 | |
| constant2 = None # device(type='cpu') torch.bfloat16 (1, 52, 32) (1664, 32, 1) 7f62751fb3d0 | |
| constant3 = None # device(type='cpu') torch.bfloat16 (1, 52, 32) (1664, 32, 1) 7f62751fb380 | |
| cpp_fused_0 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*', 'const int64_t'], ''' | |
| #include "/tmp/torchinductor_leslie/db/cdb7hyptwxpzukwd42x4ajfjlgrpum4a4htdd6lhb65apclsmno4.h" | |
| #include <c10/util/Unroll.h> | |
| #include <torch/csrc/inductor/aoti_torch/c/shim.h> | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_32_2( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); | |
| _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| _tile_zero(1); | |
| _tile_zero(2); | |
| _tile_zero(3); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 4, 6); | |
| _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(1, 4, 7); | |
| _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_dpbf16ps(2, 5, 6); | |
| _tile_dpbf16ps(3, 5, 7); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); | |
| _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_16_2( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| _tile_zero(1); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 2, 3); | |
| _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(1, 2, 4); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t M, | |
| int64_t N, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc | |
| ) { | |
| TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); | |
| TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); | |
| // The ldb would not be block_n if N != block_n | |
| const int64_t updated_ldb = ldb; | |
| // TODO(jgong5): loop unroll for M and N | |
| for (int64_t n = 0; n < N; n += 32) { | |
| for (int64_t m = 0; m < M; m += 32) { | |
| int64_t block_m = std::min<int64_t>(M - m, 32); | |
| int64_t m_tail = m; | |
| if (block_m >= 32) { | |
| kernel_micro_gemm_amx_kernel_32_2<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 32; | |
| m_tail += 32; | |
| } | |
| else | |
| if (block_m >= 16) { | |
| kernel_micro_gemm_amx_kernel_16_2<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 16; | |
| m_tail += 16; | |
| } | |
| if (block_m > 0) { | |
| kernel_micro_gemm_amx_kernel_16_2<accum>( | |
| amx_state, | |
| A + m_tail * lda, | |
| B + n, | |
| C + m_tail * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| block_m | |
| ); | |
| } | |
| } | |
| } | |
| } | |
| extern "C" | |
| void kernel(const bfloat16* X0, const bfloat16* W0, const bfloat16* W1, bfloat16* Y0, bfloat16* Y1, const int64_t ks0) | |
| { | |
| constexpr int64_t num_threads = 56; | |
| constexpr int64_t N = 32; | |
| constexpr int64_t K = 52; | |
| constexpr int64_t Mr = 32; | |
| constexpr int64_t Nr = 32; | |
| constexpr int64_t Kr = 32; | |
| constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; | |
| constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; | |
| const int64_t M = static_cast<int64_t>(ks0); | |
| const int64_t Mr_blocks = (M + Mr - 1) / Mr; | |
| int64_t Mt_blocks, Nt_blocks, Kt_blocks; | |
| mm_get_thread_blocking(num_threads, 1, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); | |
| int64_t Mc_blocks, Nc_blocks, Kc_blocks; | |
| uint32_t L1_cache_size = 49152; | |
| uint32_t L2_cache_size = 2097152; | |
| mm_get_cache_blocking<bfloat16, bfloat16>( | |
| num_threads, | |
| M, | |
| N, | |
| K, | |
| Mr, | |
| Nr, | |
| Kr, | |
| Mt_blocks, | |
| Nt_blocks, | |
| Kt_blocks, | |
| Mc_blocks, | |
| Nc_blocks, | |
| Kc_blocks, | |
| L1_cache_size, | |
| L2_cache_size | |
| ); | |
| const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; | |
| const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; | |
| const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; | |
| const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; | |
| const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; | |
| // make sure all partitions are assigned | |
| TORCH_CHECK( | |
| Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks, | |
| "Not all partitions are assigned." | |
| ); | |
| #pragma omp parallel num_threads(56) | |
| { | |
| const int tid = omp_get_thread_num(); | |
| const int64_t k_group_id = tid / num_Kt_blocks; | |
| const int64_t k_slice_id = tid % num_Kt_blocks; | |
| const int64_t n_group_id = k_group_id / num_Nt_blocks; | |
| const int64_t n_slice_id = k_group_id % num_Nt_blocks; | |
| const int64_t k_block_start = k_slice_id * Kt_blocks; | |
| const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); | |
| const int64_t n_block_start = n_slice_id * Nt_blocks; | |
| const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); | |
| const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); | |
| const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); | |
| const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; | |
| AMXState amx_state; | |
| auto _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_0 = _local_acc_buf_0.get(); | |
| auto _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_1 = _local_acc_buf_1.get(); | |
| for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { | |
| const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; | |
| const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; | |
| const int64_t m_start = mc * Mr; | |
| const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); | |
| const int64_t m_size = m_end - m_start; | |
| for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { | |
| const int64_t n_start = nc * Nr; | |
| const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); | |
| const int64_t n_size = n_end - n_start; | |
| // NB: assume we pad N, nc_block_end won't exceed padded N here. | |
| const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); | |
| if (_local_acc_buf_0 == nullptr) { _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_0 = _local_acc_buf_0.get(); } | |
| if (_local_acc_buf_1 == nullptr) { _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_1 = _local_acc_buf_1.get(); } | |
| for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { | |
| int64_t k_start = kc * Kr; | |
| int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); | |
| for (int64_t nci = nc; nci < nc_block_end; nci++) { | |
| if (kc == k_block_start) { | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } else { | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 52L*m_start)]), | |
| &(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(52L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } | |
| } | |
| } | |
| { | |
| { | |
| #pragma GCC ivdep | |
| for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) | |
| { | |
| for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L)) | |
| { | |
| { | |
| if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp1 = at::vec::convert<bfloat16>(tmp0); | |
| auto tmp3 = at::vec::convert<bfloat16>(tmp2); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16)); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16)); | |
| } | |
| if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start)))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp1 = at::vec::convert<bfloat16>(tmp0); | |
| auto tmp3 = at::vec::convert<bfloat16>(tmp2); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| amx_state.release([]() { _tile_release(); }); | |
| } | |
| } | |
| ''') | |
| async_compile.wait(globals()) | |
| del async_compile | |
| def call(args): | |
| arg2_1, arg3_1 = args | |
| args.clear() | |
| s0 = arg2_1 | |
| assert_size_stride(arg3_1, (s0, 52), (52, 1)) | |
| buf1 = empty_strided_cpu((s0, 32), (32, 1), torch.bfloat16) | |
| buf2 = empty_strided_cpu((s0, 32), (32, 1), torch.bfloat16) | |
| cpp_fused_0(arg3_1, constant2, constant3, buf1, buf2, s0) | |
| del arg3_1 | |
| return (buf1, buf2, ) | |
| def benchmark_compiled_module(times=10, repeat=10): | |
| from torch._dynamo.testing import rand_strided | |
| from torch._inductor.utils import print_performance | |
| global _frozen_param4 | |
| _frozen_param4 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.bfloat16) | |
| global _frozen_param5 | |
| _frozen_param5 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.bfloat16) | |
| global constant2 | |
| constant2 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.bfloat16) | |
| global constant3 | |
| constant3 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.bfloat16) | |
| arg2_1 = 16 | |
| arg3_1 = rand_strided((16, 52), (52, 1), device='cpu', dtype=torch.bfloat16) | |
| fn = lambda: call([arg2_1, arg3_1]) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| compiled_module_main('None', benchmark_compiled_module) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # AOT ID: ['0_inference'] | |
| from ctypes import c_void_p, c_long, c_int | |
| import torch | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| from math import inf, nan | |
| from torch._inductor.hooks import run_intermediate_hooks | |
| from torch._inductor.utils import maybe_profile | |
| from torch._inductor.codegen.memory_planning import _align as align | |
| from torch import device, empty_strided | |
| from torch._inductor.async_compile import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
| aten = torch.ops.aten | |
| inductor_ops = torch.ops.inductor | |
| _quantized = torch.ops._quantized | |
| assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
| empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
| empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
| empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
| reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
| alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
| async_compile = AsyncCompile() | |
| empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
| _frozen_param4 = None # device(type='cpu') torch.bfloat16 (1024, 512) (1, 0) 7f9d8b5cf8d0 | |
| _frozen_param5 = None # device(type='cpu') torch.bfloat16 (1024, 512) (1, 0) 7f9d8b5ceb60 | |
| constant2 = None # device(type='cpu') torch.bfloat16 (64, 512, 16) (8192, 16, 1) 7f9d8abc4e00 | |
| constant3 = None # device(type='cpu') torch.bfloat16 (64, 512, 16) (8192, 16, 1) 7f9d8abc4db0 | |
| cpp_fused_0 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*'], ''' | |
| #include "/tmp/torchinductor_leslie/db/cdb7hyptwxpzukwd42x4ajfjlgrpum4a4htdd6lhb65apclsmno4.h" | |
| #include <c10/util/Unroll.h> | |
| #include <torch/csrc/inductor/aoti_torch/c/shim.h> | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_48_1( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 48 / 16, 1, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 48 / 16, 1, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(2, C + 32 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| _tile_zero(1); | |
| _tile_zero(2); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(3, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 3, 6); | |
| _tile_stream_loadd(4, A + 16 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_dpbf16ps(1, 4, 6); | |
| _tile_stream_loadd(5, A + 32 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_dpbf16ps(2, 5, 6); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(2, C + 32 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 48 / 16, 1, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_32_1( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 32 / 16, 1, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 1, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_loadd(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| _tile_zero(1); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(4, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 2, 4); | |
| _tile_stream_loadd(3, A + 16 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_dpbf16ps(1, 3, 4); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| _tile_stored(1, C + 16 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 1, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm_amx_kernel_16_1( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc, | |
| uint8_t tilecfg_rows | |
| ) { | |
| // TODO(jgong5): add prefetch hint for A, B, C | |
| auto loadconfig = [](const amx_tilecfg& cfg) { | |
| _tile_loadconfig(&cfg); | |
| }; | |
| const auto last_k_offset = K / 32 * 32; | |
| const auto tail_k_size = K - last_k_offset; | |
| if C10_LIKELY (last_k_offset > 0) { | |
| amx_state.configure(tilecfg_rows, 64, 16 / 16, 1, loadconfig); | |
| } else { | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 1, loadconfig); | |
| } | |
| auto load_c = [&]() { | |
| _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| auto zero_c = [&]() { | |
| _tile_zero(0); | |
| }; | |
| if constexpr (accum) { | |
| load_c(); | |
| } else { | |
| zero_c(); | |
| } | |
| auto compute = [&](int k) { | |
| _tile_stream_loadd(1, A + 0 * lda + k, lda * sizeof(bfloat16)); | |
| _tile_loadd(2, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); | |
| _tile_dpbf16ps(0, 1, 2); | |
| }; | |
| #pragma GCC unroll 4 | |
| for (int k = 0; k < last_k_offset; k += 32) { | |
| compute(k); | |
| } | |
| auto store_c = [&]() { | |
| // store to C | |
| _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); | |
| }; | |
| // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead | |
| if C10_UNLIKELY (tail_k_size > 0) { | |
| if C10_LIKELY (last_k_offset > 0) { | |
| store_c(); | |
| amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 1, loadconfig); | |
| load_c(); | |
| } | |
| compute(last_k_offset); | |
| } | |
| store_c(); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm( | |
| AMXState& amx_state, | |
| const bfloat16* __restrict__ A, | |
| const bfloat16* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t M, | |
| int64_t N, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc | |
| ) { | |
| TORCH_CHECK(N % 16 == 0, "N dimension must be multiple of 16"); | |
| TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); | |
| // The ldb would not be block_n if N != block_n | |
| const int64_t updated_ldb = ldb; | |
| // TODO(jgong5): loop unroll for M and N | |
| for (int64_t n = 0; n < N; n += 16) { | |
| for (int64_t m = 0; m < M; m += 48) { | |
| int64_t block_m = std::min<int64_t>(M - m, 48); | |
| int64_t m_tail = m; | |
| if (block_m >= 48) { | |
| kernel_micro_gemm_amx_kernel_48_1<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 48; | |
| m_tail += 48; | |
| } | |
| else | |
| if (block_m >= 32) { | |
| kernel_micro_gemm_amx_kernel_32_1<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 32; | |
| m_tail += 32; | |
| } | |
| else | |
| if (block_m >= 16) { | |
| kernel_micro_gemm_amx_kernel_16_1<accum>( | |
| amx_state, | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| 16 | |
| ); | |
| block_m -= 16; | |
| m_tail += 16; | |
| } | |
| if (block_m > 0) { | |
| kernel_micro_gemm_amx_kernel_16_1<accum>( | |
| amx_state, | |
| A + m_tail * lda, | |
| B + n, | |
| C + m_tail * ldc + n, | |
| K, | |
| lda, | |
| updated_ldb, | |
| ldc, | |
| block_m | |
| ); | |
| } | |
| } | |
| } | |
| } | |
| extern "C" | |
| void kernel(const bfloat16* X0, const bfloat16* W0, const bfloat16* W1, bfloat16* Y0, bfloat16* Y1) | |
| { | |
| constexpr int64_t num_threads = 56; | |
| constexpr int64_t N = 1024; | |
| constexpr int64_t K = 512; | |
| constexpr int64_t Mr = 48; | |
| constexpr int64_t Nr = 16; | |
| constexpr int64_t Kr = 32; | |
| constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; | |
| constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; | |
| constexpr int64_t M = static_cast<int64_t>(4L); | |
| constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; | |
| constexpr int64_t Mt_blocks = 1; | |
| constexpr int64_t Nt_blocks = 2; | |
| constexpr int64_t Kt_blocks = 16; | |
| constexpr int64_t Mc_blocks = 1; | |
| constexpr int64_t Nc_blocks = 1; | |
| constexpr int64_t Kc_blocks = 16; | |
| constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; | |
| constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; | |
| constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; | |
| constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; | |
| constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; | |
| // make sure all partitions are assigned | |
| TORCH_CHECK( | |
| Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks, | |
| "Not all partitions are assigned." | |
| ); | |
| #pragma omp parallel num_threads(56) | |
| { | |
| const int tid = omp_get_thread_num(); | |
| const int64_t k_group_id = tid / num_Kt_blocks; | |
| const int64_t k_slice_id = tid % num_Kt_blocks; | |
| const int64_t n_group_id = k_group_id / num_Nt_blocks; | |
| const int64_t n_slice_id = k_group_id % num_Nt_blocks; | |
| const int64_t k_block_start = k_slice_id * Kt_blocks; | |
| const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); | |
| const int64_t n_block_start = n_slice_id * Nt_blocks; | |
| const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); | |
| const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); | |
| const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); | |
| const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; | |
| AMXState amx_state; | |
| auto _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_0 = _local_acc_buf_0.get(); | |
| auto _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_1 = _local_acc_buf_1.get(); | |
| for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { | |
| const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; | |
| const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; | |
| const int64_t m_start = mc * Mr; | |
| const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); | |
| const int64_t m_size = m_end - m_start; | |
| for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { | |
| const int64_t n_start = nc * Nr; | |
| const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); | |
| const int64_t n_size = n_end - n_start; | |
| // NB: assume we pad N, nc_block_end won't exceed padded N here. | |
| const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); | |
| if (_local_acc_buf_0 == nullptr) { _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_0 = _local_acc_buf_0.get(); } | |
| if (_local_acc_buf_1 == nullptr) { _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_1 = _local_acc_buf_1.get(); } | |
| for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { | |
| int64_t k_start = kc * Kr; | |
| int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); | |
| for (int64_t nci = nc; nci < nc_block_end; nci++) { | |
| if (kc == k_block_start) { | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W0[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W1[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } else { | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W0[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| amx_state, | |
| &(X0[static_cast<int64_t>(k_start + 512L*m_start)]), | |
| &(W1[static_cast<int64_t>(16L*k_start + 8192L*nci)]), | |
| &(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(512L), | |
| static_cast<int64_t>(16L), | |
| static_cast<int64_t>(Nc_blocks*Nr) | |
| ); | |
| } | |
| } | |
| } | |
| { | |
| { | |
| #pragma GCC ivdep | |
| for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) | |
| { | |
| for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L)) | |
| { | |
| { | |
| if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16)); | |
| auto tmp1 = at::vec::convert<bfloat16>(tmp0); | |
| auto tmp3 = at::vec::convert<bfloat16>(tmp2); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(16)); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(16)); | |
| } | |
| if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start)))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp1 = at::vec::convert<bfloat16>(tmp0); | |
| auto tmp3 = at::vec::convert<bfloat16>(tmp2); | |
| tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| amx_state.release([]() { _tile_release(); }); | |
| } | |
| } | |
| ''') | |
| async_compile.wait(globals()) | |
| del async_compile | |
| def call(args): | |
| arg2_1, = args | |
| args.clear() | |
| assert_size_stride(arg2_1, (2, 2, 512), (1024, 512, 1)) | |
| buf1 = empty_strided_cpu((4, 1024), (1024, 1), torch.bfloat16) | |
| buf2 = empty_strided_cpu((4, 1024), (1024, 1), torch.bfloat16) | |
| cpp_fused_0(arg2_1, constant2, constant3, buf1, buf2) | |
| del arg2_1 | |
| return (reinterpret_tensor(buf1, (2, 2, 1024), (2048, 1024, 1), 0), reinterpret_tensor(buf2, (2, 2, 1024), (2048, 1024, 1), 0), ) | |
| def benchmark_compiled_module(times=10, repeat=10): | |
| from torch._dynamo.testing import rand_strided | |
| from torch._inductor.utils import print_performance | |
| global _frozen_param4 | |
| _frozen_param4 = rand_strided((1024, 512), (1, 0), device='cpu', dtype=torch.bfloat16) | |
| global _frozen_param5 | |
| _frozen_param5 = rand_strided((1024, 512), (1, 0), device='cpu', dtype=torch.bfloat16) | |
| global constant2 | |
| constant2 = rand_strided((64, 512, 16), (8192, 16, 1), device='cpu', dtype=torch.bfloat16) | |
| global constant3 | |
| constant3 = rand_strided((64, 512, 16), (8192, 16, 1), device='cpu', dtype=torch.bfloat16) | |
| arg2_1 = rand_strided((2, 2, 512), (1024, 512, 1), device='cpu', dtype=torch.bfloat16) | |
| fn = lambda: call([arg2_1]) | |
| return print_performance(fn, times=times, repeat=repeat) | |
| if __name__ == "__main__": | |
| from torch._inductor.wrapper_benchmark import compiled_module_main | |
| compiled_module_main('None', benchmark_compiled_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment