Created
October 15, 2024 20:59
-
-
Save Const-me/c61e836bed08cef2f06783c7b11b4e18 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
static const char* const sourceDataPath = R"(C:\Temp\2remove\vectors.csv)"; | |
#define _CRT_SECURE_NO_WARNINGS | |
#include <iostream> | |
#include <chrono> | |
#include <immintrin.h> | |
#include <assert.h> | |
using namespace std; | |
using namespace std::chrono; | |
constexpr int THREADS = 8; | |
constexpr int SIZE = 640000; | |
constexpr int EXECUTIONS = 1000; | |
inline float hadd( __m256 r ) | |
{ | |
__m128 r4 = _mm_add_ps( _mm256_castps256_ps128( r ), _mm256_extractf128_ps( r, 1 ) ); | |
__m128 r2 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) ); | |
__m128 r1 = _mm_add_ss( r2, _mm_movehdup_ps( r2 ) ); | |
return _mm_cvtss_f32( r1 ); | |
} | |
inline __m128 reduce3( __m256 a, __m256 b, __m256 c ) | |
{ | |
return _mm_setr_ps( hadd( a ), hadd( b ), hadd( c ), 0 ); | |
} | |
class Acc | |
{ | |
__m256 dot; | |
__m256 a2, b2; | |
public: | |
Acc() | |
{ | |
dot = _mm256_setzero_ps(); | |
a2 = _mm256_setzero_ps(); | |
b2 = _mm256_setzero_ps(); | |
} | |
__forceinline void add( const float* a, const float* b ) | |
{ | |
const __m256 v1 = _mm256_loadu_ps( a ); | |
const __m256 v2 = _mm256_loadu_ps( b ); | |
dot = _mm256_fmadd_ps( v1, v2, dot ); | |
a2 = _mm256_fmadd_ps( v1, v1, a2 ); | |
b2 = _mm256_fmadd_ps( v2, v2, b2 ); | |
} | |
__forceinline void operator +=( const Acc& that ) | |
{ | |
dot = _mm256_add_ps( dot, that.dot ); | |
a2 = _mm256_add_ps( a2, that.a2 ); | |
b2 = _mm256_add_ps( b2, that.b2 ); | |
} | |
__forceinline __m128 reduce() const | |
{ | |
return reduce3( a2, b2, dot ); | |
} | |
}; | |
inline __m128 cosine_similarity_batch( const float* A, const float* B, size_t length ) | |
{ | |
assert( 0 == length % 32 ); | |
// Each accumulator consumes 3 vectors so 12 total | |
// The instruction set defines 16 of them so we're good, need 2 extra for input vectors | |
Acc a0, a1, a2, a3; | |
const float* const endA = A + length; | |
while( A < endA ) | |
{ | |
a0.add( A, B ); | |
a1.add( A + 8, B + 8 ); | |
a2.add( A + 16, B + 16 ); | |
a3.add( A + 24, B + 24 ); | |
A += 32; | |
B += 32; | |
} | |
// Reduce vertically into a0 | |
a0 += a1; | |
a2 += a3; | |
a0 += a2; | |
// Reduce horizontally | |
return a0.reduce(); | |
} | |
static float cosine_similarity_omp( const float* A, const float* B ) | |
{ | |
static_assert( 0 == SIZE % ( THREADS * 32 ) ); | |
constexpr size_t batchSize = SIZE / THREADS; | |
__m128 results[ THREADS ]; | |
#pragma omp parallel for | |
for( int64_t i = 0; i < THREADS; i++ ) | |
{ | |
const size_t offset = i * batchSize; | |
results[ i ] = cosine_similarity_batch( A + offset, B + offset, batchSize ); | |
} | |
// Reduce across the threads | |
__m128 v = results[ 0 ]; | |
for( size_t i = 1; i < THREADS; i++ ) | |
v = _mm_add_ps( v, results[ i ] ); | |
// Compute v.z / ( sqrt( v.x ) * sqrt( v.y ) ) | |
const float mul = _mm_cvtss_f32( _mm_movehl_ps( v, v ) ); | |
v = _mm_sqrt_ps( v ); | |
v = _mm_mul_ss( v, _mm_movehdup_ps( v ) ); | |
return mul / _mm_cvtss_f32( v ); | |
} | |
namespace | |
{ | |
// Lazy VC++ compiler did whatever it could to optimize away the complete cosine_similarity_simd(), because the result is unused | |
// The following weird stuff is required to prevent these optimizations | |
static void const volatile* volatile globalPointer; | |
void useCharPointer( volatile const char* v ) | |
{ | |
globalPointer = reinterpret_cast<volatile const void*>( v ); | |
} | |
template <class Tp> | |
__forceinline void doNotOptimize( Tp const& value ) | |
{ | |
useCharPointer( &reinterpret_cast<char const volatile&>( value ) ); | |
_ReadWriteBarrier(); | |
} | |
} | |
int main() | |
{ | |
float* const A = (float*)_aligned_malloc( SIZE * 4, 32 ); | |
float* const B = (float*)_aligned_malloc( SIZE * 4, 32 ); | |
if( nullptr == A || nullptr == B ) | |
return -1; | |
FILE* fp; | |
fp = fopen( sourceDataPath, "r" ); | |
float a, b; | |
int i = 0; | |
while( fscanf( fp, "%g,%g\n", &a, &b ) == 2 ) | |
{ | |
A[ i ] = a; | |
B[ i ] = b; | |
i += 1; | |
} | |
duration<double, std::milli> duration; | |
double simd_accum = 0.0; | |
for( int i = 0; i < EXECUTIONS; i++ ) | |
{ | |
auto t1 = high_resolution_clock::now(); | |
const float result = cosine_similarity_omp( A, B ); | |
doNotOptimize( result ); | |
auto t2 = high_resolution_clock::now(); | |
duration = t2 - t1; | |
simd_accum += duration.count(); | |
} | |
std::cout << simd_accum / EXECUTIONS << " ms\n"; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment