Created
March 11, 2023 15:17
-
-
Save Const-me/65ff46c31553493d13fcd6646e162494 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// ==== AVX2 decompressor for Q4_0 and Q4_1 compressed blocks ==== | |
#include <array> | |
#include <immintrin.h> | |
#include <assert.h> | |
#include <float.h> | |
// Unpack 32 4-bit fields into 32 bytes | |
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval | |
inline __m256i bytesFromNibbles( const uint8_t* rsi ) | |
{ | |
// Load 16 bytes from memory | |
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi ); | |
// Expand bytes into uint16_t values | |
__m256i bytes = _mm256_cvtepu8_epi16( tmp ); | |
// Unpack values into individual bytes | |
const __m256i lowMask = _mm256_set1_epi8( 0xF ); | |
__m256i high = _mm256_andnot_si256( lowMask, bytes ); | |
__m256i low = _mm256_and_si256( lowMask, bytes ); | |
high = _mm256_slli_epi16( high, 4 ); | |
bytes = _mm256_or_si256( low, high ); | |
return bytes; | |
} | |
// Convert lower 8 lower bytes in the vector from int8_t into float lanes | |
inline __m256 makeFloats( __m128i bytes ) | |
{ | |
__m256i i32 = _mm256_cvtepi8_epi32( bytes ); | |
return _mm256_cvtepi32_ps( i32 ); | |
} | |
// Decompress Q4_0 compressed block, the block size is 32 | |
// The block payload contains 1 reference value (the first argument), and 32 4-bit values packed into 16 bytes (second argument) | |
std::array<__m256, 4> decompressBlock40( const float* scaling, const uint8_t* rsi ) | |
{ | |
// Unpack 4-bit fields into bytes | |
__m256i bytes = bytesFromNibbles( rsi ); | |
// Now we have a vector with bytes in [0..15], offset into [-8..+7] | |
const __m256i off = _mm256_set1_epi8( 8 ); | |
bytes = _mm256_sub_epi8( bytes, off ); | |
// Broadcast ref1 into AVX vector | |
const __m256 sv = _mm256_broadcast_ss( scaling ); | |
// Produce the result | |
std::array<__m256, 4> arr; | |
__m128i tmp = _mm256_castsi256_si128( bytes ); | |
arr[ 0 ] = _mm256_mul_ps( sv, makeFloats( tmp ) ); | |
tmp = _mm_srli_si128( tmp, 8 ); | |
arr[ 1 ] = _mm256_mul_ps( sv, makeFloats( tmp ) ); | |
tmp = _mm256_extracti128_si256( bytes, 1 ); | |
arr[ 2 ] = _mm256_mul_ps( sv, makeFloats( tmp ) ); | |
tmp = _mm_srli_si128( tmp, 8 ); | |
arr[ 3 ] = _mm256_mul_ps( sv, makeFloats( tmp ) ); | |
return arr; | |
} | |
// Decompress Q4_1 compressed block, the block size is 32 | |
// The block payload contains min value, scaling vector, and 32 4-bit values packed into 16 bytes | |
std::array<__m256, 4> decompressBlock41( const float* minValue, const float* scaling, const uint8_t* rsi ) | |
{ | |
// Unpack 4-bit fields into bytes | |
const __m256i bytes = bytesFromNibbles( rsi ); | |
// Broadcast both floats into AVX vectors | |
const __m256 iv = _mm256_broadcast_ss( minValue ); | |
const __m256 sv = _mm256_broadcast_ss( scaling ); | |
// Produce the result | |
std::array<__m256, 4> arr; | |
__m128i tmp = _mm256_castsi256_si128( bytes ); | |
arr[ 0 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv ); | |
tmp = _mm_srli_si128( tmp, 8 ); | |
arr[ 1 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv ); | |
tmp = _mm256_extracti128_si256( bytes, 1 ); | |
arr[ 2 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv ); | |
tmp = _mm_srli_si128( tmp, 8 ); | |
arr[ 3 ] = _mm256_fmadd_ps( sv, makeFloats( tmp ), iv ); | |
return arr; | |
} | |
// Compute dot product of two vectors, both compressed into a sequence of Q4_0 blocks | |
float dotProductCompressed40( size_t len, const uint8_t* x, const uint8_t* y ) | |
{ | |
assert( 0 == ( len % 32 ) ); | |
const size_t countBlocks = len / 32; | |
// Prepare the source pointers | |
const float* scalesX = (const float*)x; | |
const float* scalesY = (const float*)y; | |
const float* const sxEnd = scalesX + countBlocks; | |
const uint8_t* bytesX = (const uint8_t*)( scalesX + countBlocks ); | |
const uint8_t* bytesY = (const uint8_t*)( scalesY + countBlocks ); | |
// Initialize accumulator with zeros | |
__m256 acc = _mm256_setzero_ps(); | |
// Main loop | |
while( scalesX < sxEnd ) | |
{ | |
// Compute combined scale for the block | |
const __m256 scale = _mm256_mul_ps( _mm256_broadcast_ss( scalesX ), _mm256_broadcast_ss( scalesY ) ); | |
scalesX++; | |
scalesY++; | |
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes | |
__m256i bx = bytesFromNibbles( bytesX ); | |
__m256i by = bytesFromNibbles( bytesY ); | |
bytesX += 16; | |
bytesY += 16; | |
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. | |
const __m256i off = _mm256_set1_epi8( 8 ); | |
bx = _mm256_sub_epi8( bx, off ); | |
by = _mm256_sub_epi8( by, off ); | |
// Sign-extend first 16 signed bytes into int16_t | |
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); | |
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); | |
// Compute products of int16_t integers, add pairwise | |
__m256i i32 = _mm256_madd_epi16( x16, y16 ); | |
// Sign-extend last 16 signed bytes into int16_t vectors | |
x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); | |
y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); | |
// Accumulate products of int16_t integers | |
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) ); | |
// Convert int32_t to float | |
__m256 p = _mm256_cvtepi32_ps( i32 ); | |
// Apply the scale, and accumulate | |
acc = _mm256_fmadd_ps( scale, p, acc ); | |
} | |
// Return horizontal sum of the acc vector | |
__m128 res = _mm256_extractf128_ps( acc, 1 ); | |
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); | |
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); | |
res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); | |
return _mm_cvtss_f32( res ); | |
} | |
inline __m128i packNibbles( __m256i bytes ) | |
{ | |
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh | |
const __m256i lowByte = _mm256_set1_epi16( 0xFF ); | |
__m256i high = _mm256_andnot_si256( lowByte, bytes ); | |
__m256i low = _mm256_and_si256( lowByte, bytes ); | |
high = _mm256_srli_epi16( high, 4 ); | |
bytes = _mm256_or_si256( low, high ); | |
// Compress uint16_t lanes into bytes | |
__m128i r0 = _mm256_castsi256_si128( bytes ); | |
__m128i r1 = _mm256_extracti128_si256( bytes, 1 ); | |
return _mm_packus_epi16( r0, r1 ); | |
} | |
// Compress row into Q4_0 compressed blocks, the block size is 32 | |
void compressRow40( uint8_t* rdi, const float* rsi, size_t length ) | |
{ | |
assert( 0 == ( length % 32 ) ); | |
const size_t countBlocks = length / 32; | |
const float* const rsiEnd = rsi + length; | |
float* rdiScale = (float*)( rdi ); | |
uint8_t* rdiBytes = (uint8_t*)( rdiScale + countBlocks ); | |
while( rsi < rsiEnd ) | |
{ | |
// Load elements into 4 AVX vectors | |
__m256 v0 = _mm256_loadu_ps( rsi ); | |
__m256 v1 = _mm256_loadu_ps( rsi + 8 ); | |
__m256 v2 = _mm256_loadu_ps( rsi + 16 ); | |
__m256 v3 = _mm256_loadu_ps( rsi + 24 ); | |
rsi += 32; | |
// Compute max(abs(e)) for the block | |
const __m256 signBit = _mm256_set1_ps( -0.0f ); | |
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); | |
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); | |
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); | |
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); | |
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); | |
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); | |
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); | |
const float maxScalar = _mm_cvtss_f32( max4 ); | |
// Quantize these floats | |
const float d = maxScalar / 7.0f; | |
*rdiScale = d; | |
rdiScale++; | |
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; | |
const __m256 mul = _mm256_set1_ps( id ); | |
// Apply the multiplier | |
v0 = _mm256_mul_ps( v0, mul ); | |
v1 = _mm256_mul_ps( v1, mul ); | |
v2 = _mm256_mul_ps( v2, mul ); | |
v3 = _mm256_mul_ps( v3, mul ); | |
// Round to nearest integer | |
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); | |
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); | |
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); | |
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); | |
// Convert floats to integers | |
__m256i i0 = _mm256_cvtps_epi32( v0 ); | |
__m256i i1 = _mm256_cvtps_epi32( v1 ); | |
__m256i i2 = _mm256_cvtps_epi32( v2 ); | |
__m256i i3 = _mm256_cvtps_epi32( v3 ); | |
// Convert int32 to int16 | |
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 | |
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 | |
// Convert int16 to int8 | |
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 | |
// We got our precious signed bytes, but the order is now wrong | |
// These AVX2 pack instructions process 16-byte pieces independently | |
// The following instruction is fixing the order | |
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); | |
i0 = _mm256_permutevar8x32_epi32( i0, perm ); | |
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] | |
const __m256i off = _mm256_set1_epi8( 8 ); | |
i0 = _mm256_add_epi8( i0, off ); | |
// Compress the vector into 4 bit/value, and store | |
__m128i res = packNibbles( i0 ); | |
_mm_storeu_si128( ( __m128i* )rdiBytes, res ); | |
rdiBytes += 16; | |
} | |
} | |
// ==== Debug Functions ==== | |
#include <cmath> | |
#include <stdio.h> | |
inline void storeBlock( std::array<float, 32>& arr, std::array<__m256, 4> v ) | |
{ | |
float* rdi = arr.data(); | |
_mm256_storeu_ps( rdi, v[ 0 ] ); | |
_mm256_storeu_ps( rdi + 8, v[ 1 ] ); | |
_mm256_storeu_ps( rdi + 16, v[ 2 ] ); | |
_mm256_storeu_ps( rdi + 24, v[ 3 ] ); | |
} | |
float decompressScalar40( float scaling, uint8_t byte ) | |
{ | |
assert( byte <= 15 ); | |
int8_t val = (int8_t)byte - 8; | |
return scaling * val; | |
} | |
float decompressScalar41( float minValue, float scaling, uint8_t byte ) | |
{ | |
assert( byte <= 15 ); | |
return std::fma( scaling, (float)byte, minValue ); | |
} | |
int testDecompressor() | |
{ | |
const float scaling = 13; | |
const float min = 44; | |
// From random.org | |
const std::array<uint8_t, 16> bytes = { 188, 56, 77, 68, 113, 245, 126, 231, 143, 225, 48, 216, 191, 53, 110, 118 }; | |
// Decompress and store these bytes in both compressed formats | |
std::array<float, 32> b40, b41; | |
storeBlock( b40, decompressBlock40( &scaling, bytes.data() ) ); | |
storeBlock( b41, decompressBlock41( &min, &scaling, bytes.data() ) ); | |
// Verify the data | |
for( size_t i = 0; i < 32; i++ ) | |
{ | |
uint8_t byte = bytes[ i / 2 ]; | |
if( 0 == ( i % 2 ) ) | |
byte &= 0xF; | |
else | |
byte = byte >> 4; | |
// Verify Q4_0 decompressor | |
float fast = b40[ i ]; | |
float scalar = decompressScalar40( scaling, byte ); | |
if( fast != scalar ) | |
return 1; | |
// Verify Q4_1 decompressor | |
fast = b41[ i ]; | |
scalar = decompressScalar41( min, scaling, byte ); | |
if( fast != scalar ) | |
return 1; | |
} | |
printf( "Success!\n" ); | |
return 0; | |
} | |
struct CompressedBlock40 | |
{ | |
float scale; | |
std::array<uint8_t, 16> bytes; | |
operator const uint8_t* ( ) const | |
{ | |
return (const uint8_t*)this; | |
} | |
operator uint8_t* ( ) | |
{ | |
return (uint8_t*)this; | |
} | |
}; | |
int testDotProduct() | |
{ | |
const CompressedBlock40 x | |
{ | |
3.5f, | |
{ 188, 56, 77, 68, 113, 245, 126, 231, 143, 225, 48, 216, 191, 53, 110, 118 } | |
}; | |
const CompressedBlock40 y | |
{ | |
4.17f, | |
{ 194, 237, 156, 194, 32, 200, 60, 253, 21, 69, 120, 124, 63, 77, 150, 143 } | |
}; | |
const float dotCompressed = dotProductCompressed40( 32, x, y ); | |
std::array<float, 32> xf, yf; | |
storeBlock( xf, decompressBlock40( &x.scale, x.bytes.data() ) ); | |
storeBlock( yf, decompressBlock40( &y.scale, y.bytes.data() ) ); | |
double dotScalar = 0; | |
for( size_t i = 0; i < 32; i++ ) | |
dotScalar += (double)xf[ i ] * yf[ i ]; | |
printf( "dotProductCompressed40: %g\nScalar: %g\n", dotCompressed, dotScalar ); | |
return 0; | |
} | |
int testCompressor() | |
{ | |
const CompressedBlock40 orig | |
{ | |
// We want multiplier to be power of 2 because in this test we comparing the compressed block for exact equality with memcmp() | |
// Scaling floats by powers of 2 is lossless, both multiplication and division | |
16.0f, | |
// Generated by random.org, and removed the zeros | |
{ 0x8f, 0xd1, 0x14, 0xfe, 0x3e, 0x4c, 0x3a, 0x31, 0xce, 0x15, 0x77, 0xc6, 0x43, 0x51, 0x8e, 0x71 } | |
}; | |
std::array<float, 32> fp32; | |
storeBlock( fp32, decompressBlock40( &orig.scale, orig.bytes.data() ) ); | |
CompressedBlock40 recompressed; | |
compressRow40( recompressed, fp32.data(), 32 ); | |
const int cmp = memcmp( &orig, &recompressed, sizeof( CompressedBlock40 ) ); | |
if( 0 == cmp ) | |
{ | |
printf( "Success\n" ); | |
return 0; | |
} | |
else | |
{ | |
printf( "Fail\n" ); | |
return 1; | |
} | |
} | |
int main() | |
{ | |
// return testDecompressor(); | |
// return testDotProduct(); | |
return testCompressor(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment