Skip to content

Instantly share code, notes, and snippets.

@Const-me
Last active October 16, 2024 13:37
Show Gist options
  • Save Const-me/41b013229b20f920bcee22a856c8569f to your computer and use it in GitHub Desktop.
Save Const-me/41b013229b20f920bcee22a856c8569f to your computer and use it in GitHub Desktop.
static const char* const sourceDataPath = R"(C:\Temp\2remove\vectors.csv)";
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <chrono>
#include <immintrin.h>
using namespace std;
using namespace std::chrono;
constexpr int SIZE = 640000;
constexpr int EXECUTIONS = 1000;
inline float hadd( __m256 r )
{
__m128 r4 = _mm_add_ps( _mm256_castps256_ps128( r ), _mm256_extractf128_ps( r, 1 ) );
__m128 r2 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) );
__m128 r1 = _mm_add_ss( r2, _mm_movehdup_ps( r2 ) );
return _mm_cvtss_f32( r1 );
}
// sqrt( hadd( a ) ) * sqrt( hadd( b ) )
inline float reduce2( __m256 a, __m256 b )
{
// Add 2x8 into 2x4
__m128 a4 = _mm256_extractf128_ps( a, 1 );
__m128 b4 = _mm256_extractf128_ps( b, 1 );
a4 = _mm_add_ps( a4, _mm256_castps256_ps128( a ) );
b4 = _mm_add_ps( b4, _mm256_castps256_ps128( b ) );
// Add 2x4 into 2x2 in a single vector
const __m128 high = _mm_shuffle_ps( a4, b4, _MM_SHUFFLE( 1, 0, 3, 2 ) );
const __m128 low = _mm_blend_ps( a4, b4, 0b1100 );
__m128 res = _mm_add_ps( low, high );
// Add even/odd lanes
res = _mm_add_ps( res, _mm_movehdup_ps( res ) );
// Square roots
res = _mm_sqrt_ps( res );
// Return res.x * res.z
res = _mm_mul_ss( res, _mm_movehl_ps( res, res ) );
return _mm_cvtss_f32( res );
}
class Acc
{
__m256 dot;
__m256 a2, b2;
public:
Acc()
{
dot = _mm256_setzero_ps();
a2 = _mm256_setzero_ps();
b2 = _mm256_setzero_ps();
}
__forceinline void add( const float* a, const float* b )
{
const __m256 v1 = _mm256_loadu_ps( a );
const __m256 v2 = _mm256_loadu_ps( b );
dot = _mm256_fmadd_ps( v1, v2, dot );
a2 = _mm256_fmadd_ps( v1, v1, a2 );
b2 = _mm256_fmadd_ps( v2, v2, b2 );
}
__forceinline void operator +=( const Acc& that )
{
dot = _mm256_add_ps( dot, that.dot );
a2 = _mm256_add_ps( a2, that.a2 );
b2 = _mm256_add_ps( b2, that.b2 );
}
__forceinline float reduce() const
{
const float float_dot = hadd( dot );
const float div = reduce2( a2, b2 );
return float_dot / div;
}
};
static float cosine_similarity_opt( const float* A, const float* B )
{
// Each accumulator consumes 3 vectors so 12 total
// The instruction set defines 16 of them so we're good, need 2 extra for input vectors
Acc a0, a1, a2, a3;
static_assert( 0 == SIZE % 32 );
const float* const endA = A + SIZE;
while( A < endA )
{
a0.add( A, B );
a1.add( A + 8, B + 8 );
a2.add( A + 16, B + 16 );
a3.add( A + 24, B + 24 );
A += 32;
B += 32;
}
// Reduce vertically into a0
a0 += a1;
a2 += a3;
a0 += a2;
// Reduce horizontally
return a0.reduce();
}
namespace
{
// Lazy VC++ compiler did whatever it could to optimize away the complete cosine_similarity_simd(), because the result is unused
// The following weird stuff is required to prevent these optimizations
static void const volatile* volatile globalPointer;
void useCharPointer( volatile const char* v )
{
globalPointer = reinterpret_cast<volatile const void*>( v );
}
template <class Tp>
__forceinline void doNotOptimize( Tp const& value )
{
useCharPointer( &reinterpret_cast<char const volatile&>( value ) );
_ReadWriteBarrier();
}
}
int main()
{
float* const A = (float*)_aligned_malloc( SIZE * 4, 32 );
float* const B = (float*)_aligned_malloc( SIZE * 4, 32 );
if( nullptr == A || nullptr == B )
return -1;
FILE* fp;
fp = fopen( sourceDataPath, "r" );
float a, b;
int i = 0;
while( fscanf( fp, "%g,%g\n", &a, &b ) == 2 )
{
A[ i ] = a;
B[ i ] = b;
i += 1;
}
duration<double, std::milli> duration;
double simd_accum = 0.0;
for( int i = 0; i < EXECUTIONS; i++ )
{
auto t1 = high_resolution_clock::now();
const float result = cosine_similarity_opt( &A[ 0 ], &B[ 0 ] );
doNotOptimize( result );
auto t2 = high_resolution_clock::now();
duration = t2 - t1;
simd_accum += duration.count();
}
std::cout << simd_accum / EXECUTIONS << " ms\n";
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment