Created
February 12, 2025 02:33
-
-
Save leslie-fang-intel/464fb112abdb105818ae09b057350e84 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include "/tmp/torchinductor_leslie/3b/c3bi5gk6mslf6u4iaqafhxm64z6u65e3eain4xlary5blqnvv6xx.h" | |
| #include <c10/util/Unroll.h> | |
| #include <torch/csrc/inductor/aoti_torch/c/shim.h> | |
| template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum> | |
| inline void kernel_micro_gemm_kernel( | |
| const float* __restrict__ A, | |
| const float* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc | |
| ) { | |
| using Vectorized = at::vec::Vectorized<float>; | |
| using VectorizedIn = at::vec::Vectorized<float>; | |
| constexpr auto VLEN = Vectorized::size(); | |
| constexpr auto ROWS = BLOCK_M; | |
| constexpr auto COLS = BLOCK_N / VLEN; | |
| Vectorized va; | |
| at::vec::VectorizedN<float, COLS> vb; | |
| at::vec::VectorizedN<float, ROWS*COLS> vc; | |
| auto loadc = [&](auto i) { | |
| if constexpr (accum) { | |
| constexpr int row = i / COLS; | |
| constexpr int col = i % COLS; | |
| vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); | |
| } else { | |
| vc[i] = Vectorized(0.0f); | |
| } | |
| }; | |
| c10::ForcedUnroll<ROWS * COLS>{}(loadc); | |
| auto compute = [&, COLS](auto i, int k) { | |
| constexpr int row = i / COLS; | |
| constexpr int col = i % COLS; | |
| if constexpr (col == 0) { | |
| va = Vectorized(static_cast<float>(A[row * lda + k])); | |
| } | |
| if constexpr (row == 0) { | |
| vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); | |
| } | |
| constexpr int idx = row * COLS + col; | |
| vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); | |
| }; | |
| for (int k = 0; k < K; ++k) { | |
| c10::ForcedUnroll<ROWS * COLS>{}(compute, k); | |
| } | |
| // store to C | |
| auto storec = [&](auto i) { | |
| constexpr int row = i / COLS; | |
| constexpr int col = i % COLS; | |
| vc[i].store(C + row * ldc + col * VLEN); | |
| }; | |
| c10::ForcedUnroll<ROWS * COLS>{}(storec); | |
| } | |
| template <bool accum> | |
| inline void kernel_micro_gemm( | |
| const float* __restrict__ A, | |
| const float* __restrict__ B, | |
| float* __restrict__ C, | |
| int64_t M, | |
| int64_t N, | |
| int64_t K, | |
| int64_t lda, | |
| int64_t ldb, | |
| int64_t ldc | |
| ) { | |
| TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); | |
| TORCH_CHECK(K % 1 == 0, "K dimension must be multiple of 1"); | |
| // TODO(jgong5): loop unroll for M and N | |
| for (int64_t m = 0; m < M; m += 8) { | |
| int64_t block_m = std::min<int64_t>(M - m, 8); | |
| for (int64_t n = 0; n < N; n += 32) { | |
| if (block_m == 8) { | |
| kernel_micro_gemm_kernel<8, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| } else { | |
| switch (block_m) { | |
| case 7: | |
| kernel_micro_gemm_kernel<7, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 6: | |
| kernel_micro_gemm_kernel<6, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 5: | |
| kernel_micro_gemm_kernel<5, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 4: | |
| kernel_micro_gemm_kernel<4, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 3: | |
| kernel_micro_gemm_kernel<3, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 2: | |
| kernel_micro_gemm_kernel<2, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| case 1: | |
| kernel_micro_gemm_kernel<1, 32, accum>( | |
| A + m * lda, | |
| B + n, | |
| C + m * ldc + n, | |
| K, | |
| lda, | |
| ldb, | |
| ldc | |
| ); | |
| break; | |
| default: | |
| TORCH_CHECK(false, "Unsupported block_m: 8"); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| extern "C" | |
| void kernel(const float* X, const float* W, const float* in_ptr2, const float* in_ptr3, bfloat16* Y) | |
| { | |
| constexpr int64_t num_threads = 56; | |
| constexpr int64_t N = 128; | |
| constexpr int64_t K = 64; | |
| constexpr int64_t Mr = 8; | |
| constexpr int64_t Nr = 32; | |
| constexpr int64_t Kr = 1; | |
| constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; | |
| constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; | |
| constexpr int64_t M = static_cast<int64_t>(8L); | |
| constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; | |
| constexpr int64_t Mt_blocks = 1; | |
| constexpr int64_t Nt_blocks = 1; | |
| constexpr int64_t Kt_blocks = 64; | |
| constexpr int64_t Mc_blocks = 1; | |
| constexpr int64_t Nc_blocks = 1; | |
| constexpr int64_t Kc_blocks = 64; | |
| constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; | |
| constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; | |
| constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; | |
| constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; | |
| constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; | |
| // make sure all partitions are assigned | |
| TORCH_CHECK( | |
| Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks, | |
| "Not all partitions are assigned." | |
| ); | |
| #pragma omp parallel num_threads(56) | |
| { | |
| const int tid = omp_get_thread_num(); | |
| const int64_t k_group_id = tid / num_Kt_blocks; | |
| const int64_t k_slice_id = tid % num_Kt_blocks; | |
| const int64_t n_group_id = k_group_id / num_Nt_blocks; | |
| const int64_t n_slice_id = k_group_id % num_Nt_blocks; | |
| const int64_t k_block_start = k_slice_id * Kt_blocks; | |
| const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); | |
| const int64_t n_block_start = n_slice_id * Nt_blocks; | |
| const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); | |
| const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); | |
| const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); | |
| const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; | |
| for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { | |
| const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; | |
| const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; | |
| const int64_t m_start = mc * Mr; | |
| const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); | |
| const int64_t m_size = m_end - m_start; | |
| for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { | |
| const int64_t n_start = nc * Nr; | |
| const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); | |
| const int64_t n_size = n_end - n_start; | |
| // NB: assume we pad N, nc_block_end won't exceed padded N here. | |
| const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); | |
| for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { | |
| int64_t k_start = kc * Kr; | |
| int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); | |
| for (int64_t nci = nc; nci < nc_block_end; nci++) { | |
| if (kc == k_block_start) { | |
| kernel_micro_gemm<static_cast<bool>(false)>( | |
| &(X[static_cast<int64_t>(k_start + 64L*m_start)]), | |
| &(W[static_cast<int64_t>(32L*k_start + 2048L*nci)]), | |
| &(Y[static_cast<int64_t>(n_start + 128L*m_start + Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(64L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(128L) | |
| ); | |
| } else { | |
| kernel_micro_gemm<static_cast<bool>(true)>( | |
| &(X[static_cast<int64_t>(k_start + 64L*m_start)]), | |
| &(W[static_cast<int64_t>(32L*k_start + 2048L*nci)]), | |
| &(Y[static_cast<int64_t>(n_start + 128L*m_start + Nr*nci + ((-1L)*Nr*nc))]), | |
| static_cast<int64_t>(m_end + ((-1L)*m_start)), | |
| static_cast<int64_t>(Nr), | |
| static_cast<int64_t>(k_end + ((-1L)*k_start)), | |
| static_cast<int64_t>(64L), | |
| static_cast<int64_t>(32L), | |
| static_cast<int64_t>(128L) | |
| ); | |
| } | |
| } | |
| } | |
| { | |
| { | |
| #pragma GCC ivdep | |
| for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) | |
| { | |
| for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L)) | |
| { | |
| { | |
| if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(16)); | |
| auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(16)); | |
| auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(16)); | |
| auto tmp2 = tmp0 * tmp1; | |
| auto tmp4 = tmp2 + tmp3; | |
| auto tmp5 = at::vec::convert<bfloat16>(tmp4); | |
| tmp5.store(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(16)); | |
| } | |
| if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start)))) | |
| { | |
| auto tmp0 = at::vec::Vectorized<float>::loadu(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| auto tmp2 = tmp0 * tmp1; | |
| auto tmp4 = tmp2 + tmp3; | |
| auto tmp5 = at::vec::convert<bfloat16>(tmp4); | |
| tmp5.store(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))))); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // Python bindings to call kernel(): | |
| #define PY_SSIZE_T_CLEAN | |
| #include <Python.h> | |
| #include <sstream> | |
| #include <cstdlib> | |
| #ifndef _MSC_VER | |
| #if __cplusplus < 202002L | |
| // C++20 (earlier) code | |
| // https://en.cppreference.com/w/cpp/language/attributes/likely | |
| #define likely(x) __builtin_expect(!!(x), 1) | |
| #define unlikely(x) __builtin_expect(!!(x), 0) | |
| #endif | |
| #else | |
| #define likely(x) (x) | |
| #define unlikely(x) (x) | |
| #endif | |
| // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. | |
| // We manually link it below to workaround issues with fbcode build. | |
| static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); | |
| template <typename T> static inline T parse_arg(PyObject* args, size_t n) { | |
| static_assert(std::is_pointer_v<T>, "arg type must be pointer or long"); | |
| return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); | |
| } | |
| template <> inline int64_t parse_arg<int64_t>(PyObject* args, size_t n) { | |
| auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); | |
| if(unlikely(result == -1 && PyErr_Occurred())) | |
| throw std::runtime_error("expected int arg"); | |
| return result; | |
| } | |
| template <> inline uintptr_t parse_arg<uintptr_t>(PyObject* args, size_t n) { | |
| auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n)); | |
| if(unlikely(result == reinterpret_cast<void*>(-1) && PyErr_Occurred())) | |
| throw std::runtime_error("expected int arg"); | |
| return reinterpret_cast<uintptr_t>(result); | |
| } | |
| static PyObject* kernel_py(PyObject* self, PyObject* args) { | |
| try { | |
| if(unlikely(!PyTuple_CheckExact(args))) | |
| throw std::runtime_error("tuple args required"); | |
| if(unlikely(PyTuple_GET_SIZE(args) != 5)) | |
| throw std::runtime_error("requires 5 args"); | |
| kernel(parse_arg<float*>(args, 0), parse_arg<float*>(args, 1), parse_arg<float*>(args, 2), parse_arg<float*>(args, 3), parse_arg<bfloat16*>(args, 4));Py_RETURN_NONE; | |
| } catch(std::exception const& e) { | |
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |
| return nullptr; | |
| } catch(...) { | |
| PyErr_SetString(PyExc_RuntimeError, "unhandled error"); | |
| return nullptr; | |
| } | |
| } | |
| static PyMethodDef py_methods[] = { | |
| {"kernel", kernel_py, METH_VARARGS, ""}, | |
| {NULL, NULL, 0, NULL}}; | |
| static struct PyModuleDef py_module = | |
| {PyModuleDef_HEAD_INIT, "kernel", NULL, -1, py_methods}; | |
| PyMODINIT_FUNC PyInit_kernel(void) { | |
| const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); | |
| if(!str_addr) { | |
| PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); | |
| return nullptr; | |
| } | |
| std::istringstream iss(str_addr); | |
| uintptr_t addr = 0; | |
| iss >> addr; | |
| _torchinductor_pyobject_tensor_data_ptr = | |
| reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr); | |
| PyObject* module = PyModule_Create(&py_module); | |
| if (module == NULL) { | |
| return NULL; | |
| } | |
| #ifdef Py_GIL_DISABLED | |
| PyUnstable_Module_SetGIL(mod, Py_MOD_GIL_NOT_USED); | |
| #endif | |
| return module; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment