Created
January 4, 2024 13:48
-
-
Save bjourne/c2d0db48b2e50aaadf884e4450c6aa50 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <time.h> | |
#include <xmmintrin.h> | |
#include <mmintrin.h> | |
#include "matrix.h" | |
#include "z_order.h" | |
using namespace nda; | |
// Make it easier to read the generated assembly for these functions. | |
#define NOINLINE __attribute__((noinline)) | |
// A textbook implementation of matrix multiplication. This is very simple, | |
// but it is slow, primarily because of poor locality of the loads of B. The | |
// reduction loop is innermost. | |
template <typename T> | |
NOINLINE void multiply_reduce_cols(const_matrix_ref<T> A, const_matrix_ref<T> B, matrix_ref<T> C) { | |
for (index_t i : C.i()) { | |
for (index_t j : C.j()) { | |
C(i, j) = 0; | |
for (index_t k : A.j()) { | |
C(i, j) += A(i, k) * B(k, j); | |
} | |
} | |
} | |
} | |
// This is similar to the above, but: | |
// - It additionally splits the reduction dimension k, | |
// - It traverses the io, jo loops in z order, to improve locality, | |
// - It prefetches in the inner loop. | |
// This version achieves ~90% of the theoretical peak performance of my AMD Ryzen 5800X. | |
template <typename T> | |
NOINLINE void multiply_reduce_tiles_z_order(const_matrix_ref<T> A, const_matrix_ref<T> B, matrix_ref<T> C) { | |
// Adjust this depending on the target architecture. For AVX2, | |
// vectors are 256-bit. | |
constexpr index_t vector_size = 32 / sizeof(T); | |
constexpr index_t cache_line_size = 64 / sizeof(T); | |
// We want the tiles to be as big as possible without spilling any | |
// of the accumulator registers to the stack. | |
constexpr index_t tile_rows = 4; | |
constexpr index_t tile_cols = vector_size * 3; | |
constexpr index_t tile_k = 256; | |
// TODO: It seems like z-ordering all of io, jo, ko should be best... | |
// But this seems better, even without the added convenience for initializing | |
// the output. | |
for (auto ko : split(A.j(), tile_k)) { | |
auto split_i = split<tile_rows>(C.i()); | |
auto split_j = split<tile_cols>(C.j()); | |
for_all_in_z_order(std::make_tuple(split_i, split_j), [&](auto io, auto jo) { | |
// Make a reference to this tile of the output. | |
auto C_ijo = C(io, jo); | |
// Define an accumulator buffer. | |
T buffer[tile_rows * tile_cols] = {0}; | |
auto accumulator = make_array_ref(buffer, make_compact(C_ijo.shape())); | |
// Perform the matrix multiplication for this tile. | |
for (index_t k : ko) { | |
for (index_t i = 0; i < io.extent(); i += cache_line_size) { | |
_mm_prefetch(&A(io.min() + i, k + 8), _MM_HINT_T0); | |
} | |
for (index_t j = 0; j < jo.extent(); j += cache_line_size) { | |
_mm_prefetch(&B(k + 4, jo.min() + j), _MM_HINT_T0); | |
} | |
for (index_t i : io) { | |
for (index_t j : jo) { | |
accumulator(i, j) += A(i, k) * B(k, j); | |
} | |
} | |
} | |
// Add the accumulators for this iteration of ko to the output. | |
// Because we split the K dimension, we are doing this more than once per | |
// tile of output. To avoid adding to overlapping regions more than once | |
// (when `split<>` is applied to a dimension not divided by the split factor), | |
// we need to only initialize the result for the first iteration of ko. | |
if (ko.min() == A.j().min()) { | |
for (index_t i : io) { | |
for (index_t j : jo) { | |
C_ijo(i, j) = accumulator(i, j); | |
} | |
} | |
} else { | |
for (index_t i : io) { | |
for (index_t j : jo) { | |
C_ijo(i, j) += accumulator(i, j); | |
} | |
} | |
} | |
}); | |
} | |
} | |
// 0.95s for NumPy | |
#define M 4096 | |
#define K 4096 | |
#define N 4096 | |
int | |
main() { | |
srand(time(NULL)); | |
matrix<double> A({M, K}); | |
matrix<double> B({K, N}); | |
for (size_t i = 0; i < M; i++) { | |
for (size_t j = 0; j < K; j++) { | |
A(i, j) = rand() % 500; | |
} | |
} | |
for (size_t i = 0; i < K; i++) { | |
for (size_t j = 0; j < N; j++) { | |
B(i, j) = rand() % 500; | |
} | |
} | |
printf("Multiplying...\n"); | |
matrix<double> C({M, N}); | |
// 37 s for float, 28 s for double | |
//multiply_reduce_tiles_z_order<double>(A.cref(), B.cref(), C.ref()); | |
multiply_reduce_cols<double>(A.cref(), B.cref(), C.ref()); | |
// for (size_t i = 0; i < 5; i++) { | |
// for (size_t j = 0; j < 8; j++) { | |
// A(i, j) = rand() % 500; | |
// } | |
// } | |
//printf("%d\n", m(0, 3)); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment