Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16 to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16 to your computer and use it in GitHub Desktop.
import torch
from torch._inductor.codecache import CppWrapperCodeCache
cpp_wrapper_src = (
'''
#include <optional>
#include <Python.h>
#define PYBIND11_SIMPLE_GIL_MANAGEMENT
#include <pybind11/gil.h>
namespace py = pybind11;
class RAIIPyObject {
public:
RAIIPyObject() : obj_(nullptr) {}
RAIIPyObject(PyObject* obj) : obj_(obj) {}
~RAIIPyObject() {
Py_XDECREF(obj_);
}
RAIIPyObject& operator=(const RAIIPyObject& other) {
if (this != &other) {
Py_XDECREF(obj_);
obj_ = other.obj_;
Py_XINCREF(obj_);
}
return *this;
}
operator PyObject*() {
return obj_;
}
PyObject* get() {
return obj_;
}
private:
PyObject* obj_;
};
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
using namespace torch::aot_inductor;
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h>
#include <c10/util/generic_math.h>
typedef at::Half half;
typedef at::BFloat16 bfloat16;
// Round up to the nearest multiple of 64
[[maybe_unused]] static int64_t align(int64_t nbytes) {
return (nbytes + 64 - 1) & -64;
}
// _frozen_param4 device(type='cpu') torch.float16 (32, 52) (1, 0) 7f20b991d620
// _frozen_param5 device(type='cpu') torch.float16 (32, 52) (1, 0) 7f20b991ee80
// constant2 device(type='cpu') torch.float16 (1, 52, 32) (1664, 32, 1) 7f20bb9676a0
// constant3 device(type='cpu') torch.float16 (1, 52, 32) (1664, 32, 1) 7f20b9903970
#include "/tmp/torchinductor_leslie/db/cdb7hyptwxpzukwd42x4ajfjlgrpum4a4htdd6lhb65apclsmno4.h"
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
inline void cpp_fused_0_micro_gemm_kernel(
const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
using Vectorized = at::vec::Vectorized<float>;
using VectorizedIn = at::vec::Vectorized<half>;
constexpr auto VLEN = Vectorized::size();
constexpr auto ROWS = BLOCK_M;
constexpr auto COLS = BLOCK_N / VLEN;
Vectorized va;
at::vec::VectorizedN<float, COLS> vb;
at::vec::VectorizedN<float, ROWS*COLS> vc;
auto loadc = [&](auto i) {
if constexpr (accum) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
} else {
vc[i] = Vectorized(0.0f);
}
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&, COLS](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = Vectorized(static_cast<float>(A[row * lda + k]));
}
if constexpr (row == 0) {
auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
vb[col] = at::vec::convert<float>(b);
}
constexpr int idx = row * COLS + col;
vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
};
for (int k = 0; k < K; ++k) {
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
// store to C
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i].store(C + row * ldc + col * VLEN);
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
template <bool accum>
inline void cpp_fused_0_micro_gemm(
const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32");
TORCH_CHECK(K % 1 == 0, "K dimension must be multiple of 1");
// TODO(jgong5): loop unroll for M and N
for (int64_t m = 0; m < M; m += 8) {
int64_t block_m = std::min<int64_t>(M - m, 8);
for (int64_t n = 0; n < N; n += 32) {
if (block_m == 8) {
cpp_fused_0_micro_gemm_kernel<8, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
} else {
switch (block_m) {
case 7:
cpp_fused_0_micro_gemm_kernel<7, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 6:
cpp_fused_0_micro_gemm_kernel<6, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 5:
cpp_fused_0_micro_gemm_kernel<5, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 4:
cpp_fused_0_micro_gemm_kernel<4, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 3:
cpp_fused_0_micro_gemm_kernel<3, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 2:
cpp_fused_0_micro_gemm_kernel<2, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
case 1:
cpp_fused_0_micro_gemm_kernel<1, 32, accum>(
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc
);
break;
default:
TORCH_CHECK(false, "Unsupported block_m: 8");
}
}
}
}
}
extern "C"
void cpp_fused_0(const half* X0, const half* W0, const half* W1, half* Y0, half* Y1)
{
constexpr int64_t num_threads = 56;
constexpr int64_t N = 32;
constexpr int64_t K = 52;
constexpr int64_t Mr = 8;
constexpr int64_t Nr = 32;
constexpr int64_t Kr = 1;
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
constexpr int64_t M = static_cast<int64_t>(16L);
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
constexpr int64_t Mt_blocks = 1;
constexpr int64_t Nt_blocks = 1;
constexpr int64_t Kt_blocks = 52;
constexpr int64_t Mc_blocks = 1;
constexpr int64_t Nc_blocks = 1;
constexpr int64_t Kc_blocks = 52;
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
// make sure all partitions are assigned
TORCH_CHECK(
Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks,
"Not all partitions are assigned."
);
#pragma omp parallel num_threads(56)
{
const int tid = omp_get_thread_num();
const int64_t k_group_id = tid / num_Kt_blocks;
const int64_t k_slice_id = tid % num_Kt_blocks;
const int64_t n_group_id = k_group_id / num_Nt_blocks;
const int64_t n_slice_id = k_group_id % num_Nt_blocks;
const int64_t k_block_start = k_slice_id * Kt_blocks;
const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
const int64_t n_block_start = n_slice_id * Nt_blocks;
const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
auto _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_0 = _local_acc_buf_0.get();
auto _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_1 = _local_acc_buf_1.get();
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
const int64_t m_start = mc * Mr;
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
const int64_t m_size = m_end - m_start;
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * Nr;
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
const int64_t n_size = n_end - n_start;
// NB: assume we pad N, nc_block_end won't exceed padded N here.
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
if (_local_acc_buf_0 == nullptr) { _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_0 = _local_acc_buf_0.get(); }
if (_local_acc_buf_1 == nullptr) { _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_1 = _local_acc_buf_1.get(); }
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
for (int64_t nci = nc; nci < nc_block_end; nci++) {
if (kc == k_block_start) {
cpp_fused_0_micro_gemm<static_cast<bool>(false)>(
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
cpp_fused_0_micro_gemm<static_cast<bool>(false)>(
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
} else {
cpp_fused_0_micro_gemm<static_cast<bool>(true)>(
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
cpp_fused_0_micro_gemm<static_cast<bool>(true)>(
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
}
}
}
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16));
auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<half>(tmp0);
auto tmp3 = at::vec::convert<half>(tmp2);
tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16));
tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp1 = at::vec::convert<half>(tmp0);
auto tmp3 = at::vec::convert<half>(tmp2);
tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
}
}
}
}
}
}
}
}
}
}
CACHE_TORCH_DTYPE(float16);
CACHE_TORCH_DEVICE(cpu);
void inductor_entry_impl(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed)
) {
py::gil_scoped_release release;
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 5);
auto arg2_1 = std::move(inputs[0]);
[[maybe_unused]] auto _frozen_param4 = std::move(inputs[1]);
[[maybe_unused]] auto _frozen_param5 = std::move(inputs[2]);
[[maybe_unused]] auto constant2 = std::move(inputs[3]);
[[maybe_unused]] auto constant3 = std::move(inputs[4]);
static constexpr int64_t int_array_0[] = {16L, 32L};
static constexpr int64_t int_array_1[] = {32L, 1L};
AtenTensorHandle buf1_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float16, cached_torch_device_type_cpu, 0, &buf1_handle));
RAIIAtenTensorHandle buf1(buf1_handle);
AtenTensorHandle buf2_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float16, cached_torch_device_type_cpu, 0, &buf2_handle));
RAIIAtenTensorHandle buf2(buf2_handle);
cpp_fused_0((const half*)(arg2_1.data_ptr()), (const half*)(constant2.data_ptr()), (const half*)(constant3.data_ptr()), (half*)(buf1.data_ptr()), (half*)(buf2.data_ptr()));
arg2_1.reset();
output_handles[0] = buf1.release();
output_handles[1] = buf2.release();
} // inductor_entry_impl
'''
)
inductor_entry = CppWrapperCodeCache.load_pybinding(
["std::vector<AtenTensorHandle>"], cpp_wrapper_src, "cpu", 2)
def _wrap_func(f):
def g(args):
input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
constants_tensor = [_frozen_param4, _frozen_param5, constant2, constant3]
input_tensors.extend(constants_tensor)
input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)
args.clear()
output_handles = f(input_handles)
output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
return output_tensors
return g
call = _wrap_func(inductor_entry)
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param4
_frozen_param4 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.float16)
global _frozen_param5
_frozen_param5 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.float16)
global constant2
constant2 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.float16)
global constant3
constant3 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.float16)
arg2_1 = rand_strided((16, 52), (52, 1), device='cpu', dtype=torch.float16)
fn = lambda: call([arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
_frozen_param4 = None # device(type='cpu') torch.bfloat16 (32, 52) (1, 0) 7f6281a2b970
_frozen_param5 = None # device(type='cpu') torch.bfloat16 (32, 52) (1, 0) 7f6281a2bab0
constant2 = None # device(type='cpu') torch.bfloat16 (1, 52, 32) (1664, 32, 1) 7f62751fb3d0
constant3 = None # device(type='cpu') torch.bfloat16 (1, 52, 32) (1664, 32, 1) 7f62751fb380
cpp_fused_0 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*', 'const int64_t'], '''
#include "/tmp/torchinductor_leslie/db/cdb7hyptwxpzukwd42x4ajfjlgrpum4a4htdd6lhb65apclsmno4.h"
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
template <bool accum>
inline void kernel_micro_gemm_amx_kernel_32_2(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
uint8_t tilecfg_rows
) {
// TODO(jgong5): add prefetch hint for A, B, C
auto loadconfig = [](const amx_tilecfg& cfg) {
_tile_loadconfig(&cfg);
};
const auto last_k_offset = K / 32 * 32;
const auto tail_k_size = K - last_k_offset;
if C10_LIKELY (last_k_offset > 0) {
amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig);
} else {
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
}
auto load_c = [&]() {
_tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
_tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float));
_tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float));
};
auto zero_c = [&]() {
_tile_zero(0);
_tile_zero(1);
_tile_zero(2);
_tile_zero(3);
};
if constexpr (accum) {
load_c();
} else {
zero_c();
}
auto compute = [&](int k) {
_tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16));
_tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(0, 4, 6);
_tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(1, 4, 7);
_tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16));
_tile_dpbf16ps(2, 5, 6);
_tile_dpbf16ps(3, 5, 7);
};
#pragma GCC unroll 4
for (int k = 0; k < last_k_offset; k += 32) {
compute(k);
}
auto store_c = [&]() {
// store to C
_tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
_tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float));
_tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float));
};
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
if C10_UNLIKELY (tail_k_size > 0) {
if C10_LIKELY (last_k_offset > 0) {
store_c();
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
load_c();
}
compute(last_k_offset);
}
store_c();
}
template <bool accum>
inline void kernel_micro_gemm_amx_kernel_16_2(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
uint8_t tilecfg_rows
) {
// TODO(jgong5): add prefetch hint for A, B, C
auto loadconfig = [](const amx_tilecfg& cfg) {
_tile_loadconfig(&cfg);
};
const auto last_k_offset = K / 32 * 32;
const auto tail_k_size = K - last_k_offset;
if C10_LIKELY (last_k_offset > 0) {
amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig);
} else {
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
}
auto load_c = [&]() {
_tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
};
auto zero_c = [&]() {
_tile_zero(0);
_tile_zero(1);
};
if constexpr (accum) {
load_c();
} else {
zero_c();
}
auto compute = [&](int k) {
_tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16));
_tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(0, 2, 3);
_tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(1, 2, 4);
};
#pragma GCC unroll 4
for (int k = 0; k < last_k_offset; k += 32) {
compute(k);
}
auto store_c = [&]() {
// store to C
_tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
};
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
if C10_UNLIKELY (tail_k_size > 0) {
if C10_LIKELY (last_k_offset > 0) {
store_c();
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
load_c();
}
compute(last_k_offset);
}
store_c();
}
template <bool accum>
inline void kernel_micro_gemm(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32");
TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
// The ldb would not be block_n if N != block_n
const int64_t updated_ldb = ldb;
// TODO(jgong5): loop unroll for M and N
for (int64_t n = 0; n < N; n += 32) {
for (int64_t m = 0; m < M; m += 32) {
int64_t block_m = std::min<int64_t>(M - m, 32);
int64_t m_tail = m;
if (block_m >= 32) {
kernel_micro_gemm_amx_kernel_32_2<accum>(
amx_state,
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
updated_ldb,
ldc,
16
);
block_m -= 32;
m_tail += 32;
}
else
if (block_m >= 16) {
kernel_micro_gemm_amx_kernel_16_2<accum>(
amx_state,
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
updated_ldb,
ldc,
16
);
block_m -= 16;
m_tail += 16;
}
if (block_m > 0) {
kernel_micro_gemm_amx_kernel_16_2<accum>(
amx_state,
A + m_tail * lda,
B + n,
C + m_tail * ldc + n,
K,
lda,
updated_ldb,
ldc,
block_m
);
}
}
}
}
extern "C"
void kernel(const bfloat16* X0, const bfloat16* W0, const bfloat16* W1, bfloat16* Y0, bfloat16* Y1, const int64_t ks0)
{
constexpr int64_t num_threads = 56;
constexpr int64_t N = 32;
constexpr int64_t K = 52;
constexpr int64_t Mr = 32;
constexpr int64_t Nr = 32;
constexpr int64_t Kr = 32;
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
const int64_t M = static_cast<int64_t>(ks0);
const int64_t Mr_blocks = (M + Mr - 1) / Mr;
int64_t Mt_blocks, Nt_blocks, Kt_blocks;
mm_get_thread_blocking(num_threads, 1, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
int64_t Mc_blocks, Nc_blocks, Kc_blocks;
uint32_t L1_cache_size = 49152;
uint32_t L2_cache_size = 2097152;
mm_get_cache_blocking<bfloat16, bfloat16>(
num_threads,
M,
N,
K,
Mr,
Nr,
Kr,
Mt_blocks,
Nt_blocks,
Kt_blocks,
Mc_blocks,
Nc_blocks,
Kc_blocks,
L1_cache_size,
L2_cache_size
);
const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
// make sure all partitions are assigned
TORCH_CHECK(
Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks,
"Not all partitions are assigned."
);
#pragma omp parallel num_threads(56)
{
const int tid = omp_get_thread_num();
const int64_t k_group_id = tid / num_Kt_blocks;
const int64_t k_slice_id = tid % num_Kt_blocks;
const int64_t n_group_id = k_group_id / num_Nt_blocks;
const int64_t n_slice_id = k_group_id % num_Nt_blocks;
const int64_t k_block_start = k_slice_id * Kt_blocks;
const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
const int64_t n_block_start = n_slice_id * Nt_blocks;
const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
AMXState amx_state;
auto _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_0 = _local_acc_buf_0.get();
auto _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_1 = _local_acc_buf_1.get();
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
const int64_t m_start = mc * Mr;
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
const int64_t m_size = m_end - m_start;
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * Nr;
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
const int64_t n_size = n_end - n_start;
// NB: assume we pad N, nc_block_end won't exceed padded N here.
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
if (_local_acc_buf_0 == nullptr) { _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_0 = _local_acc_buf_0.get(); }
if (_local_acc_buf_1 == nullptr) { _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_1 = _local_acc_buf_1.get(); }
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
for (int64_t nci = nc; nci < nc_block_end; nci++) {
if (kc == k_block_start) {
kernel_micro_gemm<static_cast<bool>(false)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
kernel_micro_gemm<static_cast<bool>(false)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
} else {
kernel_micro_gemm<static_cast<bool>(true)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W0[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
kernel_micro_gemm<static_cast<bool>(true)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 52L*m_start)]),
&(W1[static_cast<int64_t>(32L*k_start + 1664L*nci)]),
&(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(52L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
}
}
}
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16));
auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
auto tmp3 = at::vec::convert<bfloat16>(tmp2);
tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16));
tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
auto tmp3 = at::vec::convert<bfloat16>(tmp2);
tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 32L*m_start + 32L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
}
}
}
}
}
}
}
}
amx_state.release([]() { _tile_release(); });
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg2_1, arg3_1 = args
args.clear()
s0 = arg2_1
assert_size_stride(arg3_1, (s0, 52), (52, 1))
buf1 = empty_strided_cpu((s0, 32), (32, 1), torch.bfloat16)
buf2 = empty_strided_cpu((s0, 32), (32, 1), torch.bfloat16)
cpp_fused_0(arg3_1, constant2, constant3, buf1, buf2, s0)
del arg3_1
return (buf1, buf2, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param4
_frozen_param4 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param5
_frozen_param5 = rand_strided((32, 52), (1, 0), device='cpu', dtype=torch.bfloat16)
global constant2
constant2 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.bfloat16)
global constant3
constant3 = rand_strided((1, 52, 32), (1664, 32, 1), device='cpu', dtype=torch.bfloat16)
arg2_1 = 16
arg3_1 = rand_strided((16, 52), (52, 1), device='cpu', dtype=torch.bfloat16)
fn = lambda: call([arg2_1, arg3_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
_frozen_param4 = None # device(type='cpu') torch.bfloat16 (1024, 512) (1, 0) 7f9d8b5cf8d0
_frozen_param5 = None # device(type='cpu') torch.bfloat16 (1024, 512) (1, 0) 7f9d8b5ceb60
constant2 = None # device(type='cpu') torch.bfloat16 (64, 512, 16) (8192, 16, 1) 7f9d8abc4e00
constant3 = None # device(type='cpu') torch.bfloat16 (64, 512, 16) (8192, 16, 1) 7f9d8abc4db0
cpp_fused_0 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include "/tmp/torchinductor_leslie/db/cdb7hyptwxpzukwd42x4ajfjlgrpum4a4htdd6lhb65apclsmno4.h"
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
template <bool accum>
inline void kernel_micro_gemm_amx_kernel_48_1(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
uint8_t tilecfg_rows
) {
// TODO(jgong5): add prefetch hint for A, B, C
auto loadconfig = [](const amx_tilecfg& cfg) {
_tile_loadconfig(&cfg);
};
const auto last_k_offset = K / 32 * 32;
const auto tail_k_size = K - last_k_offset;
if C10_LIKELY (last_k_offset > 0) {
amx_state.configure(tilecfg_rows, 64, 48 / 16, 1, loadconfig);
} else {
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 48 / 16, 1, loadconfig);
}
auto load_c = [&]() {
_tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_loadd(1, C + 16 * ldc + 0, ldc * sizeof(float));
_tile_loadd(2, C + 32 * ldc + 0, ldc * sizeof(float));
};
auto zero_c = [&]() {
_tile_zero(0);
_tile_zero(1);
_tile_zero(2);
};
if constexpr (accum) {
load_c();
} else {
zero_c();
}
auto compute = [&](int k) {
_tile_stream_loadd(3, A + 0 * lda + k, lda * sizeof(bfloat16));
_tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(0, 3, 6);
_tile_stream_loadd(4, A + 16 * lda + k, lda * sizeof(bfloat16));
_tile_dpbf16ps(1, 4, 6);
_tile_stream_loadd(5, A + 32 * lda + k, lda * sizeof(bfloat16));
_tile_dpbf16ps(2, 5, 6);
};
#pragma GCC unroll 4
for (int k = 0; k < last_k_offset; k += 32) {
compute(k);
}
auto store_c = [&]() {
// store to C
_tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_stored(1, C + 16 * ldc + 0, ldc * sizeof(float));
_tile_stored(2, C + 32 * ldc + 0, ldc * sizeof(float));
};
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
if C10_UNLIKELY (tail_k_size > 0) {
if C10_LIKELY (last_k_offset > 0) {
store_c();
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 48 / 16, 1, loadconfig);
load_c();
}
compute(last_k_offset);
}
store_c();
}
template <bool accum>
inline void kernel_micro_gemm_amx_kernel_32_1(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
uint8_t tilecfg_rows
) {
// TODO(jgong5): add prefetch hint for A, B, C
auto loadconfig = [](const amx_tilecfg& cfg) {
_tile_loadconfig(&cfg);
};
const auto last_k_offset = K / 32 * 32;
const auto tail_k_size = K - last_k_offset;
if C10_LIKELY (last_k_offset > 0) {
amx_state.configure(tilecfg_rows, 64, 32 / 16, 1, loadconfig);
} else {
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 1, loadconfig);
}
auto load_c = [&]() {
_tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_loadd(1, C + 16 * ldc + 0, ldc * sizeof(float));
};
auto zero_c = [&]() {
_tile_zero(0);
_tile_zero(1);
};
if constexpr (accum) {
load_c();
} else {
zero_c();
}
auto compute = [&](int k) {
_tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16));
_tile_loadd(4, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(0, 2, 4);
_tile_stream_loadd(3, A + 16 * lda + k, lda * sizeof(bfloat16));
_tile_dpbf16ps(1, 3, 4);
};
#pragma GCC unroll 4
for (int k = 0; k < last_k_offset; k += 32) {
compute(k);
}
auto store_c = [&]() {
// store to C
_tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_stored(1, C + 16 * ldc + 0, ldc * sizeof(float));
};
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
if C10_UNLIKELY (tail_k_size > 0) {
if C10_LIKELY (last_k_offset > 0) {
store_c();
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 1, loadconfig);
load_c();
}
compute(last_k_offset);
}
store_c();
}
template <bool accum>
inline void kernel_micro_gemm_amx_kernel_16_1(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
uint8_t tilecfg_rows
) {
// TODO(jgong5): add prefetch hint for A, B, C
auto loadconfig = [](const amx_tilecfg& cfg) {
_tile_loadconfig(&cfg);
};
const auto last_k_offset = K / 32 * 32;
const auto tail_k_size = K - last_k_offset;
if C10_LIKELY (last_k_offset > 0) {
amx_state.configure(tilecfg_rows, 64, 16 / 16, 1, loadconfig);
} else {
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 1, loadconfig);
}
auto load_c = [&]() {
_tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
};
auto zero_c = [&]() {
_tile_zero(0);
};
if constexpr (accum) {
load_c();
} else {
zero_c();
}
auto compute = [&](int k) {
_tile_stream_loadd(1, A + 0 * lda + k, lda * sizeof(bfloat16));
_tile_loadd(2, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(0, 1, 2);
};
#pragma GCC unroll 4
for (int k = 0; k < last_k_offset; k += 32) {
compute(k);
}
auto store_c = [&]() {
// store to C
_tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
};
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
if C10_UNLIKELY (tail_k_size > 0) {
if C10_LIKELY (last_k_offset > 0) {
store_c();
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 1, loadconfig);
load_c();
}
compute(last_k_offset);
}
store_c();
}
template <bool accum>
inline void kernel_micro_gemm(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
TORCH_CHECK(N % 16 == 0, "N dimension must be multiple of 16");
TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
// The ldb would not be block_n if N != block_n
const int64_t updated_ldb = ldb;
// TODO(jgong5): loop unroll for M and N
for (int64_t n = 0; n < N; n += 16) {
for (int64_t m = 0; m < M; m += 48) {
int64_t block_m = std::min<int64_t>(M - m, 48);
int64_t m_tail = m;
if (block_m >= 48) {
kernel_micro_gemm_amx_kernel_48_1<accum>(
amx_state,
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
updated_ldb,
ldc,
16
);
block_m -= 48;
m_tail += 48;
}
else
if (block_m >= 32) {
kernel_micro_gemm_amx_kernel_32_1<accum>(
amx_state,
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
updated_ldb,
ldc,
16
);
block_m -= 32;
m_tail += 32;
}
else
if (block_m >= 16) {
kernel_micro_gemm_amx_kernel_16_1<accum>(
amx_state,
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
updated_ldb,
ldc,
16
);
block_m -= 16;
m_tail += 16;
}
if (block_m > 0) {
kernel_micro_gemm_amx_kernel_16_1<accum>(
amx_state,
A + m_tail * lda,
B + n,
C + m_tail * ldc + n,
K,
lda,
updated_ldb,
ldc,
block_m
);
}
}
}
}
extern "C"
void kernel(const bfloat16* X0, const bfloat16* W0, const bfloat16* W1, bfloat16* Y0, bfloat16* Y1)
{
constexpr int64_t num_threads = 56;
constexpr int64_t N = 1024;
constexpr int64_t K = 512;
constexpr int64_t Mr = 48;
constexpr int64_t Nr = 16;
constexpr int64_t Kr = 32;
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
constexpr int64_t M = static_cast<int64_t>(4L);
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
constexpr int64_t Mt_blocks = 1;
constexpr int64_t Nt_blocks = 2;
constexpr int64_t Kt_blocks = 16;
constexpr int64_t Mc_blocks = 1;
constexpr int64_t Nc_blocks = 1;
constexpr int64_t Kc_blocks = 16;
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
// make sure all partitions are assigned
TORCH_CHECK(
Mt_blocks * Nt_blocks * Kt_blocks * 56 >= Mr_blocks * Nr_blocks * Kr_blocks,
"Not all partitions are assigned."
);
#pragma omp parallel num_threads(56)
{
const int tid = omp_get_thread_num();
const int64_t k_group_id = tid / num_Kt_blocks;
const int64_t k_slice_id = tid % num_Kt_blocks;
const int64_t n_group_id = k_group_id / num_Nt_blocks;
const int64_t n_slice_id = k_group_id % num_Nt_blocks;
const int64_t k_block_start = k_slice_id * Kt_blocks;
const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
const int64_t n_block_start = n_slice_id * Nt_blocks;
const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
AMXState amx_state;
auto _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_0 = _local_acc_buf_0.get();
auto _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf_1 = _local_acc_buf_1.get();
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
const int64_t m_start = mc * Mr;
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
const int64_t m_size = m_end - m_start;
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * Nr;
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
const int64_t n_size = n_end - n_start;
// NB: assume we pad N, nc_block_end won't exceed padded N here.
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
if (_local_acc_buf_0 == nullptr) { _local_acc_buf_0 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_0 = _local_acc_buf_0.get(); }
if (_local_acc_buf_1 == nullptr) { _local_acc_buf_1 = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf_1 = _local_acc_buf_1.get(); }
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
for (int64_t nci = nc; nci < nc_block_end; nci++) {
if (kc == k_block_start) {
kernel_micro_gemm<static_cast<bool>(false)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 512L*m_start)]),
&(W0[static_cast<int64_t>(16L*k_start + 8192L*nci)]),
&(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(512L),
static_cast<int64_t>(16L),
static_cast<int64_t>(Nc_blocks*Nr)
);
kernel_micro_gemm<static_cast<bool>(false)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 512L*m_start)]),
&(W1[static_cast<int64_t>(16L*k_start + 8192L*nci)]),
&(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(512L),
static_cast<int64_t>(16L),
static_cast<int64_t>(Nc_blocks*Nr)
);
} else {
kernel_micro_gemm<static_cast<bool>(true)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 512L*m_start)]),
&(W0[static_cast<int64_t>(16L*k_start + 8192L*nci)]),
&(local_acc_buf_0[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(512L),
static_cast<int64_t>(16L),
static_cast<int64_t>(Nc_blocks*Nr)
);
kernel_micro_gemm<static_cast<bool>(true)>(
amx_state,
&(X0[static_cast<int64_t>(k_start + 512L*m_start)]),
&(W1[static_cast<int64_t>(16L*k_start + 8192L*nci)]),
&(local_acc_buf_1[static_cast<int64_t>(Nr*nci + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(512L),
static_cast<int64_t>(16L),
static_cast<int64_t>(Nc_blocks*Nr)
);
}
}
}
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16));
auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
auto tmp3 = at::vec::convert<bfloat16>(tmp2);
tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(16));
tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(n_end + ((-1L)*n_start))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf_0 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf_1 + static_cast<int64_t>(x1 + Nc_blocks*Nr*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
auto tmp3 = at::vec::convert<bfloat16>(tmp2);
tmp1.store(Y0 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
tmp3.store(Y1 + static_cast<int64_t>(n_start + x1 + 1024L*m_start + 1024L*x0), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(n_end + ((-1L)*n_start)), static_cast<int64_t>(16L))))));
}
}
}
}
}
}
}
}
amx_state.release([]() { _tile_release(); });
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg2_1, = args
args.clear()
assert_size_stride(arg2_1, (2, 2, 512), (1024, 512, 1))
buf1 = empty_strided_cpu((4, 1024), (1024, 1), torch.bfloat16)
buf2 = empty_strided_cpu((4, 1024), (1024, 1), torch.bfloat16)
cpp_fused_0(arg2_1, constant2, constant3, buf1, buf2)
del arg2_1
return (reinterpret_tensor(buf1, (2, 2, 1024), (2048, 1024, 1), 0), reinterpret_tensor(buf2, (2, 2, 1024), (2048, 1024, 1), 0), )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
global _frozen_param4
_frozen_param4 = rand_strided((1024, 512), (1, 0), device='cpu', dtype=torch.bfloat16)
global _frozen_param5
_frozen_param5 = rand_strided((1024, 512), (1, 0), device='cpu', dtype=torch.bfloat16)
global constant2
constant2 = rand_strided((64, 512, 16), (8192, 16, 1), device='cpu', dtype=torch.bfloat16)
global constant3
constant3 = rand_strided((64, 512, 16), (8192, 16, 1), device='cpu', dtype=torch.bfloat16)
arg2_1 = rand_strided((2, 2, 512), (1024, 512, 1), device='cpu', dtype=torch.bfloat16)
fn = lambda: call([arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment