Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created February 12, 2025 02:33
Show Gist options
  • Select an option

  • Save leslie-fang-intel/464fb112abdb105818ae09b057350e84 to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/464fb112abdb105818ae09b057350e84 to your computer and use it in GitHub Desktop.
#include "/tmp/torchinductor_leslie/3b/c3bi5gk6mslf6u4iaqafhxm64z6u65e3eain4xlary5blqnvv6xx.h"
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
inline void kernel_micro_gemm_kernel(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
using Vectorized = at::vec::Vectorized<float>;
using VectorizedIn = at::vec::Vectorized<float>;
constexpr auto VLEN = Vectorized::size();
constexpr auto ROWS = BLOCK_M;
constexpr auto COLS = BLOCK_N / VLEN;
Vectorized va;
at::vec::VectorizedN<float, COLS> vb;
at::vec::VectorizedN<float, ROWS*COLS> vc;
auto loadc = [&](auto i) {
if constexpr (accum) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
} else {
vc[i] = Vectorized(0.0f);
}
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&, COLS](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = Vectorized(static_cast<float>(A[row * lda + k]));
}
if constexpr (row == 0) {
vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
}
constexpr int idx = row * COLS + col;
vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
};
for (int k = 0; k < K; ++k) {
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
// store to C
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i].store(C + row * ldc + col * VLEN);
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
template <bool accum>
inline void kernel_micro_gemm(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32");
TORCH_CHECK(K % 1 == 0, "K dimension must be multiple of 1");
// TODO(jgong5): loop unroll for M and N
for (int64_t m = 0; m < M; m += 8) {
int64_t block_m = std::min<int64_t>(M - m, 8);
for (int64_t n = 0; n < N; n += 32) {
if (block_m == 8) {
kernel_micro_gemm_kernel<8, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
} else {
switch (block_m) {
case 7:
kernel_micro_gemm_kernel<7, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 6:
kernel_micro_gemm_kernel<6, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 5:
kernel_micro_gemm_kernel<5, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 4:
kernel_micro_gemm_kernel<4, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 3:
kernel_micro_gemm_kernel<3, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 2:
kernel_micro_gemm_kernel<2, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 1:
kernel_micro_gemm_kernel<1, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
default:
TORCH_CHECK(false, "Unsupported block_m: 8");
}
}
}
}
}
extern "C"
void kernel(const float* X, const float* W, const float* in_ptr2, const float* in_ptr3, bfloat16* Y)
{
constexpr int64_t num_threads = 56;
constexpr int64_t N = 128;
constexpr int64_t K = 64;
constexpr int64_t Mr = 8;
constexpr int64_t Nr = 32;
constexpr int64_t Kr = 1;
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
constexpr int64_t M = static_cast<int64_t>(8L);
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
constexpr int64_t Mt_blocks = 1;
constexpr int64_t Nt_blocks = 1;
constexpr int64_t Kt_blocks = 64;
constexpr int64_t Mc_blocks = 1;
constexpr int64_t Nc_blocks = 1;
constexpr int64_t Kc_blocks = 64;
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
// make sure all partitions are assigned
TORCH_CHECK(
Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks,
"Not all partitions are assigned."
);
#pragma omp parallel num_threads(56)
{
const int tid = omp_get_thread_num();
const int64_t k_group_id = tid / num_Kt_blocks;
const int64_t k_slice_id = tid % num_Kt_blocks;
const int64_t n_group_id = k_group_id / num_Nt_blocks;
const int64_t n_slice_id = k_group_id % num_Nt_blocks;
const int64_t k_block_start = k_slice_id * Kt_blocks;
const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
const int64_t n_block_start = n_slice_id * Nt_blocks;
const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
const int64_t m_start = mc * Mr;
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
const int64_t m_size = m_end - m_start;
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * Nr;
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
const int64_t n_size = n_end - n_start;
// NB: assume we pad N, nc_block_end won't exceed padded N here.
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
for (int64_t nci = nc; nci < nc_block_end; nci++) {
if (kc == k_block_start) {
kernel_micro_gemm<static_cast<bool>(false)>(
&(X[static_cast<int64_t>(k_start + 64L*m_start)]),
&(W[static_cast<int64_t>(32L*k_start + 2048L*nci)]),
&(Y[static_cast<int64_t>(n_start + 128L*m_start + Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(64L),
static_cast<int64_t>(32L),
static_cast<int64_t>(128L)
);
} else {
kernel_micro_gemm<static_cast<bool>(true)>(
&(X[static_cast<int64_t>(k_start + 64L*m_start)]),
&(W[static_cast<int64_t>(32L*k_start + 2048L*nci)]),
&(Y[static_cast<int64_t>(n_start + 128L*m_start + Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(64L),
static_cast<int64_t>(32L),
static_cast<int64_t>(128L)
);
}
}
}
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(16));
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(16));
auto tmp2 = tmp0 * tmp1;
auto tmp4 = tmp2 + tmp3;
auto tmp5 = at::vec::convert<bfloat16>(tmp4);
tmp5.store(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp3 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp2 = tmp0 * tmp1;
auto tmp4 = tmp2 + tmp3;
auto tmp5 = at::vec::convert<bfloat16>(tmp4);
tmp5.store(Y + static_cast<int64_t>(n_start + x1 + 128L*m_start + 128L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
}
}
}
}
}
}
}
}
}
}
// Python bindings to call kernel():
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sstream>
#include <cstdlib>
#ifndef _MSC_VER
#if __cplusplus < 202002L
// C++20 (earlier) code
// https://en.cppreference.com/w/cpp/language/attributes/likely
#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)
#endif
#else
#define likely(x) (x)
#define unlikely(x) (x)
#endif
// This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow.
// We manually link it below to workaround issues with fbcode build.
static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj);
template <typename T> static inline T parse_arg(PyObject* args, size_t n) {
static_assert(std::is_pointer_v<T>, "arg type must be pointer or long");
return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
}
template <> inline int64_t parse_arg<int64_t>(PyObject* args, size_t n) {
auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
if(unlikely(result == -1 && PyErr_Occurred()))
throw std::runtime_error("expected int arg");
return result;
}
template <> inline uintptr_t parse_arg<uintptr_t>(PyObject* args, size_t n) {
auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n));
if(unlikely(result == reinterpret_cast<void*>(-1) && PyErr_Occurred()))
throw std::runtime_error("expected int arg");
return reinterpret_cast<uintptr_t>(result);
}
static PyObject* kernel_py(PyObject* self, PyObject* args) {
try {
if(unlikely(!PyTuple_CheckExact(args)))
throw std::runtime_error("tuple args required");
if(unlikely(PyTuple_GET_SIZE(args) != 5))
throw std::runtime_error("requires 5 args");
kernel(parse_arg<float*>(args, 0), parse_arg<float*>(args, 1), parse_arg<float*>(args, 2), parse_arg<float*>(args, 3), parse_arg<bfloat16*>(args, 4));Py_RETURN_NONE;
} catch(std::exception const& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "unhandled error");
return nullptr;
}
}
static PyMethodDef py_methods[] = {
{"kernel", kernel_py, METH_VARARGS, ""},
{NULL, NULL, 0, NULL}};
static struct PyModuleDef py_module =
{PyModuleDef_HEAD_INIT, "kernel", NULL, -1, py_methods};
PyMODINIT_FUNC PyInit_kernel(void) {
const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR");
if(!str_addr) {
PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set");
return nullptr;
}
std::istringstream iss(str_addr);
uintptr_t addr = 0;
iss >> addr;
_torchinductor_pyobject_tensor_data_ptr =
reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr);
PyObject* module = PyModule_Create(&py_module);
if (module == NULL) {
return NULL;
}
#ifdef Py_GIL_DISABLED
PyUnstable_Module_SetGIL(mod, Py_MOD_GIL_NOT_USED);
#endif
return module;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment