Last active
July 7, 2023 16:16
-
-
Save Const-me/3ade77faad47f0fbb0538965ae7f8e04 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdint.h> | |
#include <immintrin.h> | |
#include <intrin.h> | |
#include <stdio.h> | |
// Count of set bits in `plus` minus count of set bits in `minus` | |
// The result is in [ -32 .. +32 ] interval | |
inline int popCntDiff( uint32_t plus, uint32_t minus ) | |
{ | |
plus = __popcnt( plus ); | |
minus = __popcnt( minus ); | |
return (int)plus - (int)minus; | |
} | |
// Horizontal sum of all 4 int64_t elements in the AVX2 vector | |
inline int64_t hadd_epi64( __m256i v32 ) | |
{ | |
__m128i v = _mm256_extracti128_si256( v32, 1 ); | |
v = _mm_add_epi64( v, _mm256_castsi256_si128( v32 ) ); | |
const int64_t high = _mm_extract_epi64( v, 1 ); | |
const int64_t low = _mm_cvtsi128_si64( v ); | |
return high + low; | |
} | |
// AVX2 implementation of that algorithm | |
int computeWithAvx2( const char* input ) | |
{ | |
// Create a few constants | |
const __m256i zero = _mm256_setzero_si256(); | |
const __m256i s = _mm256_set1_epi8( 's' ); | |
const __m256i p = _mm256_set1_epi8( 'p' ); | |
__m256i counter; | |
// Prologue, make sure the pointer is aligned by 32 bytes | |
const size_t rem = ( (size_t)input ) % 32; | |
if( 0 == rem ) | |
{ | |
// The input already aligned, initialize accumulator with 0 | |
counter = _mm256_setzero_si256(); | |
} | |
else | |
{ | |
// Load aligned vector from the address before the input buffer | |
// Same VMEM page as the first byte of the input, no access violations | |
input -= rem; | |
const __m256i v = _mm256_load_si256( ( const __m256i* )input ); | |
uint32_t bmpZero = (uint32_t)_mm256_movemask_epi8( _mm256_cmpeq_epi8( v, zero ) ); | |
uint32_t bmpPlus = (uint32_t)_mm256_movemask_epi8( _mm256_cmpeq_epi8( v, s ) ); | |
uint32_t bmpMinus = (uint32_t)_mm256_movemask_epi8( _mm256_cmpeq_epi8( v, p ) ); | |
// Discard lower bits resulted from loading before the start of the buffer | |
bmpZero >>= (uint32_t)rem; | |
bmpPlus >>= (uint32_t)rem; | |
bmpMinus >>= (uint32_t)rem; | |
if( 0 == bmpZero ) | |
{ | |
// No `\0` encountered in the initial bytes of the input | |
// Compute initial value for the accumulator vector | |
__m128i iv = _mm_cvtsi64_si128( popCntDiff( bmpPlus, bmpMinus ) * 0xFF ); | |
counter = _mm256_blend_epi32( zero, _mm256_castsi128_si256( iv ), 0b00000011 ); | |
input += 32; | |
} | |
else | |
{ | |
// The input was tiny, found `\0` already | |
// Clear higher bits in the two bitmaps which were after the first `\0` | |
const uint32_t len = _tzcnt_u32( bmpZero ); | |
bmpPlus = _bzhi_u32( bmpPlus, len ); | |
bmpMinus = _bzhi_u32( bmpMinus, len ); | |
// Compute the result | |
return popCntDiff( bmpPlus, bmpMinus ); | |
} | |
} | |
// The pointer is aligned by 32 bytes, which serves two purposes: we can use aligned loads, | |
// and most importantly loading 32 bytes guarantees to not cross page boundary. | |
// VMEM permissions are defined for aligned 4kb pages, we can technically load within a page without access violations, | |
// despite the language standard says it's UB | |
while( true ) | |
{ | |
// Load 32 bytes from the pointer | |
const __m256i v = _mm256_load_si256( ( const __m256i* )input ); | |
// Compare bytes for v == '\0' | |
const __m256i z = _mm256_cmpeq_epi8( v, zero ); | |
// Compare bytes for equality with these two other markers | |
__m256i cmpPlus = _mm256_cmpeq_epi8( v, s ); | |
__m256i cmpMinus = _mm256_cmpeq_epi8( v, p ); | |
const uint32_t bmpZero = (uint32_t)_mm256_movemask_epi8( z ); | |
if( 0 != bmpZero ) | |
{ | |
// At least one byte of the 32 was zero | |
const int res = (int)( hadd_epi64( counter ) / 0xFF ); | |
uint32_t bmpPlus = (uint32_t)_mm256_movemask_epi8( cmpPlus ); | |
uint32_t bmpMinus = (uint32_t)_mm256_movemask_epi8( cmpMinus ); | |
// Clear higher bits in the two bitmaps which were after the first found `\0` | |
const uint32_t len = _tzcnt_u32( bmpZero ); | |
bmpPlus = _bzhi_u32( bmpPlus, len ); | |
bmpMinus = _bzhi_u32( bmpMinus, len ); | |
// Produce the result | |
return res + popCntDiff( bmpPlus, bmpMinus ); | |
} | |
// Increment the source pointer | |
input += 32; | |
// Compute horizontal sum of bytes within 8-byte lanes | |
cmpPlus = _mm256_sad_epu8( cmpPlus, zero ); | |
cmpMinus = _mm256_sad_epu8( cmpMinus, zero ); | |
cmpPlus = _mm256_sub_epi64( cmpPlus, cmpMinus ); | |
// Update the counter | |
counter = _mm256_add_epi64( counter, cmpPlus ); | |
} | |
} | |
#include <vector> | |
#include <random> | |
std::vector<char> nullTerminatedRandom( size_t length ) | |
{ | |
std::vector<char> result; | |
result.resize( length ); | |
// Deliberately seeding RNG with 0, to generate same output every time | |
std::mt19937 gen( 0 ); | |
std::uniform_int_distribution<size_t> distrib( 0, 4 ); | |
const char pattern[ 4 ]{ 's', 'p', '0', '1' }; | |
for( char& c : result ) | |
c = pattern[ distrib( gen ) ]; | |
// Write terminating `\0` into the last element | |
result[ length - 1 ] = '\0'; | |
return result; | |
} | |
int computeWithSwitches( const char* input ) | |
{ | |
int res = 0; | |
while( true ) | |
{ | |
char c = *input++; | |
switch( c ) | |
{ | |
case '\0': | |
return res; | |
case 's': | |
res += 1; | |
break; | |
case 'p': | |
res -= 1; | |
break; | |
default: | |
break; | |
} | |
} | |
} | |
int main() | |
{ | |
constexpr bool useAvx = true; | |
// Using odd length slightly over 1GB, just for lulz | |
const size_t len = 1024 * 1024 * 1024 + 17; | |
const auto data = nullTerminatedRandom( len ); | |
const char* const rsi = data.data(); | |
// Compute the result using either of these two methods, measuring the time | |
const int64_t tscStart = __rdtsc(); | |
const int sum = useAvx ? computeWithAvx2( rsi ) : computeWithSwitches( rsi ); | |
const int64_t tscElapsed = __rdtsc() - tscStart; | |
printf( "%i; elapsed time: %lli\n", sum, tscElapsed ); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment