Skip to content

Instantly share code, notes, and snippets.

@BreadFish64
Created March 30, 2023 02:17
Show Gist options
  • Save BreadFish64/55a223451f9e6b49044d675f1e11c542 to your computer and use it in GitHub Desktop.
Save BreadFish64/55a223451f9e6b49044d675f1e11c542 to your computer and use it in GitHub Desktop.
Matrix multiplication
#include "matrix.hpp"
#include <immintrin.h>
#include <cstring>
#include <future>
#include <vector>
// Anonymous namespace for internal linkage
namespace {
// For generating masks and gather indices
#define INC_VECTOR_256_EPI32 _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0)
#define INC_VECTOR_512_EPI32 _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
// Makes a mask vector used by AVX2 for the limited number of instructions that support it
[[gnu::always_inline, gnu::target("avx2,fma")]] inline auto MaskFirst256EPI32(int n) noexcept {
const auto n8 = _mm256_set1_epi32(n);
return _mm256_cmpgt_epi32(n8, INC_VECTOR_256_EPI32);
}
// _mm256_mask_i32gather_ps wants the mask as a float vector
// But the bit representation is the same either way
[[gnu::always_inline, gnu::target("avx2,fma")]] inline auto SIasPS(__m256i x) noexcept {
__m256 y;
std::memcpy(&y, &x, sizeof(y));
return y;
}
// Masks have their own registers and are allowed for pretty much everything in AVX512
[[gnu::always_inline, gnu::target("avx512f,tune=cascadelake")]] inline auto MaskFirstMask16(
int n) noexcept {
return _cvtu32_mask16((1u << n) - 1u);
}
// Didn't feel like supporting pre-haswell processors
[[gnu::target("default")]] inline void TransposeBlock(
[[maybe_unused]] const Matrix<float* __restrict>& dst,
[[maybe_unused]] const Matrix<const float*>& src) noexcept {
assert(false);
}
// Takes several columns from the B vector and rotates them so they can be accessed more efficiently
[[gnu::target("avx2,fma")]] inline void TransposeBlock(const Matrix<float* __restrict>& dst,
const Matrix<const float*>& src) noexcept {
ASSUME(dst.width == src.height);
ASSUME(src.width == dst.height);
const size_t width = dst.width;
const size_t height = dst.height;
ASSUME(height <= CACHE_LINE_FLOATS);
const size_t full_width = width / 8;
const size_t leftover = width % 8;
const auto gather_offsets =
_mm256_mullo_epi32(INC_VECTOR_256_EPI32, _mm256_set1_epi32(src.stride));
size_t x_tile = 0;
for ( ;x_tile < full_width; ++x_tile) {
for (size_t y = 0; y < height; ++y) {
// Gather loads each element from a separate offset
// So we can load columns by using multiple-of-stride offsets
auto v = _mm256_i32gather_ps(&src[x_tile * 8][y], gather_offsets, sizeof(float));
_mm256_store_ps(&dst[y][x_tile * 8], v);
}
}
if (leftover) {
const auto mask = MaskFirst256EPI32(leftover);
const auto fmask = SIasPS(mask);
for (size_t y = 0; y < height; ++y) {
const auto v = _mm256_mask_i32gather_ps(_mm256_setzero_ps(), &src[x_tile * 8][y],
gather_offsets, fmask, sizeof(float));
// Just assume that we have enough padding to not mask the stores
// Since the memory allocation for the transposed matrix was done internally
_mm256_store_ps(&dst[y][x_tile * 8], v);
}
}
}
[[gnu::target("avx512f")]] inline void TransposeBlock(
const Matrix<float* __restrict>& dst, const Matrix<const float*>& src) noexcept {
ASSUME(dst.width == src.height);
ASSUME(src.width == dst.height);
const size_t width = dst.width;
const size_t height = dst.height;
ASSUME(height <= CACHE_LINE_FLOATS);
const size_t full_width = width / 16;
const size_t leftover = width % 16;
const auto gather_offsets =
_mm512_mullo_epi32(INC_VECTOR_512_EPI32, _mm512_set1_epi32(src.stride));
size_t x_tile = 0;
for (; x_tile < full_width; ++x_tile) {
for (size_t y = 0; y < height; ++y) {
auto v = _mm512_i32gather_ps(gather_offsets, &src[x_tile * 16][y], sizeof(float));
_mm512_store_ps(&dst[y][x_tile * 16], v);
}
}
if (leftover) {
const auto mask = MaskFirstMask16(leftover);
for (size_t y = 0; y < height; ++y) {
const auto v = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, gather_offsets,
&src[x_tile * 16][y], sizeof(float));
_mm512_store_ps(&dst[y][x_tile * 16], v);
}
}
}
[[gnu::target("avx2,fma")]] inline void AccumulateChunk(const Matrix<float* __restrict>& dst,
const Matrix<const __m256*>& src) noexcept {
ASSUME(dst.width == src.width);
ASSUME(dst.height == src.height);
ASSUME(dst.width <= 8);
ASSUME(dst.height == 1);
const size_t width = dst.width;
auto s0 = src[0][0];
auto s1 = src[0][1];
auto s2 = src[0][2];
auto s3 = src[0][3];
auto s4 = src[0][4];
auto s5 = src[0][5];
auto s6 = src[0][6];
auto s7 = src[0][7];
s0 = _mm256_hadd_ps(s0, s1);
s1 = _mm256_hadd_ps(s2, s3);
s2 = _mm256_hadd_ps(s4, s5);
s3 = _mm256_hadd_ps(s6, s7);
s0 = _mm256_hadd_ps(s0, s1);
s1 = _mm256_hadd_ps(s2, s3);
auto t0 = _mm256_extractf128_ps(s0, 1);
auto t1 = _mm256_extractf128_ps(s1, 0);
s0 = _mm256_insertf128_ps(s0, t1, 1);
s1 = _mm256_insertf128_ps(s1, t0, 0);
s0 = _mm256_add_ps(s0, s1);
const auto mask = MaskFirst256EPI32(width);
_mm256_maskstore_ps(dst.data, mask, s0);
}
[[gnu::target("avx512f")]] inline void AccumulateChunk(
const Matrix<float* __restrict>& dst, const Matrix<const __m512*>& src) noexcept {
ASSUME(dst.width == src.width);
ASSUME(dst.height == src.height);
ASSUME(dst.width <= 16);
ASSUME(dst.height == 1);
const size_t width = dst.width;
// There's no 512 bit hadd, and extracting 256 bit float vector from __m512 is a pain
const __m256* const src_ptr = reinterpret_cast<const __m256*>(src.data);
auto s0 = _mm256_add_ps(src_ptr[0], src_ptr[1]);
auto s1 = _mm256_add_ps(src_ptr[2], src_ptr[3]);
auto s2 = _mm256_add_ps(src_ptr[4], src_ptr[5]);
auto s3 = _mm256_add_ps(src_ptr[6], src_ptr[7]);
auto s4 = _mm256_add_ps(src_ptr[8], src_ptr[9]);
auto s5 = _mm256_add_ps(src_ptr[10], src_ptr[11]);
auto s6 = _mm256_add_ps(src_ptr[12], src_ptr[13]);
auto s7 = _mm256_add_ps(src_ptr[14], src_ptr[15]);
auto s8 = _mm256_add_ps(src_ptr[16], src_ptr[17]);
auto s9 = _mm256_add_ps(src_ptr[18], src_ptr[19]);
auto s10 = _mm256_add_ps(src_ptr[20], src_ptr[21]);
auto s11 = _mm256_add_ps(src_ptr[22], src_ptr[23]);
auto s12 = _mm256_add_ps(src_ptr[24], src_ptr[25]);
auto s13 = _mm256_add_ps(src_ptr[26], src_ptr[27]);
auto s14 = _mm256_add_ps(src_ptr[28], src_ptr[29]);
auto s15 = _mm256_add_ps(src_ptr[30], src_ptr[31]);
s0 = _mm256_hadd_ps(s0, s1);
s1 = _mm256_hadd_ps(s2, s3);
s2 = _mm256_hadd_ps(s4, s5);
s3 = _mm256_hadd_ps(s6, s7);
s4 = _mm256_hadd_ps(s8, s9);
s5 = _mm256_hadd_ps(s10, s11);
s6 = _mm256_hadd_ps(s12, s13);
s7 = _mm256_hadd_ps(s14, s15);
s0 = _mm256_hadd_ps(s0, s1);
s1 = _mm256_hadd_ps(s2, s3);
s2 = _mm256_hadd_ps(s4, s5);
s3 = _mm256_hadd_ps(s6, s7);
__m512 t0 = _mm512_setzero_ps();
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s0, 0), 0);
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s1, 0), 1);
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s2, 0), 2);
t0 = _mm512_insertf32x4(t0, _mm256_extractf128_ps(s3, 0), 3);
__m512 t1 = _mm512_setzero_ps();
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s0, 1), 0);
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s1, 1), 1);
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s2, 1), 2);
t1 = _mm512_insertf32x4(t1, _mm256_extractf128_ps(s3, 1), 3);
auto result = _mm512_add_ps(t0, t1);
const auto mask = MaskFirstMask16(width);
_mm512_mask_storeu_ps(dst.data, mask, result);
}
// Reduce all of the temporary accumulator vectors to single floats for the result
template <typename VectorT>
inline void Accumulate(const Matrix<float* __restrict>& dst,
const Matrix<const VectorT*>& src) noexcept {
constexpr size_t VECTOR_SIZE = sizeof(VectorT) / sizeof(float);
ASSUME(dst.width == src.width);
ASSUME(dst.height == src.height);
ASSUME(dst.width <= CACHE_LINE_FLOATS);
ASSUME(dst.height <= CACHE_LINE_FLOATS);
const size_t height = dst.height;
const size_t width = dst.width;
const size_t subrows = CeilDiv(width, VECTOR_SIZE);
for (size_t y = 0; y < height; ++y) {
for (size_t s = 0; s < subrows; ++s) {
AccumulateChunk(dst.subMat(y, s * VECTOR_SIZE, 1, VECTOR_SIZE),
src.subMat(y, s * VECTOR_SIZE, 1, VECTOR_SIZE));
}
}
}
template [[gnu::target("avx2,fma")]] void Accumulate<__m256>(
const Matrix<float* __restrict>& dst, const Matrix<const __m256*>& src) noexcept;
template [[gnu::target("avx512f,tune=cascadelake")]] void Accumulate<__m512>(
const Matrix<float* __restrict>& dst, const Matrix<const __m512*>& src) noexcept;
// Implements C[i][k] += A[i][j] * B[j][k];
// But for a VECTOR_SIZE x VECTOR_SIZE block
[[gnu::target("avx2,fma")]] inline void MultiplyMatrixSubtileIteration(
const Matrix<__m256* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept {
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B);
ASSUME(m <= 8);
ASSUME(n <= 8);
ASSUME(p <= 8);
ASSUME(C.stride == 8);
// Load 8x8 floats from B into YMM
// Compilers aren't good at doing this with arrays;
#define EXPANDED(k) __m256 b##k;
#include "common/expand8.inl"
#undef EXPANDED
// Don't mask since we took care of zeroing the extras during the transpose
if (__builtin_expect(p, 8)) {
#define EXPANDED(k) b##k = _mm256_load_ps(B[k]);
#include "common/expand8.inl"
#undef EXPANDED
} else {
#define EXPANDED(k) \
if (k < p) { b##k = _mm256_load_ps(B[k]); }
#include "common/expand8.inl"
#undef EXPANDED
}
auto mask = MaskFirst256EPI32(n);
for (size_t i = 0; i < m; ++i) {
__m256 a = _mm256_maskload_ps(A[i], mask);
__m256* __restrict c_ptr = C[i];
if (__builtin_expect(p, 8)) {
#define EXPANDED(k) \
*c_ptr = _mm256_fmadd_ps(a, b##k, *c_ptr); \
c_ptr++;
#include "common/expand8.inl"
#undef EXPANDED
} else {
// I experimented with unrolling using a Duff's device but the compiler doesn't generate
// the code you would expect. Indirect branch prediction is worse anyway.
#define EXPANDED(k) \
if (k < p) { \
*c_ptr = _mm256_fmadd_ps(a, b##k, *c_ptr); \
c_ptr++; \
}
#include "common/expand8.inl"
#undef EXPANDED
}
}
}
[[gnu::target("avx512f,tune=cascadelake")]] inline void MultiplyMatrixSubtileIteration(
const Matrix<__m512* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept {
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B);
ASSUME(m <= 16);
ASSUME(n <= 16);
ASSUME(p <= 16);
ASSUME(C.stride == 16);
// Load 16x16 floats from B into ZMM
// Compilers aren't good at doing this with arrays;
#define EXPANDED(k) __m512 b##k;
#include "common/expand16.inl"
#undef EXPANDED
if (__builtin_expect(p, 16)) {
#define EXPANDED(k) b##k = _mm512_load_ps(B[k]);
#include "common/expand16.inl"
#undef EXPANDED
} else {
#define EXPANDED(k) \
if (k < p) { b##k = _mm512_load_ps(B[k]); }
#include "common/expand16.inl"
#undef EXPANDED
}
auto mask = MaskFirstMask16(n);
for (size_t i = 0; i < m; ++i) {
__m512 a = _mm512_mask_loadu_ps(_mm512_setzero_ps(), mask, A[i]);
__m512* __restrict c_ptr = C[i];
if (__builtin_expect(p, 16)) {
#define EXPANDED(k) \
*c_ptr = _mm512_fmadd_ps(a, b##k, *c_ptr); \
c_ptr++;
#include "common/expand16.inl"
#undef EXPANDED
} else {
#define EXPANDED(k) \
if (k < p) { \
*c_ptr = _mm512_fmadd_ps(a, b##k, *c_ptr); \
c_ptr++; \
}
#include "common/expand16.inl"
#undef EXPANDED
}
}
}
// Tiled version of
// for (size_t j = 0; j < n; ++j) { sum += A[i][j] * B[j][k]; }
template <typename VectorT>
[[gnu::flatten]] inline void MultiplyMatrixSubtile(const Matrix<VectorT* __restrict>& C,
const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept {
constexpr size_t VECTOR_SIZE = sizeof(VectorT) / sizeof(float);
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B);
ASSUME(m <= VECTOR_SIZE);
ASSUME(p <= VECTOR_SIZE);
ASSUME(C.stride == VECTOR_SIZE);
size_t n_subtiles = CeilDiv<size_t>(n, VECTOR_SIZE);
for (size_t j_subtile = 0; j_subtile < n_subtiles; ++j_subtile) {
MultiplyMatrixSubtileIteration(C,
A.subMat(0, j_subtile * VECTOR_SIZE, A.height, VECTOR_SIZE),
B.subMat(0, j_subtile * VECTOR_SIZE, B.height, VECTOR_SIZE));
}
}
template [[gnu::target("avx2,fma")]] void MultiplyMatrixSubtile<__m256>(
const Matrix<__m256* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept;
template [[gnu::target("avx512f,tune=cascadelake")]] void MultiplyMatrixSubtile<__m512>(
const Matrix<__m512* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept;
// Tiled version of whole matrix multiplication algorithm
// Split cache line sized tiles into vector sized tiles
template <typename VectorT>
inline void ImplMultiplyMatrixDispatchN(const Matrix<float* __restrict>& C,
const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept {
constexpr size_t VECTOR_SIZE = sizeof(VectorT) / sizeof(float);
const auto [m, n, p] = ValidateTransposedMultiplication(C, A, B);
ASSUME(m <= CACHE_LINE_FLOATS);
ASSUME(p <= CACHE_LINE_FLOATS);
// Reducing the vectors after every iteration is massively inefficient
// So store the whole vectors until the accumulation is finished
// then reduce them and store in C.
// Align to page boundary
alignas(4096) thread_local VectorT accumulation_scratch[VECTOR_SIZE * VECTOR_SIZE];
const size_t p_subtiles = CeilDiv<size_t>(p, VECTOR_SIZE);
const size_t m_subtiles = CeilDiv<size_t>(m, VECTOR_SIZE);
for (size_t k_subtile = 0; k_subtile < p_subtiles; ++k_subtile) {
auto sub_b = B.subMat(k_subtile * VECTOR_SIZE, 0, VECTOR_SIZE, B.width);
for (size_t i_subtile = 0; i_subtile < m_subtiles; ++i_subtile) {
// Zero out scratch between runs
for (VectorT& vec : accumulation_scratch) { vec = VectorT{}; }
auto sub_a = A.subMat(i_subtile * VECTOR_SIZE, 0, VECTOR_SIZE, A.width);
MultiplyMatrixSubtile(
Matrix<VectorT* __restrict>{
accumulation_scratch,
sub_a.height,
sub_b.height,
VECTOR_SIZE,
},
sub_a, sub_b);
// Accumulate
auto sub_c = C.subMat(i_subtile * VECTOR_SIZE, k_subtile * VECTOR_SIZE, VECTOR_SIZE,
VECTOR_SIZE);
Accumulate(sub_c, Matrix<const VectorT*>{
accumulation_scratch,
sub_a.height,
sub_b.height,
VECTOR_SIZE,
});
}
}
}
template [[gnu::target("avx2,fma")]] void ImplMultiplyMatrixDispatchN<__m256>(
const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept;
template [[gnu::target("avx512f,tune=cascadelake")]] void ImplMultiplyMatrixDispatchN<__m512>(
const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept;
[[gnu::target("default")]] inline void MultiplyMatrixDispatchN(
[[maybe_unused]] const Matrix<float* __restrict>& C,
[[maybe_unused]] const Matrix<const float*>& A,
[[maybe_unused]] const Matrix<const float*>& B) noexcept {
assert(false);
}
[[gnu::target("avx2,fma")]] inline void MultiplyMatrixDispatchN(
const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept {
ImplMultiplyMatrixDispatchN<__m256>(C, A, B);
}
[[gnu::target("avx512f,tune=cascadelake")]] inline void MultiplyMatrixDispatchN(
const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept {
ImplMultiplyMatrixDispatchN<__m512>(C, A, B);
}
inline void MultiplyMatrixDispatchM(const Matrix<float* __restrict>& C,
const Matrix<const float*>& A, const Matrix<const float*>& B,
float* __restrict transposed_B_storage) noexcept {
const auto [m, n, p] = ValidateMultiplication(C, A, B);
ASSUME(p <= CACHE_LINE_FLOATS);
// gather instructions aren't very efficient so we don't want to transpose the same data more
// than once, which is why the matrix is split into columns first.
TransposeBlock(
Matrix<float* __restrict>{
transposed_B_storage,
B.width,
B.height,
AlignUp(B.height, CACHE_LINE_FLOATS),
},
B);
Matrix<const float*> transposed_B{
transposed_B_storage,
B.width,
B.height,
AlignUp(B.height, CACHE_LINE_FLOATS),
};
const size_t m_tiles = CeilDiv(m, CACHE_LINE_FLOATS);
for (size_t i_tile = 0; i_tile < m_tiles; ++i_tile) {
MultiplyMatrixDispatchN(
C.subMat(i_tile * CACHE_LINE_FLOATS, 0, CACHE_LINE_FLOATS, CACHE_LINE_FLOATS),
A.subMat(i_tile * CACHE_LINE_FLOATS, 0, CACHE_LINE_FLOATS, A.width), transposed_B);
}
}
} // namespace
// Hand optimized matrix multiplication algorithm that heavily takes advantage of
// SIMD and cache coherency to be over 100x faster than the simple algorithm.
//
// The tiling produces a 4.5 x speedup over SIMD on each individual row and column.
// This was mostly tested on an Alder Lake CPU with AVX2.
void MultiplyMatricesTiled(const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B, bool multithreaded) noexcept {
const auto [m, n, p] = ValidateMultiplication(C, A, B);
const size_t p_tiles = CeilDiv(p, CACHE_LINE_FLOATS);
if (!multithreaded) {
std::unique_ptr<float> transposed_B_storage{new (std::align_val_t{
CACHE_LINE_SIZE}) float[AlignUp(B.height, CACHE_LINE_FLOATS) * CACHE_LINE_FLOATS]};
for (size_t k_tile = 0; k_tile < p_tiles; ++k_tile) {
MultiplyMatrixDispatchM(
C.subMat(0, k_tile * CACHE_LINE_FLOATS, C.height, CACHE_LINE_FLOATS), A,
B.subMat(0, k_tile * CACHE_LINE_FLOATS, B.height, CACHE_LINE_FLOATS),
transposed_B_storage.get());
}
return;
}
std::atomic<size_t> global_k_tile{0};
const auto Work = [&] {
std::unique_ptr<float> transposed_B_storage{new (std::align_val_t{
CACHE_LINE_SIZE}) float[AlignUp(B.height, CACHE_LINE_FLOATS) * CACHE_LINE_FLOATS]};
while (true) {
const size_t k_tile = global_k_tile.fetch_add(1);
if (k_tile >= p_tiles) { break; }
MultiplyMatrixDispatchM(
C.subMat(0, k_tile * CACHE_LINE_FLOATS, C.height, CACHE_LINE_FLOATS), A,
B.subMat(0, k_tile * CACHE_LINE_FLOATS, B.height, CACHE_LINE_FLOATS),
transposed_B_storage.get());
}
};
// If there is multiple columns of tiles to process then start a thread for each column
// up to the number of cores
std::vector<std::future<void>> workers(
std::max<size_t>(0, std::min<size_t>(p_tiles, std::thread::hardware_concurrency()) - 1));
for (auto& worker : workers) { worker = std::async(std::launch::async, Work); }
Work();
for (const auto& worker : workers) { worker.wait(); }
}
void MultiplyMatricesNaive(const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept {
const auto [m, n, p] = ValidateMultiplication(C, A, B);
for (size_t i = 0; i < m; ++i) {
for (size_t k = 0; k < p; ++k) {
float sum = 0.0;
for (size_t j = 0; j < n; ++j) { sum += A[i][j] * B[j][k]; }
C[i][k] = sum;
}
}
}
#pragma once
#include "common/utilities.hpp"
#include <type_traits>
template <typename T>
struct Matrix {
T data;
size_t height;
size_t width;
size_t stride;
[[gnu::always_inline]] Matrix(T data, size_t height, size_t width, size_t stride)
: data{data}, height{height}, width{width}, stride{stride} {};
[[gnu::always_inline]] Matrix(T data, size_t height, size_t width)
: Matrix{data, height, width, width} {};
[[gnu::always_inline]] constexpr size_t width_bytes() const { return width * sizeof(*data); }
[[gnu::always_inline]] constexpr size_t stride_bytes() const { return stride * sizeof(*data); }
[[gnu::always_inline]] constexpr T operator[](size_t y) const { return data + y * stride; }
[[gnu::always_inline]] constexpr Matrix<T> subMat(size_t y, size_t x, size_t h,
size_t w) const {
ASSUME(x <= width);
ASSUME(y <= height);
return Matrix<T>{(*this)[y] + x, min2(y + h, height) - y, min2(x + w, width) - x, stride};
}
};
struct ValidateMultiplicationRet {
size_t m;
size_t n;
size_t p;
};
[[gnu::always_inline]] inline ValidateMultiplicationRet ValidateMultiplication(
const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) {
ASSUME(A.height == C.height);
ASSUME(A.width == B.height);
ASSUME(B.width == C.width);
return ValidateMultiplicationRet{C.height, A.width, C.width};
}
template <typename T>
[[gnu::always_inline]] inline ValidateMultiplicationRet ValidateTransposedMultiplication(
const Matrix<T>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) {
ASSUME(A.height == C.height);
ASSUME(A.width == B.width);
ASSUME(B.height == C.width);
return ValidateMultiplicationRet{C.height, A.width, C.width};
}
void MultiplyMatricesNaive(const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B) noexcept;
void MultiplyMatricesTiled(const Matrix<float* __restrict>& C, const Matrix<const float*>& A,
const Matrix<const float*>& B, bool multithreaded) noexcept;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment