Created
March 30, 2023 02:17
-
-
Save BreadFish64/55a223451f9e6b49044d675f1e11c542 to your computer and use it in GitHub Desktop.
Matrix multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "matrix.hpp" | |
#include <immintrin.h> | |
#include <cstring> | |
#include <future> | |
#include <vector> | |
// Anonymous namespace for internal linkage | |
namespace { | |
// For generating masks and gather indices | |
#define INC_VECTOR_256_EPI32 _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0) | |
#define INC_VECTOR_512_EPI32 _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) | |
// Makes a mask vector used by AVX2 for the limited number of instructions that support it | |
[[gnu::always_inline, gnu::target("avx2,fma")]] inline auto MaskFirst256EPI32(int n) noexcept { | |
const auto n8 = _mm256_set1_epi32(n); | |
return _mm256_cmpgt_epi32(n8, INC_VECTOR_256_EPI32); | |
} | |
// _mm256_mask_i32gather_ps wants the mask as a float vector | |
// But the bit representation is the same either way | |
[[gnu::always_inline, gnu::target("avx2,fma")]] inline auto SIasPS(__m256i x) noexcept { | |
__m256 y; | |
std::memcpy(&y, &x, sizeof(y)); | |
return y; | |
} | |
// Masks have their own registers and are allowed for pretty much everything in AVX512 | |
[[gnu::always_inline, gnu::target("avx512f,tune=cascadelake")]] inline auto MaskFirstMask16( | |
int n) noexcept { | |
return _cvtu32_mask16((1u << n) - 1u); | |
} | |
// Didn't feel like supporting pre-haswell processors | |
[[gnu::target("default")]] inline void TransposeBlock( | |
[[maybe_unused]] const Matrix<float* __restrict>& dst, | |
[[maybe_unused]] const Matrix<const float*>& src) noexcept { | |
assert(false); | |
} | |
// Takes several columns from the B vector and rotates them so they can be accessed more efficiently | |
[[gnu::target("avx2,fma")]] inline void TransposeBlock(const Matrix<float* __restrict>& dst, | |
const Matrix<const float*>& src) noexcept { | |
ASSUME(dst.width == src.height); | |
ASSUME(src.width == dst.height); | |
const size_t width = dst.width; | |
const size_t height = dst.height; | |
ASSUME(height <= CACHE_LINE_FLOATS); | |
const size_t full_width = width / 8; | |
const size_t leftover = width % 8; | |
const auto gather_offsets = | |
_mm256_mullo_epi32(INC_VECTOR_256_EPI32, _mm256_set1_epi32(src.stride)); | |
size_t x_tile = 0; | |
for ( ;x_tile < full_width; ++x_tile) { | |
for (size_t y = 0; y < height; ++y) { | |
// Gather loads each element from a separate offset | |
// So we can load columns by using multiple-of-stride offsets | |
auto v = _mm256_i32gather_ps(&src[x_tile * 8][y], gather_offsets, sizeof(float)); | |
_mm256_store_ps(&dst[y][x_tile * 8], v); | |
} | |
} | |
if (leftover) { | |
const auto mask = MaskFirst256EPI32(leftover); | |
const auto fmask = SIasPS(mask); | |
for (size_t y = 0; y < height; ++y) { | |
const auto v = _mm256_mask_i32gather_ps(_mm256_setzero_ps(), &src[x_tile * 8][y], | |
gather_offsets, fmask, sizeof(float)); | |
// Just assume that we have enough padding to not mask the stores | |
// Since the memory allocation for the transposed matrix was done internally | |
_mm256_store_ps(&dst[y][x_tile * 8], v); | |
} | |
} | |
} | |
[[gnu::target("avx512f")]] inline void TransposeBlock( | |
const Matrix<float* __restrict>& dst, const Matrix<const float*>& src) noexcept { | |
ASSUME(dst.width == src.height); | |
ASSUME(src.width == dst.height); | |
const size_t width = dst.width; | |
const size_t height = dst.height; | |
ASSUME(height <= CACHE_LINE_FLOATS); | |
const size_t full_width = width / 16; | |
const size_t leftover = width % 16; | |
const auto gather_offsets = | |
_mm512_mullo_epi32(INC_VECTOR_512_EPI32, _mm512_set1_epi32(src.stride)); | |
size_t x_tile = 0; | |
for (; x_tile < full_width; ++x_tile) { | |
for (size_t y = 0; y < height; ++y) { | |
auto v = _mm512_i32gather_ps(gather_offsets, &src[x_tile * 16][y], sizeof(float)); | |
_mm512_store_ps(&dst[y][x_tile * 16], v); | |
} | |
} | |
if (leftover) { | |
const auto mask = MaskFirstMask16(leftover); | |
for (size_t y = 0; y < height; ++y) { | |
const auto v = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, gather_offsets, | |
&src[x_tile * 16][y], sizeof(float)); | |
_mm512_store_ps(&dst[y][x_tile * 16], v); | |
} | |
} | |
} | |
[[gnu::target("avx2,fma")]] inline void AccumulateChunk(const Matrix<float* __restrict>& dst, | |
const Matrix<const __m256*>& src) noexcept { | |
ASSUME(dst.width == src.width); | |
ASSUME(dst.height == src.height); | |
ASSUME(dst.width <= 8); | |
ASSUME(dst.height == 1); | |
const size_t width = dst.width; | |
auto s0 = src[0][0]; | |
auto s1 = src[0][1]; | |
auto s2 = src[0][2]; | |
auto s3 = src[0][3]; | |
auto s4 = src[0][4]; | |
auto s5 = src[0][5]; | |
auto s6 = src[0][6]; | |
auto s7 = src[0][7]; | |
s0 = _mm256_hadd_ps(s0, s1); | |
s1 = _mm256_hadd_ps(s2, s3); | |
s2 = _mm256_hadd_ps(s4, s5); | |
s3 = _mm256_hadd_ps(s6, s7); | |
s0 = _mm256_hadd_ps(s0, s1); | |
s1 = _mm256_hadd_ps(s2, s3); | |
auto t0 = _mm256_extractf128_ps(s0, 1); | |
auto t1 = _mm256_extractf128_ps(s1, 0); | |
s0 = _mm256_insertf128_ps(s0, t1, 1); | |
s1 = _mm256_insertf128_ps(s1, t0, 0); | |
s0 = _mm256_add_ps(s0, s1); | |
const auto mask = MaskFirst256EPI32(width); | |
_mm256_maskstore_ps(dst.data, mask, s0); | |
} | |
[[gnu::target("avx512f")]] inline void AccumulateChunk( | |
const Matrix<float* __restrict>& dst, const Matrix<const __m512*>& src) noexcept { | |
ASSUME(dst.width == src.width); | |
ASSUME(dst.height == src.height); | |
ASSUME(dst.width <= 16); | |
ASSUME(dst.height == 1); | |
const size_t width = dst.width; | |
// There's no 512 bit hadd, and extracting 256 bit float vector from __m512 is a pain | |
const __m256* const src_ptr = reinterpret_cast<const __m256*>(src.data); | |
auto s0 = _mm256_add_ps(src_ptr[0], src_ptr[1]); | |
auto s1 = _mm256_add_ps(src_ptr[2], src_ptr[3]); | |
auto s2 = _mm256_add_ps(src_ptr[4], src_ptr[5]); | |
auto s3 = _mm256_add_ps(src_ptr[6], src_ptr[7]); | |
auto s4 = _mm256_add_ps(src_ptr[8], src_ptr[9]); | |
auto s5 = _mm256_add_ps(src_ptr[10], src_ptr[11]); | |
auto s6 = _mm256_add_ps(src_ptr[12], src_ptr[13]); | |
auto s7 = _mm256_add_ps(src_ptr[14], src_ptr[15]); | |
auto s8 = _mm256_add_ps(src_ptr[16], src_ptr[17]); | |
auto s9 = _mm256_add_ps(src_ptr[18], src_ptr[19]); | |
auto s10 = _mm256_add_ps(src_ptr[20], src_ptr[21]); | |
auto s11 = _mm256_add_ps(src_ptr[22], src_ptr[23]); | |
auto s12 = _mm256_add_ps(src_ptr[24], src_ptr[25]); | |
auto s13 = _mm256_add_ps(src_ptr[26], src_ptr[27]); | |
auto s14 = _mm256_add_ps(src_ptr[28], src_ptr[29]); | |
auto s15 = _mm256_add_ps(src_ptr[30], src_ptr[31]); | |
s0 = _mm256_hadd_ps(s0, s1); | |
s1 = _mm256_hadd_ps(s2, s3); | |
s2 = _mm256_hadd_ps(s4, s5); | |
s3 = _mm256_hadd_ps(s6, s7); | |
s4 = _mm256_hadd_ps(s8, s9); | |
s5 = _mm256_hadd_ps(s10, s11); | |
s6 = _mm256_hadd_ps(s12, s13); | |
s7 = _mm256_hadd_ps(s14, s15); | |
s0 = _mm256_hadd_ps(s0, s1); | |
s1 = _mm256_hadd_ps(s2, s3); | |
s2 = _mm256_hadd_ps(s4, s5); | |
s3 = _mm256_hadd_ps(s6, s7); | |
__m512 t0 = _mm512_setzero_ps(); | |
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s0, 0), 0); | |
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s1, 0), 1); | |
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s2, 0), 2); | |
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s3, 0), 3); | |
__m512 t1 = _mm512_setzero_ps(); | |
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s0, 1), 0); | |
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s1, 1), 1); | |
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s2, 1), 2); | |
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s3, 1), 3); | |
auto result = _mm512_add_ps(t0, t1); | |
const auto mask = MaskFirstMask16(width); | |
_mm512_mask_storeu_ps(dst.data, mask, result); | |
} | |
// Reduce all of the temporary accumulator vectors to single floats for the result | |
template <typename VectorT> | |
inline void Accumulate(const Matrix<float* __restrict>& dst, | |
const Matrix<const VectorT*>& src) noexcept { | |
constexpr size_t VECTOR_SIZE = sizeof(VectorT) / sizeof(float); | |
ASSUME(dst.width == src.width); | |
ASSUME(dst.height == src.height); | |
ASSUME(dst.width <= CACHE_LINE_FLOATS); | |
ASSUME(dst.height <= CACHE_LINE_FLOATS); | |
const size_t height = dst.height; | |
const size_t width = dst.width; | |
const size_t subrows = CeilDiv(width, VECTOR_SIZE); | |
for (size_t y = 0; y < height; ++y) { | |
for (size_t s = 0; s < subrows; ++s) { | |
AccumulateChunk(dst.subMat(y, s * VECTOR_SIZE, 1, VECTOR_SIZE), | |
src.subMat(y, s * VECTOR_SIZE, 1, VECTOR_SIZE)); | |
} | |
} | |
} | |
template [[gnu::target("avx2,fma")]] void Accumulate<__m256>( | |
const Matrix<float* __restrict>& dst, const Matrix<const __m256*>& src) noexcept; | |
template [[gnu::target("avx512f,tune=cascadelake")]] void Accumulate<__m512>( | |
const Matrix<float* __restrict>& dst, const Matrix<const __m512*>& src) noexcept; | |
// Implements C[i][k] += A[i][j] * B[j][k]; | |
// But for a VECTOR_SIZE x VECTOR_SIZE block | |
[[gnu::target("avx2,fma")]] inline void MultiplyMatrixSubtileIteration( | |
const Matrix<__m256* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept { | |
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B); | |
ASSUME(m <= 8); | |
ASSUME(n <= 8); | |
ASSUME(p <= 8); | |
ASSUME(C.stride == 8); | |
// Load 8x8 floats from B into YMM | |
// Compilers aren't good at doing this with arrays; | |
#define EXPANDED(k) __m256 b##k; | |
#include "common/expand8.inl" | |
#undef EXPANDED | |
// Don't mask since we took care of zeroing the extras during the transpose | |
if (__builtin_expect(p, 8)) { | |
#define EXPANDED(k) b##k = _mm256_load_ps(B[k]); | |
#include "common/expand8.inl" | |
#undef EXPANDED | |
} else { | |
#define EXPANDED(k) \ | |
if (k < p) { b##k = _mm256_load_ps(B[k]); } | |
#include "common/expand8.inl" | |
#undef EXPANDED | |
} | |
auto mask = MaskFirst256EPI32(n); | |
for (size_t i = 0; i < m; ++i) { | |
__m256 a = _mm256_maskload_ps(A[i], mask); | |
__m256* __restrict c_ptr = C[i]; | |
if (__builtin_expect(p, 8)) { | |
#define EXPANDED(k) \ | |
*c_ptr = _mm256_fmadd_ps(a, b##k, *c_ptr); \ | |
c_ptr++; | |
#include "common/expand8.inl" | |
#undef EXPANDED | |
} else { | |
// I experimented with unrolling using a Duff's device but the compiler doesn't generate | |
// the code you would expect. Indirect branch prediction is worse anyway. | |
#define EXPANDED(k) \ | |
if (k < p) { \ | |
*c_ptr = _mm256_fmadd_ps(a, b##k, *c_ptr); \ | |
c_ptr++; \ | |
} | |
#include "common/expand8.inl" | |
#undef EXPANDED | |
} | |
} | |
} | |
[[gnu::target("avx512f,tune=cascadelake")]] inline void MultiplyMatrixSubtileIteration( | |
const Matrix<__m512* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept { | |
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B); | |
ASSUME(m <= 16); | |
ASSUME(n <= 16); | |
ASSUME(p <= 16); | |
ASSUME(C.stride == 16); | |
// Load 16x16 floats from B into ZMM | |
// Compilers aren't good at doing this with arrays; | |
#define EXPANDED(k) __m512 b##k; | |
#include "common/expand16.inl" | |
#undef EXPANDED | |
if (__builtin_expect(p, 16)) { | |
#define EXPANDED(k) b##k = _mm512_load_ps(B[k]); | |
#include "common/expand16.inl" | |
#undef EXPANDED | |
} else { | |
#define EXPANDED(k) \ | |
if (k < p) { b##k = _mm512_load_ps(B[k]); } | |
#include "common/expand16.inl" | |
#undef EXPANDED | |
} | |
auto mask = MaskFirstMask16(n); | |
for (size_t i = 0; i < m; ++i) { | |
__m512 a = _mm512_mask_loadu_ps(_mm512_setzero_ps(), mask, A[i]); | |
__m512* __restrict c_ptr = C[i]; | |
if (__builtin_expect(p, 16)) { | |
#define EXPANDED(k) \ | |
*c_ptr = _mm512_fmadd_ps(a, b##k, *c_ptr); \ | |
c_ptr++; | |
#include "common/expand16.inl" | |
#undef EXPANDED | |
} else { | |
#define EXPANDED(k) \ | |
if (k < p) { \ | |
*c_ptr = _mm512_fmadd_ps(a, b##k, *c_ptr); \ | |
c_ptr++; \ | |
} | |
#include "common/expand16.inl" | |
#undef EXPANDED | |
} | |
} | |
} | |
// Tiled version of | |
// for (size_t j = 0; j < n; ++j) { sum += A[i][j] * B[j][k]; } | |
template <typename VectorT> | |
[[gnu::flatten]] inline void MultiplyMatrixSubtile(const Matrix<VectorT* __restrict>& C, | |
const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept { | |
constexpr size_t VECTOR_SIZE = sizeof(VectorT) / sizeof(float); | |
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B); | |
ASSUME(m <= VECTOR_SIZE); | |
ASSUME(p <= VECTOR_SIZE); | |
ASSUME(C.stride == VECTOR_SIZE); | |
size_t n_subtiles = CeilDiv<size_t>(n, VECTOR_SIZE); | |
for (size_t j_subtile = 0; j_subtile < n_subtiles; ++j_subtile) { | |
MultiplyMatrixSubtileIteration(C, | |
A.subMat(0, j_subtile * VECTOR_SIZE, A.height, VECTOR_SIZE), | |
B.subMat(0, j_subtile * VECTOR_SIZE, B.height, VECTOR_SIZE)); | |
} | |
} | |
template [[gnu::target("avx2,fma")]] void MultiplyMatrixSubtile<__m256>( | |
const Matrix<__m256* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept; | |
template [[gnu::target("avx512f,tune=cascadelake")]] void MultiplyMatrixSubtile<__m512>( | |
const Matrix<__m512* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept; | |
// Tiled version of whole matrix multiplication algorithm | |
// Split cache line sized tiles into vector sized tiles | |
template <typename VectorT> | |
inline void ImplMultiplyMatrixDispatchN(const Matrix<float* __restrict>& C, | |
const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept { | |
constexpr size_t VECTOR_SIZE = sizeof(VectorT) / sizeof(float); | |
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B); | |
ASSUME(m <= CACHE_LINE_FLOATS); | |
ASSUME(p <= CACHE_LINE_FLOATS); | |
// Reducing the vectors after every iteration is massively inefficient | |
// So store the whole vectors until the accumulation is finished | |
// then reduce them and store in C. | |
// Align to page boundary | |
alignas(4096) thread_local VectorT accumulation_scratch[VECTOR_SIZE * VECTOR_SIZE]; | |
const size_t p_subtiles = CeilDiv<size_t>(p, VECTOR_SIZE); | |
const size_t m_subtiles = CeilDiv<size_t>(m, VECTOR_SIZE); | |
for (size_t k_subtile = 0; k_subtile < p_subtiles; ++k_subtile) { | |
auto sub_b = B.subMat(k_subtile * VECTOR_SIZE, 0, VECTOR_SIZE, B.width); | |
for (size_t i_subtile = 0; i_subtile < m_subtiles; ++i_subtile) { | |
// Zero out scratch between runs | |
for (VectorT& vec : accumulation_scratch) { vec = VectorT{}; } | |
auto sub_a = A.subMat(i_subtile * VECTOR_SIZE, 0, VECTOR_SIZE, A.width); | |
MultiplyMatrixSubtile( | |
Matrix<VectorT* __restrict>{ | |
accumulation_scratch, | |
sub_a.height, | |
sub_b.height, | |
VECTOR_SIZE, | |
}, | |
sub_a, sub_b); | |
// Accumulate | |
auto sub_c = C.subMat(i_subtile * VECTOR_SIZE, k_subtile * VECTOR_SIZE, VECTOR_SIZE, | |
VECTOR_SIZE); | |
Accumulate(sub_c, Matrix<const VectorT*>{ | |
accumulation_scratch, | |
sub_a.height, | |
sub_b.height, | |
VECTOR_SIZE, | |
}); | |
} | |
} | |
} | |
template [[gnu::target("avx2,fma")]] void ImplMultiplyMatrixDispatchN<__m256>( | |
const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept; | |
template [[gnu::target("avx512f,tune=cascadelake")]] void ImplMultiplyMatrixDispatchN<__m512>( | |
const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept; | |
[[gnu::target("default")]] inline void MultiplyMatrixDispatchN( | |
[[maybe_unused]] const Matrix<float* __restrict>& C, | |
[[maybe_unused]] const Matrix<const float*>& A, | |
[[maybe_unused]] const Matrix<const float*>& B) noexcept { | |
assert(false); | |
} | |
[[gnu::target("avx2,fma")]] inline void MultiplyMatrixDispatchN( | |
const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept { | |
ImplMultiplyMatrixDispatchN<__m256>(C, A, B); | |
} | |
[[gnu::target("avx512f,tune=cascadelake")]] inline void MultiplyMatrixDispatchN( | |
const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept { | |
ImplMultiplyMatrixDispatchN<__m512>(C, A, B); | |
} | |
inline void MultiplyMatrixDispatchM(const Matrix<float* __restrict>& C, | |
const Matrix<const float*>& A, const Matrix<const float*>& B, | |
float* __restrict transposed_B_storage) noexcept { | |
const auto [m, n, p] = ValidateMultiplication(C, A, B); | |
ASSUME(p <= CACHE_LINE_FLOATS); | |
// gather instructions aren't very efficient so we don't want to transpose the same data more | |
// than once, which is why the matrix is split into columns first. | |
TransposeBlock( | |
Matrix<float* __restrict>{ | |
transposed_B_storage, | |
B.width, | |
B.height, | |
AlignUp(B.height, CACHE_LINE_FLOATS), | |
}, | |
B); | |
Matrix<const float*> transposed_B{ | |
transposed_B_storage, | |
B.width, | |
B.height, | |
AlignUp(B.height, CACHE_LINE_FLOATS), | |
}; | |
const size_t m_tiles = CeilDiv(m, CACHE_LINE_FLOATS); | |
for (size_t i_tile = 0; i_tile < m_tiles; ++i_tile) { | |
MultiplyMatrixDispatchN( | |
C.subMat(i_tile * CACHE_LINE_FLOATS, 0, CACHE_LINE_FLOATS, CACHE_LINE_FLOATS), | |
A.subMat(i_tile * CACHE_LINE_FLOATS, 0, CACHE_LINE_FLOATS, A.width), transposed_B); | |
} | |
} | |
} // namespace | |
// Hand optimized matrix multiplication algorithm that heavily takes advantage of | |
// SIMD and cache coherency to be over 100x faster than the simple algorithm. | |
// | |
// The tiling produces a 4.5 x speedup over SIMD on each individual row and column. | |
// This was mostly tested on an Alder Lake CPU with AVX2. | |
void MultiplyMatricesTiled(const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B, bool multithreaded) noexcept { | |
const auto [m, n, p] = ValidateMultiplication(C, A, B); | |
const size_t p_tiles = CeilDiv(p, CACHE_LINE_FLOATS); | |
if (!multithreaded) { | |
std::unique_ptr<float> transposed_B_storage{new (std::align_val_t{ | |
CACHE_LINE_SIZE}) float[AlignUp(B.height, CACHE_LINE_FLOATS) * CACHE_LINE_FLOATS]}; | |
for (size_t k_tile = 0; k_tile < p_tiles; ++k_tile) { | |
MultiplyMatrixDispatchM( | |
C.subMat(0, k_tile * CACHE_LINE_FLOATS, C.height, CACHE_LINE_FLOATS), A, | |
B.subMat(0, k_tile * CACHE_LINE_FLOATS, B.height, CACHE_LINE_FLOATS), | |
transposed_B_storage.get()); | |
} | |
return; | |
} | |
std::atomic<size_t> global_k_tile{0}; | |
const auto Work = [&] { | |
std::unique_ptr<float> transposed_B_storage{new (std::align_val_t{ | |
CACHE_LINE_SIZE}) float[AlignUp(B.height, CACHE_LINE_FLOATS) * CACHE_LINE_FLOATS]}; | |
while (true) { | |
const size_t k_tile = global_k_tile.fetch_add(1); | |
if (k_tile >= p_tiles) { break; } | |
MultiplyMatrixDispatchM( | |
C.subMat(0, k_tile * CACHE_LINE_FLOATS, C.height, CACHE_LINE_FLOATS), A, | |
B.subMat(0, k_tile * CACHE_LINE_FLOATS, B.height, CACHE_LINE_FLOATS), | |
transposed_B_storage.get()); | |
} | |
}; | |
// If there is multiple columns of tiles to process then start a thread for each column | |
// up to the number of cores | |
std::vector<std::future<void>> workers( | |
std::max<size_t>(0, std::min<size_t>(p_tiles, std::thread::hardware_concurrency()) - 1)); | |
for (auto& worker : workers) { worker = std::async(std::launch::async, Work); } | |
Work(); | |
for (const auto& worker : workers) { worker.wait(); } | |
} | |
void MultiplyMatricesNaive(const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept { | |
const auto [m, n, p] = ValidateMultiplication(C, A, B); | |
for (size_t i = 0; i < m; ++i) { | |
for (size_t k = 0; k < p; ++k) { | |
float sum = 0.0; | |
for (size_t j = 0; j < n; ++j) { sum += A[i][j] * B[j][k]; } | |
C[i][k] = sum; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include "common/utilities.hpp" | |
#include <type_traits> | |
template <typename T> | |
struct Matrix { | |
T data; | |
size_t height; | |
size_t width; | |
size_t stride; | |
[[gnu::always_inline]] Matrix(T data, size_t height, size_t width, size_t stride) | |
: data{data}, height{height}, width{width}, stride{stride} {}; | |
[[gnu::always_inline]] Matrix(T data, size_t height, size_t width) | |
: Matrix{data, height, width, width} {}; | |
[[gnu::always_inline]] constexpr size_t width_bytes() const { return width * sizeof(*data); } | |
[[gnu::always_inline]] constexpr size_t stride_bytes() const { return stride * sizeof(*data); } | |
[[gnu::always_inline]] constexpr T operator[](size_t y) const { return data + y * stride; } | |
[[gnu::always_inline]] constexpr Matrix<T> subMat(size_t y, size_t x, size_t h, | |
size_t w) const { | |
ASSUME(x <= width); | |
ASSUME(y <= height); | |
return Matrix<T>{(*this)[y] + x, min2(y + h, height) - y, min2(x + w, width) - x, stride}; | |
} | |
}; | |
struct ValidateMultiplicationRet { | |
size_t m; | |
size_t n; | |
size_t p; | |
}; | |
[[gnu::always_inline]] inline ValidateMultiplicationRet ValidateMultiplication( | |
const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) { | |
ASSUME(A.height == C.height); | |
ASSUME(A.width == B.height); | |
ASSUME(B.width == C.width); | |
return ValidateMultiplicationRet{C.height, A.width, C.width}; | |
} | |
template <typename T> | |
[[gnu::always_inline]] inline ValidateMultiplicationRet ValidateTransposedMultiplication( | |
const Matrix<T>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) { | |
ASSUME(A.height == C.height); | |
ASSUME(A.width == B.width); | |
ASSUME(B.height == C.width); | |
return ValidateMultiplicationRet{C.height, A.width, C.width}; | |
} | |
void MultiplyMatricesNaive(const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B) noexcept; | |
void MultiplyMatricesTiled(const Matrix<float* __restrict>& C, const Matrix<const float*>& A, | |
const Matrix<const float*>& B, bool multithreaded) noexcept; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment