Last active
March 1, 2023 22:16
-
-
Save Const-me/15111c4f9502b001eef31ffde4aa7770 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
// Compute product of width*16 column major matrix by vector of length `width`, | |
// the result is a vector of length 16 | |
// BTW, according to godbolt.org, gcc does better than clang for this code. | |
void multiplyInner_avx16( const float* mat, const float* vec, size_t width, float* rdi ) | |
{ | |
// Using 4 accumulators per row, 4*16=64 scalars in 8 AVX vectors | |
__m256 a00 = _mm256_setzero_ps(); | |
__m256 a01 = _mm256_setzero_ps(); | |
__m256 a10 = _mm256_setzero_ps(); | |
__m256 a11 = _mm256_setzero_ps(); | |
__m256 a20 = _mm256_setzero_ps(); | |
__m256 a21 = _mm256_setzero_ps(); | |
__m256 a30 = _mm256_setzero_ps(); | |
__m256 a31 = _mm256_setzero_ps(); | |
// Compute these products | |
constexpr size_t maskAlign4 = ~(size_t)3; | |
const float* const vecEndAligned = vec + ( width & maskAlign4 ); | |
while( vec < vecEndAligned ) | |
{ | |
// Each iteration of this loop consumes 4 elements from the vector, and 4 columns = 64 elements from the matrix | |
// Broadcast 4 elements from the vector | |
const __m256 v4 = _mm256_broadcast_ps( ( const __m128* )vec ); | |
vec += 4; | |
// Column #0 | |
__m256 v = _mm256_permute_ps( v4, _MM_SHUFFLE( 0, 0, 0, 0 ) ); | |
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 ); | |
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 ); | |
// Column #1 | |
v = _mm256_permute_ps( v4, _MM_SHUFFLE( 1, 1, 1, 1 ) ); | |
a10 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 16 ), a10 ); | |
a11 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 24 ), a11 ); | |
// Column #2 | |
v = _mm256_permute_ps( v4, _MM_SHUFFLE( 2, 2, 2, 2 ) ); | |
a20 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 32 ), a20 ); | |
a21 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 40 ), a21 ); | |
// Column #3 | |
v = _mm256_permute_ps( v4, _MM_SHUFFLE( 3, 3, 3, 3 ) ); | |
a30 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 48 ), a30 ); | |
a31 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 56 ), a31 ); | |
mat += 64; | |
} | |
// Handle the remainder | |
// The branches are predictable, same outcome every time this function is called | |
const size_t rem = width % 4; | |
if( rem == 1 ) | |
{ | |
// Column #0 | |
const __m256 v = _mm256_broadcast_ss( vec ); | |
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 ); | |
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 ); | |
} | |
else if( rem > 1 ) | |
{ | |
// Broadcast 2 elements from the vector | |
const __m256 v2 = _mm256_castpd_ps( _mm256_broadcast_sd( (const double*)vec ) ); | |
// Column #0 | |
__m256 v = _mm256_moveldup_ps( v2 ); | |
a00 = _mm256_fmadd_ps( v, _mm256_load_ps( mat ), a00 ); | |
a01 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 8 ), a01 ); | |
// Column #1 | |
v = _mm256_movehdup_ps( v2 ); | |
a10 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 16 ), a10 ); | |
a11 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 24 ), a11 ); | |
if( rem > 2 ) | |
{ | |
// Column #2 | |
v = _mm256_broadcast_ss( vec + 2 ); | |
a20 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 32 ), a20 ); | |
a21 = _mm256_fmadd_ps( v, _mm256_load_ps( mat + 40 ), a21 ); | |
} | |
} | |
// Reduce 64 accumulators to 32 | |
a00 = _mm256_add_ps( a00, a20 ); | |
a01 = _mm256_add_ps( a01, a21 ); | |
a10 = _mm256_add_ps( a10, a30 ); | |
a11 = _mm256_add_ps( a11, a31 ); | |
// Reduce 32 accumulators to 16 | |
a00 = _mm256_add_ps( a00, a10 ); | |
a01 = _mm256_add_ps( a01, a11 ); | |
// Finally, store the products | |
_mm256_store_ps( rdi, a00 ); | |
_mm256_store_ps( rdi + 8, a01 ); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment