Created
February 3, 2023 08:37
-
-
Save rygorous/4212be0cd009584e4184e641ca210528 to your computer and use it in GitHub Desktop.
Multigetbits, the second
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <stdint.h> | |
#include <string.h> | |
#include <smmintrin.h> | |
#ifdef __RADAVX__ | |
#include <immintrin.h> | |
#endif | |
#if !defined(__clang__) && defined(_MSC_VER) | |
#include <intrin.h> | |
static inline uint16_t bswap16(uint16_t x) { return _byteswap_ushort(x); } | |
static inline uint32_t bswap32(uint32_t x) { return _byteswap_ulong(x); } | |
static inline uint64_t bswap64(uint64_t x) { return _byteswap_uint64(x); } | |
static inline uint64_t rotl64(uint64_t x, uint32_t k) { return _rotl64(x, k); } | |
#else | |
static inline uint16_t bswap16(uint16_t x) { return __builtin_bswap16(x); } | |
static inline uint32_t bswap32(uint32_t x) { return __builtin_bswap32(x); } | |
static inline uint64_t bswap64(uint64_t x) { return __builtin_bswap64(x); } | |
static inline uint64_t rrRotlVar64(uint64_t x, uint32_t k) { __asm__("rolq %%cl, %0" : "+r"(x) : "c"(k)); return x; } | |
#define rotl64(u64,num) (__builtin_constant_p((num)) ? ( ( (u64) << (num) ) | ( (u64) >> (64 - (num))) ) : rrRotlVar64((u64),(num))) | |
#endif | |
static inline __m128i prefix_sum_u8(__m128i x) | |
{ | |
#if 1 | |
// alternative form that uses shifts, not the general shuffle network on port 5 (which is a bottleneck | |
// for us) | |
x = _mm_add_epi8(x, _mm_slli_epi64(x, 8)); | |
x = _mm_add_epi8(x, _mm_slli_epi64(x, 16)); | |
x = _mm_add_epi8(x, _mm_slli_epi64(x, 32)); | |
x = _mm_add_epi8(x, _mm_shuffle_epi8(x, _mm_setr_epi8(-1,-1,-1,-1,-1,-1,-1,-1, 7,7,7,7,7,7,7,7))); | |
#else | |
// x[0], x[1], x[2], x[3], ... | |
x = _mm_add_epi8(x, _mm_slli_si128(x, 1)); | |
// x[0], sum(x[0:2]), sum(x[1:3]), sum(x[2:4]), ... | |
x = _mm_add_epi8(x, _mm_slli_si128(x, 2)); | |
// x[0], sum(x[0:2]), sum(x[0:3]), sum(x[0:4]), sum(x[1:5]), sum(x[2:6]), ... | |
x = _mm_add_epi8(x, _mm_slli_si128(x, 4)); | |
// longest group now sums over 8 elems | |
x = _mm_add_epi8(x, _mm_slli_si128(x, 8)); | |
#endif | |
// and now we're done | |
return x; | |
} | |
static inline __m128i prefix_sum_u16(__m128i x) | |
{ | |
#if 1 | |
x = _mm_add_epi16(x, _mm_slli_epi64(x, 16)); | |
x = _mm_add_epi16(x, _mm_slli_epi64(x, 32)); | |
x = _mm_add_epi16(x, _mm_shuffle_epi8(x, _mm_setr_epi8(-1,-1,-1,-1,-1,-1,-1,-1, 6,7,6,7,6,7,6,7))); | |
#else | |
x = _mm_add_epi16(x, _mm_slli_si128(x, 2)); | |
x = _mm_add_epi16(x, _mm_slli_si128(x, 4)); | |
x = _mm_add_epi16(x, _mm_slli_si128(x, 8)); | |
#endif | |
return x; | |
} | |
static inline __m128i prefix_sum_u32(__m128i x) | |
{ | |
#if 1 | |
x = _mm_add_epi32(x, _mm_slli_epi64(x, 32)); | |
x = _mm_add_epi32(x, _mm_shuffle_epi8(x, _mm_setr_epi8(-1,-1,-1,-1,-1,-1,-1,-1, 4,5,6,7,4,5,6,7))); | |
#else | |
// x[0], x[1], x[2], x[3] | |
x = _mm_add_epi32(x, _mm_slli_si128(x, 4)); | |
// x[0], sum(x[0:2]), sum(x[1:3]), sum(x[2:4]) | |
x = _mm_add_epi32(x, _mm_slli_si128(x, 8)); | |
// x[0], sum(x[0:2]), sum(x[0:3]), sum(x[0:4]) | |
#endif | |
return x; | |
} | |
// individual field_widths in [0,8] | |
// MSB-first bit packing convention, SSSE3+ | |
// | |
// compiled with /arch:AVX (to get rid of reg-reg moves Jaguar code wouldn't have): | |
// ballpark is ~42 ops for the 16 getbits, so ~2.63 ops/getbits. | |
// | |
// so expect maybe 1 cycle/lane on the big cores, 1.7 cycles/lane on Jaguar. (!) | |
static inline __m128i multigetbits8(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
// prefix-sum the field widths and advance bit position pointer | |
__m128i summed_widths = prefix_sum_u8(field_widths); | |
uint32_t total_width = (uint32_t)_mm_extract_epi16(summed_widths, 7) >> 8; // no PEXTRB before SSE4.1, and this is the only place where SSE4.1+ helps | |
*pbit_basepos = bit_basepos + total_width; | |
// NOTE once this is done (which is something like 1/4 into the whole thing by op count), | |
// OoO cores can start working on next iter | |
// -> this will get good core utilization | |
// determine starting bit position for every lane | |
// and split into bit-within-byte and byte indices | |
__m128i basepos_u8 = _mm_shuffle_epi8(_mm_cvtsi32_si128(bit_basepos & 7), _mm_setzero_si128()); | |
__m128i first_bit_index = _mm_add_epi8(basepos_u8, _mm_slli_si128(summed_widths, 1)); | |
__m128i first_byte_index = _mm_and_si128(_mm_srli_epi16(first_bit_index, 3), _mm_set1_epi8(0x1f)); // no "shift bytes", sigh. | |
// source bytes | |
__m128i src_byte0 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 0)); | |
__m128i src_byte1 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 1)); | |
// first/second bytes for every lane | |
__m128i byte0 = _mm_shuffle_epi8(src_byte0, first_byte_index); | |
__m128i byte1 = _mm_shuffle_epi8(src_byte1, first_byte_index); | |
// assemble words | |
__m128i words0 = _mm_unpacklo_epi8(byte1, byte0); | |
__m128i words1 = _mm_unpackhi_epi8(byte1, byte0); | |
// now, need to shift | |
// ((byte0<<8) | byte1) >> (16 - width - (first_bit_index & 7)) | |
// we don't have per-lane variable shifts in SSSE3, but we do have PMULHUW, | |
// and we can do the multiplier table lookup via PSHUFB. | |
__m128i shift_amt = _mm_add_epi8(_mm_and_si128(first_bit_index, _mm_set1_epi8(7)), field_widths); | |
__m128i shiftm0_lut = _mm_setr_epi8(0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80, 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00); | |
__m128i shiftm1_lut = _mm_setr_epi8(0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, 0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80); | |
__m128i shiftm0 = _mm_shuffle_epi8(shiftm0_lut, shift_amt); | |
__m128i shiftm1 = _mm_shuffle_epi8(shiftm1_lut, shift_amt); | |
__m128i shift_mul0 = _mm_unpacklo_epi8(shiftm0, shiftm1); | |
__m128i shift_mul1 = _mm_unpackhi_epi8(shiftm0, shiftm1); | |
__m128i shifted0 = _mm_mulhi_epu16(words0, shift_mul0); | |
__m128i shifted1 = _mm_mulhi_epu16(words1, shift_mul1); | |
// pack the results back into bytes | |
__m128i byte_mask = _mm_set1_epi16(0xff); | |
__m128i shifted_bytes = _mm_packus_epi16(_mm_and_si128(shifted0, byte_mask), _mm_and_si128(shifted1, byte_mask)); | |
// mask by field width, again using a PSHUFB LUT | |
__m128i width_mask_lut = _mm_setr_epi8(0,1,3,7, 15,31,63,127, -1,-1,-1,-1, -1,-1,-1,-1); | |
__m128i width_mask = _mm_shuffle_epi8(width_mask_lut, field_widths); | |
__m128i result = _mm_and_si128(shifted_bytes, width_mask); | |
return result; | |
} | |
static inline __m128i big_endian_load_shift128(const uint8_t *in_ptr, uint32_t bit_basepos) | |
{ | |
// Grab 128 source bits starting "bit_basepos" bits into the bit stream | |
__m128i src_byte0 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 0)); | |
__m128i src_byte1 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 1)); | |
// We need to consume the first (bit_basepos & 7) bits with a big-endian 128-bit | |
// funnel shift, which we don't have ready at hand, so we need to get creative; | |
// specifically, use 16-bit shifts by a single distance for all lanes (which we have) | |
// and make sure to not grab any bits that crossed byte boundaries (which would be | |
// taken from the wrong byte due to the endianness difference) | |
uint32_t basepos7 = bit_basepos & 7; | |
__m128i basepos_shiftamt = _mm_cvtsi32_si128(basepos7); | |
// Combine to big-endian 16-bit words and shift those since we don't have 8-bit shifts | |
// at hand | |
__m128i merged0 = _mm_unpacklo_epi8(src_byte1, src_byte0); | |
__m128i merged1 = _mm_unpackhi_epi8(src_byte1, src_byte0); | |
__m128i shifted0 = _mm_sll_epi16(merged0, basepos_shiftamt); | |
__m128i shifted1 = _mm_sll_epi16(merged1, basepos_shiftamt); | |
__m128i reduced0 = _mm_srli_epi16(shifted0, 8); | |
__m128i reduced1 = _mm_srli_epi16(shifted1, 8); | |
__m128i shifted_src_bytes = _mm_packus_epi16(reduced0, reduced1); | |
return shifted_src_bytes; | |
} | |
// Once more, with feeling | |
// trying to figure out a niftier way to do this that'll also allow me do to full multigetbits16 | |
// and multigetbits32 which don't suck | |
static inline __m128i multigetbits8b(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
// Prefix-sum the field widths and advance bit position pointer | |
__m128i end_bit_index = prefix_sum_u8(field_widths); | |
uint32_t total_width = (uint32_t)_mm_extract_epi16(end_bit_index, 7) >> 8; // no PEXTRB before SSE4.1, and this is the only place where SSE4.1+ helps | |
*pbit_basepos = bit_basepos + total_width; | |
// Doing this shift is a bit of a production, but it simplifies the rest. | |
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos); | |
__m128i end_byte_index = _mm_and_si128(_mm_srli_epi16(end_bit_index, 3), _mm_set1_epi8(0x1f)); // no "shift bytes", sigh. | |
// Grab first/second bytes for every lane | |
__m128i byte1 = _mm_shuffle_epi8(shifted_src_bytes, end_byte_index); | |
__m128i byte0 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(end_byte_index, _mm_set1_epi8(1))); | |
// Assemble words (byte1 << 8) | byte0 | |
__m128i words0 = _mm_unpacklo_epi8(byte1, byte0); | |
__m128i words1 = _mm_unpackhi_epi8(byte1, byte0); | |
// Now do a left shift by 1 << (end_bit_index & 7) using a multiply, | |
// putting the end of the bit field at the boundary between the low and high byte | |
// in every word. | |
__m128i end_bit_index7 = _mm_and_si128(end_bit_index, _mm_set1_epi8(7)); | |
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128); | |
__m128i shiftm = _mm_shuffle_epi8(left_shift_lut, end_bit_index7); | |
__m128i shift_mul0 = _mm_unpacklo_epi8(shiftm, _mm_setzero_si128()); | |
__m128i shift_mul1 = _mm_unpackhi_epi8(shiftm, _mm_setzero_si128()); | |
__m128i shifted0 = _mm_mullo_epi16(words0, shift_mul0); | |
__m128i shifted1 = _mm_mullo_epi16(words1, shift_mul1); | |
// Grab the high byte of the results and pack back into bytes | |
__m128i shifted_bytes = _mm_packus_epi16(_mm_srli_epi16(shifted0, 8), _mm_srli_epi16(shifted1, 8)); | |
// mask by field width, again using a PSHUFB LUT | |
__m128i width_mask_lut = _mm_setr_epi8(0,1,3,7, 15,31,63,127, -1,-1,-1,-1, -1,-1,-1,-1); | |
__m128i width_mask = _mm_shuffle_epi8(width_mask_lut, field_widths); | |
__m128i result = _mm_and_si128(shifted_bytes, width_mask); | |
return result; | |
} | |
static inline __m128i multigetbits_leftshift_mult(__m128i end_bit_index) | |
{ | |
#if 1 | |
// This requires 0 <= end_bit_index <= 127! | |
// We use that PSHUFB only looks at the bottom 4 bits for the index, plus bit 7 to decide whether to | |
// substitute in zero. | |
// | |
// Since end_bit_index < 128, we know bit 7 is clear, so we don't need to AND with 7. Just replicate | |
// the 8-entry table twice. | |
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128); | |
__m128i left_shift_mult = _mm_and_si128(_mm_shuffle_epi8(left_shift_lut, end_bit_index), _mm_set1_epi32(0xff)); | |
#else | |
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7)); | |
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_set1_epi32(0x3f800000)); | |
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic)); | |
#endif | |
return left_shift_mult; | |
} | |
// field widths here are U32[4] in [0,24] | |
static inline __m128i multigetbits24a(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i summed_widths = prefix_sum_u32(field_widths); | |
uint32_t total_width = _mm_extract_epi16(summed_widths, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+) | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00); | |
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths); | |
// say bit_basepos = 3 and field_widths[0] = 11 | |
// then end_bit_index[0] = 3 + 11 = 14 | |
// | |
// we want to shuffle the input bytes so the byte containing bit 14 (in bit stream order) ends up in the least significant | |
// byte position of lane 0 | |
// | |
// this is byte 1, so we want shuffle[0] = 14>>3 = 1 | |
// and then we need to shift left by another (14 & 7) = 6 bit positions to have the bottom of the bit field be | |
// flush with bit 8 of lane 0. | |
// | |
// note that this Just Works(tm) if end_bit_index[i] ends up a multiple of 8: we fetch for one byte | |
// too far (since we ust end_bit_index and not end_bit_index-1) but then shift by 0, so that ends up | |
// starting from bit 8 of the target lane is exactly what we want. | |
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// grab source bytes and shuffle | |
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3))); | |
__m128i dwords = _mm_shuffle_epi8(src_bytes, byte_shuffle); | |
// left shift the source dwords | |
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index); | |
__m128i shifted_dwords = _mm_mullo_epi32(dwords, left_shift_mult); | |
// right shift by 8 to align it to the bottom | |
__m128i finished_bit_grab = _mm_srli_epi32(shifted_dwords, 8); | |
// create width mask constants | |
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000)); | |
__m128i width_mask = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic)); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
// field widths here are given packed little-endian into an U32 | |
static inline __m128i multigetbits24b(const uint8_t *in_ptr, uint32_t *pbit_basepos, uint32_t packed_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
// use a multiply do to the inclusive prefix sum | |
uint32_t field_end = (packed_widths + (bit_basepos & 7)) * 0x01010101u; | |
*pbit_basepos = (bit_basepos & ~7) + (field_end >> 24); | |
__m128i widths_vec = _mm_cvtepu8_epi32(_mm_cvtsi32_si128(packed_widths)); | |
__m128i end_bit_index = _mm_cvtepu8_epi32(_mm_cvtsi32_si128(field_end)); | |
// say bit_basepos = 3 and field_widths[0] = 11 | |
// then end_bit_index[0] = 3 + 11 = 14 | |
// | |
// we want to shuffle the input bytes so the byte containing bit 14 (in bit stream order) ends up in the least significant | |
// byte position of lane 0 | |
// | |
// this is byte 1, so we want shuffle[0] = 14>>3 = 1 | |
// and then we need to shift left by another (14 & 7) = 6 bit positions to have the bottom of the bit field be | |
// flush with bit 8 of lane 0. | |
// | |
// note that this Just Works(tm) if end_bit_index[i] ends up a multiple of 8: we fetch for one byte | |
// too far (since we ust end_bit_index and not end_bit_index-1) but then shift by 0, so that ends up | |
// starting from bit 8 of the target lane is exactly what we want. | |
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// grab source bytes and shuffle | |
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3))); | |
__m128i dwords = _mm_shuffle_epi8(src_bytes, byte_shuffle); | |
// left shift the source dwords | |
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index); | |
__m128i shifted_dwords = _mm_mullo_epi32(dwords, left_shift_mult); | |
// right shift by 8 to align it to the bottom | |
__m128i finished_bit_grab = _mm_srli_epi32(shifted_dwords, 8); | |
// create width mask constants | |
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(widths_vec, 23), _mm_set1_epi32(0xbf800000)); | |
__m128i width_mask = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic)); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
// Field widths are [0,30]. | |
// Limit here is 30 so that we consume at most 30*4 + 7 (for the initial align) = 127 bits from the source | |
// any more turns out to get messy | |
static inline __m128i multigetbits30(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i summed_widths = prefix_sum_u32(field_widths); | |
uint32_t total_width = _mm_extract_epi16(summed_widths, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+) | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00); | |
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths); | |
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// grab source bytes and shuffle | |
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3))); | |
__m128i dwords0 = _mm_shuffle_epi8(src_bytes, byte_shuffle); | |
__m128i dwords1 = _mm_shuffle_epi8(src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1))); | |
// left shift the source dwords | |
// The high concept here is that the "l" values contain the low bits of the result, and | |
// the 'h' values contain the high bits of the result. | |
// | |
// The top approach computes this with 16-bit multiplies which are usually faster, | |
// but this requires a slightly more complicated setup for the multipliers. | |
// | |
// The bottom approach just uses 32-bit multiplies. | |
#if 1 | |
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7)); | |
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_castps_si128(_mm_set1_ps((float)0x10001))); | |
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic)); | |
__m128i shifted_dwordsl = _mm_mullo_epi16(dwords0, left_shift_mult); | |
__m128i shifted_dwordsh = _mm_mullo_epi16(dwords1, left_shift_mult); | |
#else | |
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index); | |
__m128i shifted_dwordsl = _mm_mullo_epi32(dwords0, left_shift_mult); | |
__m128i shifted_dwordsh = _mm_mullo_epi32(dwords1, left_shift_mult); | |
#endif | |
// combine the low and high parts | |
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi32(shifted_dwordsl, 8), shifted_dwordsh); | |
// create width mask constants | |
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000)); | |
__m128i width_mask = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic)); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
// Field widths are [0,32]. | |
// with the big_endian_load_shift128 primitive, we can support 32 bits in every lane | |
static inline __m128i multigetbits32(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i end_bit_index = prefix_sum_u32(field_widths); | |
uint32_t total_width = _mm_extract_epi16(end_bit_index, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+) | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos); | |
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// grab source bytes and shuffle | |
__m128i dwords0 = _mm_shuffle_epi8(shifted_src_bytes, byte_shuffle); | |
__m128i dwords1 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1))); | |
// left shift the source dwords | |
// The high concept here is that the "l" values contain the low bits of the result, and | |
// the 'h' values contain the high bits of the result. | |
// | |
// The top approach computes this with 16-bit multiplies which are usually faster, | |
// but this requires a slightly more complicated setup for the multipliers. | |
// | |
// The bottom approach just uses 32-bit multiplies. | |
#if 1 | |
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7)); | |
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_castps_si128(_mm_set1_ps((float)0x10001))); | |
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic)); | |
__m128i shifted_dwordsl = _mm_mullo_epi16(dwords0, left_shift_mult); | |
__m128i shifted_dwordsh = _mm_mullo_epi16(dwords1, left_shift_mult); | |
#else | |
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index); | |
__m128i shifted_dwordsl = _mm_mullo_epi32(dwords0, left_shift_mult); | |
__m128i shifted_dwordsh = _mm_mullo_epi32(dwords1, left_shift_mult); | |
#endif | |
// combine the low and high parts | |
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi32(shifted_dwordsl, 8), shifted_dwordsh); | |
// create width mask constants | |
// supporting width=32 here adds an extra wrinkle | |
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000)); | |
__m128i width_mask0 = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic)); | |
__m128i width_gt31 = _mm_cmpgt_epi32(field_widths, _mm_set1_epi32(31)); | |
__m128i width_mask = _mm_andnot_si128(width_gt31, width_mask0); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
static inline __m128i wideload_shuffle(__m128i src0_128, __m128i src1_32, __m128i shuf_index) | |
{ | |
__m128i lower = _mm_shuffle_epi8(src0_128, shuf_index); | |
__m128i upper = _mm_andnot_si128(_mm_cmpgt_epi8(shuf_index, _mm_set1_epi8(-1)), src1_32); | |
return _mm_or_si128(lower, upper); | |
//__m128i upper = _mm_shuffle_epi8(src1_32, _mm_xor_si128(shuf_index, _mm_set1_epi8(0x83 - 0x100))); | |
//return _mm_or_si128(lower, upper); | |
} | |
// Field widths are [0,32]. | |
// alternative approach without laod_shift128, instead using a different load strategy | |
// | |
// interesting but worse | |
static inline __m128i multigetbits32c(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i summed_field_widths = prefix_sum_u32(field_widths); | |
uint32_t total_width = _mm_extract_epi16(summed_field_widths, 6); // using PEXTRW (SSE2) instead of PEXTRD (SSE4.1+) | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i end_bit_index = _mm_add_epi32(summed_field_widths, _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0)); | |
__m128i src_bytes0 = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3) + 0)); | |
__m128i src_bytes1 = _mm_cvtsi32_si128(*(int *) (in_ptr + (bit_basepos >> 3) + 13)); // MOVD | |
src_bytes1 = _mm_shuffle_epi8(src_bytes1, _mm_set1_epi8(3)); // broadcast final byte | |
__m128i end_byte_index = _mm_add_epi32(_mm_srli_epi32(end_bit_index, 3), _mm_set1_epi32(0x70)); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// grab source bytes and shuffle | |
__m128i dwords0 = wideload_shuffle(src_bytes0, src_bytes1, byte_shuffle); | |
__m128i dwords1 = wideload_shuffle(src_bytes0, src_bytes1, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1))); | |
// left shift the source dwords | |
// The high concept here is that the "l" values contain the low bits of the result, and | |
// the 'h' values contain the high bits of the result. | |
// | |
// The top approach computes this with 16-bit multiplies which are usually faster, | |
// but this requires a slightly more complicated setup for the multipliers. | |
// | |
// The bottom approach just uses 32-bit multiplies. | |
#if 1 | |
__m128i left_shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7)); | |
__m128i left_shift_magic = _mm_add_epi32(_mm_slli_epi32(left_shift_amt, 23), _mm_castps_si128(_mm_set1_ps((float)0x10001))); | |
__m128i left_shift_mult = _mm_cvttps_epi32(_mm_castsi128_ps(left_shift_magic)); | |
__m128i shifted_dwordsl = _mm_mullo_epi16(dwords0, left_shift_mult); | |
__m128i shifted_dwordsh = _mm_mullo_epi16(dwords1, left_shift_mult); | |
#else | |
__m128i left_shift_mult = multigetbits_leftshift_mult(end_bit_index); | |
__m128i shifted_dwordsl = _mm_mullo_epi32(dwords0, left_shift_mult); | |
__m128i shifted_dwordsh = _mm_mullo_epi32(dwords1, left_shift_mult); | |
#endif | |
// combine the low and high parts | |
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi32(shifted_dwordsl, 8), shifted_dwordsh); | |
// create width mask constants | |
// supporting width=32 here adds an extra wrinkle | |
__m128i width_magic = _mm_add_epi32(_mm_slli_epi32(field_widths, 23), _mm_set1_epi32(0xbf800000)); | |
__m128i width_mask0 = _mm_cvttps_epi32(_mm_castsi128_ps(width_magic)); | |
__m128i width_gt31 = _mm_cmpgt_epi32(field_widths, _mm_set1_epi32(31)); | |
__m128i width_mask = _mm_andnot_si128(width_gt31, width_mask0); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
// Field widths are [0,15]. | |
// Limit is 15 so that we consume at most 15*8 + 7 (for the initial align) = 127 bits from the source | |
// any more gets messy | |
static inline __m128i multigetbits15(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i summed_widths = prefix_sum_u16(field_widths); | |
uint32_t total_width = _mm_extract_epi16(summed_widths, 7); | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i basepos_u16 = _mm_set1_epi16(bit_basepos & 7); | |
__m128i end_bit_index = _mm_add_epi16(basepos_u16, summed_widths); | |
__m128i end_byte_index = _mm_srli_epi16(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,2,2, 4,4,6,6, 8,8,10,10, 12,12,14,14)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_set1_epi16(0x0100)); | |
// grab source bytes and shuffle | |
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3))); | |
__m128i words0 = _mm_shuffle_epi8(src_bytes, byte_shuffle); | |
__m128i words1 = _mm_shuffle_epi8(src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1))); | |
// left shift the source words | |
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128); | |
__m128i left_shift_mult = _mm_and_si128(_mm_shuffle_epi8(left_shift_lut, end_bit_index), _mm_set1_epi16(0xff)); | |
__m128i shifted_words0 = _mm_mullo_epi16(words0, left_shift_mult); | |
__m128i shifted_words1 = _mm_mullo_epi16(words1, left_shift_mult); | |
// combine the low and high parts | |
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi16(shifted_words0, 8), shifted_words1); | |
// create width mask constants | |
__m128i width_exps = _mm_add_epi16(_mm_slli_epi16(field_widths, 7), _mm_set1_epi16(0xbf80)); | |
__m128i zero = _mm_setzero_si128(); | |
__m128i widthm0 = _mm_cvttps_epi32(_mm_castsi128_ps(_mm_unpacklo_epi16(zero, width_exps))); | |
__m128i widthm1 = _mm_cvttps_epi32(_mm_castsi128_ps(_mm_unpackhi_epi16(zero, width_exps))); | |
__m128i width_mask = _mm_packs_epi32(widthm0, widthm1); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
// Field widths are [0,16]. | |
// with the big_endian_load_shift128 primitive, we can support 16 bits in every lane | |
static inline __m128i multigetbits16(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i end_bit_index = prefix_sum_u16(field_widths); | |
uint32_t total_width = _mm_extract_epi16(end_bit_index, 7); | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos); | |
__m128i end_byte_index = _mm_srli_epi16(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,2,2, 4,4,6,6, 8,8,10,10, 12,12,14,14)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_set1_epi16(0x0100)); | |
// shuffle source bytes | |
__m128i words0 = _mm_shuffle_epi8(shifted_src_bytes, byte_shuffle); | |
__m128i words1 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1))); | |
// left shift the source words | |
__m128i left_shift_lut = _mm_setr_epi8(1,2,4,8, 16,32,64,-128, 1,2,4,8, 16,32,64,-128); | |
__m128i left_shift_mult = _mm_and_si128(_mm_shuffle_epi8(left_shift_lut, _mm_and_si128(end_bit_index, _mm_set1_epi8(7))), _mm_set1_epi16(0xff)); | |
__m128i shifted_words0 = _mm_mullo_epi16(words0, left_shift_mult); | |
__m128i shifted_words1 = _mm_mullo_epi16(words1, left_shift_mult); | |
// combine the low and high parts | |
__m128i finished_bit_grab = _mm_or_si128(_mm_srli_epi16(shifted_words0, 8), shifted_words1); | |
// create width mask | |
// need to do this differently from multigetbits15 logic here to make width=16 work | |
__m128i base_mask_lut = _mm_setr_epi8(-1,-2,-4,-8, -16,-32,-64,-128, -1,-2,-4,-8, -16,-32,-64,-128); | |
__m128i width_mask0 = _mm_shuffle_epi8(base_mask_lut, field_widths); // gives (-1 << (field_widths & 7)) | |
__m128i width_mask0s = _mm_slli_epi16(width_mask0, 8); | |
__m128i width_gt7 = _mm_cmpgt_epi16(field_widths, _mm_set1_epi16(7)); | |
// conditionally shift by 8 where field_widths >= 8 | |
__m128i width_mask1 = _mm_or_si128(_mm_and_si128(width_mask0s, width_gt7), _mm_andnot_si128(width_gt7, width_mask0)); | |
// conditionally zero mask where field_widths >= 16 | |
__m128i width_gt15 = _mm_cmpgt_epi16(field_widths, _mm_set1_epi16(15)); | |
__m128i width_mask = _mm_andnot_si128(width_gt15, width_mask1); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
#ifdef __RADAVX__ | |
// field widths here are U32[4] in [0,24] | |
static inline __m128i multigetbits24a_avx2(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i summed_widths = prefix_sum_u32(field_widths); | |
uint32_t total_width = _mm_extract_epi32(summed_widths, 3); | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00); | |
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths); | |
// say bit_basepos = 3 and field_widths[0] = 11 | |
// then end_bit_index[0] = 3 + 11 = 14 | |
// | |
// we want to shuffle the input bytes so the byte containing bit 14 (in bit stream order) ends up in the least significant | |
// byte position of lane 0 | |
// | |
// this is byte 1, so we want shuffle[0] = 14>>3 = 1 | |
// and then we need to shift left by another (14 & 7) = 6 bit positions to have the bottom of the bit field be | |
// flush with bit 8 of lane 0. | |
// | |
// note that this Just Works(tm) if end_bit_index[i] ends up a multiple of 8: we fetch for one byte | |
// too far (since we ust end_bit_index and not end_bit_index-1) but then shift by 0, so that ends up | |
// starting from bit 8 of the target lane is exactly what we want. | |
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// grab source bytes and shuffle | |
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3))); | |
__m128i dwords = _mm_shuffle_epi8(src_bytes, byte_shuffle); | |
// right shift the source dwords to align the corect bits at the bottom | |
__m128i shift_amt = _mm_sub_epi32(_mm_set1_epi32(8), _mm_and_si128(end_bit_index, _mm_set1_epi32(7))); | |
__m128i finished_bit_grab = _mm_srlv_epi32(dwords, shift_amt); | |
// mask to desired field widths | |
__m128i width_mask = _mm_sllv_epi32(_mm_set1_epi32(-1), field_widths); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
// Field widths are [0,30]. | |
// Limit here is 30 so that we consume at most 30*4 + 7 (for the initial align) = 127 bits from the source | |
// any more turns out to get messy | |
static inline __m128i multigetbits30_avx2(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i summed_widths = prefix_sum_u32(field_widths); | |
uint32_t total_width = _mm_extract_epi32(summed_widths, 3); | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i basepos_u32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(bit_basepos & 7), 0x00); | |
__m128i end_bit_index = _mm_add_epi32(basepos_u32, summed_widths); | |
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// grab source bytes and shuffle | |
__m128i src_bytes = _mm_loadu_si128((const __m128i *) (in_ptr + (bit_basepos >> 3))); | |
__m128i dwords0 = _mm_shuffle_epi8(src_bytes, byte_shuffle); | |
__m128i dwords1 = _mm_shuffle_epi8(src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1))); | |
// shift the source dwords | |
// The high concept here is that the "l" values contain the low bits of the result, and | |
// the 'h' values contain the high bits of the result. | |
__m128i shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7)); | |
__m128i rev_shift_amt = _mm_sub_epi32(_mm_set1_epi32(8), shift_amt); | |
__m128i shifted_dwordsl = _mm_srlv_epi32(dwords0, rev_shift_amt); | |
__m128i shifted_dwordsh = _mm_sllv_epi32(dwords1, shift_amt); | |
// combine the low and high parts | |
__m128i finished_bit_grab = _mm_or_si128(shifted_dwordsl, shifted_dwordsh); | |
// mask to desired field widths | |
__m128i width_mask = _mm_sllv_epi32(_mm_set1_epi32(-1), field_widths); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
// Field widths are [0,32]. | |
static inline __m128i multigetbits32_avx2(const uint8_t *in_ptr, uint32_t *pbit_basepos, __m128i field_widths) | |
{ | |
uint32_t bit_basepos = *pbit_basepos; | |
__m128i end_bit_index = prefix_sum_u32(field_widths); | |
uint32_t total_width = _mm_extract_epi32(end_bit_index, 3); | |
*pbit_basepos = bit_basepos + total_width; | |
__m128i shifted_src_bytes = big_endian_load_shift128(in_ptr, bit_basepos); | |
__m128i end_byte_index = _mm_srli_epi32(end_bit_index, 3); | |
__m128i byte_shuffle = _mm_shuffle_epi8(end_byte_index, _mm_setr_epi8(0,0,0,0, 4,4,4,4, 8,8,8,8, 12,12,12,12)); | |
byte_shuffle = _mm_sub_epi8(byte_shuffle, _mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)); | |
// shuffle source bytes | |
__m128i dwords0 = _mm_shuffle_epi8(shifted_src_bytes, byte_shuffle); | |
__m128i dwords1 = _mm_shuffle_epi8(shifted_src_bytes, _mm_sub_epi8(byte_shuffle, _mm_set1_epi8(1))); | |
// shift the source dwords | |
// The high concept here is that the "l" values contain the low bits of the result, and | |
// the 'h' values contain the high bits of the result. | |
__m128i shift_amt = _mm_and_si128(end_bit_index, _mm_set1_epi32(7)); | |
__m128i rev_shift_amt = _mm_sub_epi32(_mm_set1_epi32(8), shift_amt); | |
__m128i shifted_dwordsl = _mm_srlv_epi32(dwords0, rev_shift_amt); | |
__m128i shifted_dwordsh = _mm_sllv_epi32(dwords1, shift_amt); | |
// combine the low and high parts | |
__m128i finished_bit_grab = _mm_or_si128(shifted_dwordsl, shifted_dwordsh); | |
// mask to desired field widths | |
__m128i width_mask = _mm_sllv_epi32(_mm_set1_epi32(-1), field_widths); | |
__m128i masked_fields = _mm_andnot_si128(width_mask, finished_bit_grab); | |
return masked_fields; | |
} | |
#endif | |
// ---- testbed | |
static void decode8_ref(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
static const uint8_t masks[9] = { 0,1,3,7, 15,31,63,127, 255 }; | |
// we just gleefully over-read; not worrying about that right now. | |
uint64_t bits = 0, bitc = 0; | |
for (size_t i = 0; i < count; i += 7) | |
{ | |
// refill bit buffer (so it contains at least 56 bits) | |
uint64_t bytes_consumed = (63 - bitc) >> 3; | |
bits |= bswap64(*(const uint64_t *) in_ptr) >> bitc; | |
bitc |= 56; | |
in_ptr += bytes_consumed; | |
// decode 7 values | |
uint32_t w; | |
uint64_t mask; | |
#if 1 | |
// !!better!! (by 0.5 cycles/elem on SNB!) | |
#define DECONE(ind) \ | |
w = width_arr[i + ind]; \ | |
bits = rotl64(bits, w); \ | |
mask = masks[w] & bits; \ | |
out_ptr[i + ind] = (uint8_t) mask; \ | |
bits ^= mask; \ | |
bitc -= w | |
#else | |
#define DECONE(ind) \ | |
w = width_arr[i + ind]; \ | |
bits = rotl64(bits, w); \ | |
mask = masks[w]; \ | |
out_ptr[i + ind] = (uint8_t) (bits & mask); \ | |
bits &= ~mask; \ | |
bitc -= w | |
#endif | |
DECONE(0); | |
DECONE(1); | |
DECONE(2); | |
DECONE(3); | |
DECONE(4); | |
DECONE(5); | |
DECONE(6); | |
#undef DECONE | |
} | |
} | |
static void decode8_SSSE3(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 16) | |
{ | |
__m128i widths = _mm_loadu_si128((const __m128i *) (width_arr + i)); | |
__m128i values = multigetbits8(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode8b_SSSE3(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 16) | |
{ | |
__m128i widths = _mm_loadu_si128((const __m128i *) (width_arr + i)); | |
__m128i values = multigetbits8b(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode16_ref(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
static const uint16_t masks[17] = { 0x0000, 0x0001,0x0003,0x0007,0x000f, 0x001f,0x003f,0x007f,0x00ff, 0x01ff,0x03ff,0x07ff,0x0fff, 0x1fff,0x3fff,0x7fff,0xffff }; | |
// we just gleefully over-read; not worrying about that right now. | |
uint64_t bits = 0, bitc = 0; | |
for (size_t i = 0; i < count; i += 3) | |
{ | |
// refill bit buffer (so it contains at least 56 bits) | |
uint64_t bytes_consumed = (63 - bitc) >> 3; | |
bits |= bswap64(*(const uint64_t *) in_ptr) >> bitc; | |
bitc |= 56; | |
in_ptr += bytes_consumed; | |
// decode 3 values | |
uint32_t w; | |
uint64_t mask; | |
#define DECONE(ind) \ | |
w = width_arr[i + ind]; \ | |
bits = rotl64(bits, w); \ | |
mask = masks[w] & bits; \ | |
out_ptr[i + ind] = (uint16_t) mask; \ | |
bits ^= mask; \ | |
bitc -= w | |
DECONE(0); | |
DECONE(1); | |
DECONE(2); | |
#undef DECONE | |
} | |
} | |
static void decode16_SSSE3(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
__m128i byteswap_shuf = _mm_setr_epi8(1,0, 3,2, 5,4, 7,6, 9,8, 11,10, 13,12, 15,14); | |
for (size_t i = 0; i < count; i += 8) | |
{ | |
__m128i widths = _mm_loadl_epi64((const __m128i *) (width_arr + i)); | |
widths = _mm_unpacklo_epi8(widths, widths); | |
widths = _mm_subs_epu8(widths, _mm_set1_epi16(0x0008)); | |
widths = _mm_min_epu8(widths, _mm_set1_epi8(8)); | |
__m128i values = multigetbits8(in_ptr, &bit_pos, widths); | |
values = _mm_shuffle_epi8(values, byteswap_shuf); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode16b_SSSE3(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 8) | |
{ | |
__m128i widths = _mm_loadl_epi64((const __m128i *) (width_arr + i)); | |
widths = _mm_unpacklo_epi8(widths, _mm_setzero_si128()); | |
__m128i values = multigetbits16(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode15_SSSE3(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 8) | |
{ | |
__m128i widths = _mm_loadl_epi64((const __m128i *) (width_arr + i)); | |
widths = _mm_unpacklo_epi8(widths, _mm_setzero_si128()); | |
__m128i values = multigetbits15(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode32_ref(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
static const uint32_t masks[33] = | |
{ | |
0x00000000, | |
0x00000001, 0x00000003, 0x00000007, 0x0000000f, | |
0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff, | |
0x000001ff, 0x000003ff, 0x000007ff, 0x00000fff, | |
0x00001fff, 0x00003fff, 0x00007fff, 0x0000ffff, | |
0x0001ffff, 0x0003ffff, 0x0007ffff, 0x000fffff, | |
0x001fffff, 0x003fffff, 0x007fffff, 0x00ffffff, | |
0x01ffffff, 0x03ffffff, 0x07ffffff, 0x0fffffff, | |
0x1fffffff, 0x3fffffff, 0x7fffffff, 0xffffffff, | |
}; | |
// we just gleefully over-read; not worrying about that right now. | |
uint64_t bitc = 0; | |
for (size_t i = 0; i < count; i++) | |
{ | |
// grab value | |
uint64_t bits = bswap64(*(const uint64_t *) (in_ptr + (bitc >> 3))); | |
uint32_t w = width_arr[i]; | |
out_ptr[i] = rotl64(bits, w + (bitc & 7)) & masks[w]; | |
bitc += w; | |
} | |
} | |
static void decode32_SSSE3(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
__m128i broadcast_shuf = _mm_setr_epi8(0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3); | |
__m128i byteswap_shuf = _mm_setr_epi8(3,2,1,0, 7,6,5,4, 11,10,9,8, 15,14,13,12); | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_shuffle_epi8(widths, broadcast_shuf); | |
widths = _mm_subs_epu8(widths, _mm_set1_epi32(0x00081018)); | |
widths = _mm_min_epu8(widths, _mm_set1_epi8(8)); | |
__m128i values = multigetbits8(in_ptr, &bit_pos, widths); | |
values = _mm_shuffle_epi8(values, byteswap_shuf); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode32b_SSSE3(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
__m128i expand_shuf = _mm_setr_epi8(0,-1,-1,-1, 1,-1,-1,-1, 2,-1,-1,-1, 3,-1,-1,-1); | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_shuffle_epi8(widths, expand_shuf); | |
__m128i values = multigetbits32(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode32c_SSSE3(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
__m128i expand_shuf = _mm_setr_epi8(0,-1,-1,-1, 1,-1,-1,-1, 2,-1,-1,-1, 3,-1,-1,-1); | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_shuffle_epi8(widths, expand_shuf); | |
__m128i values = multigetbits32c(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode24_ref(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
static const uint32_t masks[33] = | |
{ | |
0x00000000, | |
0x00000001, 0x00000003, 0x00000007, 0x0000000f, | |
0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff, | |
0x000001ff, 0x000003ff, 0x000007ff, 0x00000fff, | |
0x00001fff, 0x00003fff, 0x00007fff, 0x0000ffff, | |
0x0001ffff, 0x0003ffff, 0x0007ffff, 0x000fffff, | |
0x001fffff, 0x003fffff, 0x007fffff, 0x00ffffff, | |
0x01ffffff, 0x03ffffff, 0x07ffffff, 0x0fffffff, | |
0x1fffffff, 0x3fffffff, 0x7fffffff, 0xffffffff, | |
}; | |
// we just gleefully over-read; not worrying about that right now. | |
uint64_t bitc = 0; | |
for (size_t i = 0; i < count; i += 2) | |
{ | |
// grab value | |
uint64_t bits = bswap64(*(const uint64_t *) in_ptr); | |
uint32_t w; | |
w = width_arr[i + 0]; | |
out_ptr[i + 0] = rotl64(bits, w + bitc) & masks[w]; | |
bitc += w; | |
w = width_arr[i + 1]; | |
out_ptr[i + 1] = rotl64(bits, w + bitc) & masks[w]; | |
bitc += w; | |
in_ptr += bitc >> 3; | |
bitc &= 7; | |
} | |
} | |
static void decode24_SSE4_v1(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_cvtepu8_epi32(widths); | |
__m128i values = multigetbits24a(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode24_SSE4_v2(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
uint32_t widths = *(const uint32_t *) (width_arr + i); | |
__m128i values = multigetbits24b(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode30_SSE4(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_cvtepu8_epi32(widths); | |
__m128i values = multigetbits30(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
#ifdef __RADAVX__ | |
static void decode24_AVX2_v1(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_cvtepu8_epi32(widths); | |
__m128i values = multigetbits24a_avx2(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode30_AVX2(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_cvtepu8_epi32(widths); | |
__m128i values = multigetbits30_avx2(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
static void decode32_AVX2(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count) | |
{ | |
uint32_t bit_pos = 0; | |
for (size_t i = 0; i < count; i += 4) | |
{ | |
__m128i widths = _mm_cvtsi32_si128(*(const int *) (width_arr + i)); | |
widths = _mm_cvtepu8_epi32(widths); | |
__m128i values = multigetbits32_avx2(in_ptr, &bit_pos, widths); | |
_mm_storeu_si128((__m128i *) (out_ptr + i), values); | |
} | |
} | |
#endif | |
// ---- RNG (PCG XSH RR 64/32 MCG) | |
typedef struct { | |
uint64_t state; | |
} rng32; | |
rng32 rng32_seed(uint64_t seed) | |
{ | |
rng32 r; | |
// state cannot be 0 (MCG) | |
// also do one multiply step in case the input is a small integer (which it often is) | |
r.state = (seed | 1) * 6364136223846793005ULL; | |
return r; | |
} | |
uint32_t rng32_random(rng32 *r) | |
{ | |
// Generate output from old state | |
uint64_t oldstate = r->state; | |
uint32_t rot_input = (uint32_t) (((oldstate >> 18) ^ oldstate) >> 27); | |
uint32_t rot_amount = (uint32_t) (oldstate >> 59); | |
uint32_t output = (rot_input >> rot_amount) | (rot_input << ((0u - rot_amount) & 31)); // rotr(rot_input, rot_amount) | |
// Advance multiplicative congruential generator | |
// Constant from PCG reference impl. | |
r->state = oldstate * 6364136223846793005ull; | |
return output; | |
} | |
// ---- test driver | |
typedef void kernel8_func(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count); | |
typedef void kernel16_func(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count); | |
typedef void kernel32_func(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count); | |
static inline uint32_t cycle_timer() | |
{ | |
#ifdef _MSC_VER | |
return (uint32_t)__rdtsc(); | |
#else | |
uint32_t lo, hi; | |
asm volatile ("rdtsc" : "=a" (lo), "=d" (hi) ); | |
return lo; | |
#endif | |
} | |
int g_sink = 0; // to prevent DCE | |
static int int_compare(const void *a, const void *b) | |
{ | |
int ai = *(const int *)a; | |
int bi = *(const int *)b; | |
return (ai > bi) - (ai < bi); | |
} | |
static uint32_t run_test8(uint8_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count, kernel8_func *func) | |
{ | |
uint32_t start_tsc = cycle_timer(); | |
func(out_ptr, width_arr, in_ptr, count); | |
uint32_t cycle_count = cycle_timer() - start_tsc; | |
return cycle_count; | |
} | |
static void testone8(char const *name, kernel8_func *func) | |
{ | |
static const size_t kNumValues = 7*16*128; // a bit under 16k, divisible by both 16 and 7 | |
//static const size_t kNumValues = 7*16; | |
static const size_t kNumPadded = kNumValues + 16; | |
uint8_t out1_bytes[kNumPadded] = {}; | |
uint8_t out2_bytes[kNumPadded] = {}; | |
uint8_t in_bytes[kNumPadded] = {}; | |
uint8_t widths[kNumPadded] = {}; | |
// set up some test data | |
rng32 rng = rng32_seed(83467); | |
for (size_t i = 0; i < kNumValues; i++) | |
{ | |
uint32_t val = rng32_random(&rng); | |
in_bytes[i] = val & 0xff; | |
widths[i] = (val >> 8) % 9; // values in 0..8 | |
} | |
// verify by comparing against ref | |
run_test8(out1_bytes, widths, in_bytes, kNumValues, func); | |
run_test8(out2_bytes, widths, in_bytes, kNumValues, decode8_ref); | |
if (memcmp(out1_bytes, out2_bytes, kNumValues) != 0) | |
{ | |
size_t i = 0; | |
while (out1_bytes[i] == out2_bytes[i]) | |
++i; | |
printf("%20s: MISMATCH! (at byte %d)\n", name, (int)i); | |
return; | |
} | |
// warm-up | |
static const size_t nWarmupRuns = 100; | |
for (size_t run = 0; run < nWarmupRuns; ++run) | |
g_sink += run_test8(out1_bytes, widths, in_bytes, kNumValues, func); | |
// benchmark | |
static const size_t nRuns = 10000; | |
int *run_lens = new int[nRuns]; | |
for (size_t run = 0; run < nRuns; ++run) | |
run_lens[run] = run_test8(out1_bytes, widths, in_bytes, kNumValues, func); | |
qsort(run_lens, nRuns, sizeof(*run_lens), int_compare); | |
double ratio = 1.0 / kNumValues; | |
printf("%20s: med %.2f/b, 1st%% %.2f/b, 95th%% %.2f/b\n", name, run_lens[nRuns/2]*ratio, run_lens[nRuns/100]*ratio, run_lens[nRuns-1-nRuns/20]*ratio); | |
delete[] run_lens; | |
} | |
static uint32_t run_test16(uint16_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count, kernel16_func *func) | |
{ | |
uint32_t start_tsc = cycle_timer(); | |
func(out_ptr, width_arr, in_ptr, count); | |
uint32_t cycle_count = cycle_timer() - start_tsc; | |
return cycle_count; | |
} | |
static void testone16(char const *name, kernel16_func *func, int max_width) | |
{ | |
static const size_t kNumValues = 3*8*128; // a bit under 16k bytes, divisible by both 8 and 3 | |
static const size_t kNumPadded = kNumValues + 16; | |
uint16_t out1_buf[kNumPadded] = {}; | |
uint16_t out2_buf[kNumPadded] = {}; | |
uint8_t in_bytes[kNumPadded*2] = {}; | |
uint8_t widths[kNumPadded] = {}; | |
// set up some test data | |
rng32 rng = rng32_seed(83467); | |
for (size_t i = 0; i < kNumValues; i++) | |
{ | |
uint32_t val = rng32_random(&rng); | |
in_bytes[i*2+0] = val & 0xff; | |
in_bytes[i*2+1] = (val >> 8) & 0xff; | |
widths[i] = (val >> 16) % (max_width + 1); | |
} | |
// verify by comparing against ref | |
run_test16(out1_buf, widths, in_bytes, kNumValues, func); | |
run_test16(out2_buf, widths, in_bytes, kNumValues, decode16_ref); | |
if (memcmp(out1_buf, out2_buf, kNumValues * sizeof(uint16_t)) != 0) | |
{ | |
size_t i = 0; | |
while (out1_buf[i] == out2_buf[i]) | |
++i; | |
printf("%20s: MISMATCH! (at word %d)\n", name, (int)i); | |
return; | |
} | |
// warm-up | |
static const size_t nWarmupRuns = 100; | |
for (size_t run = 0; run < nWarmupRuns; ++run) | |
g_sink += run_test16(out1_buf, widths, in_bytes, kNumValues, func); | |
// benchmark | |
static const size_t nRuns = 10000; | |
int *run_lens = new int[nRuns]; | |
for (size_t run = 0; run < nRuns; ++run) | |
run_lens[run] = run_test16(out1_buf, widths, in_bytes, kNumValues, func); | |
qsort(run_lens, nRuns, sizeof(*run_lens), int_compare); | |
double ratio = 1.0 / kNumValues; | |
printf("%20s: med %.2f/b, 1st%% %.2f/b, 95th%% %.2f/b\n", name, run_lens[nRuns/2]*ratio, run_lens[nRuns/100]*ratio, run_lens[nRuns-1-nRuns/20]*ratio); | |
delete[] run_lens; | |
} | |
static uint32_t run_test32(uint32_t *out_ptr, const uint8_t *width_arr, const uint8_t *in_ptr, size_t count, kernel32_func *func) | |
{ | |
uint32_t start_tsc = cycle_timer(); | |
func(out_ptr, width_arr, in_ptr, count); | |
uint32_t cycle_count = cycle_timer() - start_tsc; | |
return cycle_count; | |
} | |
static void testone32(char const *name, kernel32_func *func, int max_width) | |
{ | |
static const size_t kNumValues = 3*8*128; // a bit under 16k bytes, divisible by both 8 and 3 | |
static const size_t kNumPadded = kNumValues + 16; | |
uint32_t out1_buf[kNumPadded] = {}; | |
uint32_t out2_buf[kNumPadded] = {}; | |
uint8_t in_bytes[kNumPadded*2] = {}; | |
uint8_t widths[kNumPadded] = {}; | |
// set up some test data | |
rng32 rng = rng32_seed(83467); | |
for (size_t i = 0; i < kNumValues; i++) | |
{ | |
uint32_t val = rng32_random(&rng); | |
in_bytes[i*2+0] = val & 0xff; | |
in_bytes[i*2+1] = (val >> 8) & 0xff; | |
widths[i] = (val >> 16) % (max_width + 1); | |
} | |
// verify by comparing against ref | |
run_test32(out1_buf, widths, in_bytes, kNumValues, func); | |
run_test32(out2_buf, widths, in_bytes, kNumValues, decode32_ref); | |
if (memcmp(out1_buf, out2_buf, kNumValues * sizeof(uint32_t)) != 0) | |
{ | |
size_t i = 0; | |
while (out1_buf[i] == out2_buf[i]) | |
++i; | |
printf("%20s: MISMATCH! (at word %d)\n", name, (int)i); | |
return; | |
} | |
// warm-up | |
static const size_t nWarmupRuns = 1000; | |
for (size_t run = 0; run < nWarmupRuns; ++run) | |
g_sink += run_test32(out1_buf, widths, in_bytes, kNumValues, func); | |
// benchmark | |
static const size_t nRuns = 30000; | |
int *run_lens = new int[nRuns]; | |
for (size_t run = 0; run < nRuns; ++run) | |
run_lens[run] = run_test32(out1_buf, widths, in_bytes, kNumValues, func); | |
qsort(run_lens, nRuns, sizeof(*run_lens), int_compare); | |
double ratio = 1.0 / kNumValues; | |
printf("%20s: med %.2f/b, 1st%% %.2f/b, 95th%% %.2f/b\n", name, run_lens[nRuns/2]*ratio, run_lens[nRuns/100]*ratio, run_lens[nRuns-1-nRuns/20]*ratio); | |
delete[] run_lens; | |
} | |
int main() | |
{ | |
#define TESTIT8(what) testone8(#what, what) | |
#define TESTIT16(what, width) testone16(#what, what, width) | |
#define TESTIT32(what, width) testone32(#what, what, width) | |
TESTIT8(decode8_ref); | |
TESTIT8(decode8_SSSE3); | |
TESTIT8(decode8b_SSSE3); | |
TESTIT16(decode15_SSSE3, 15); | |
TESTIT16(decode16_ref, 16); | |
TESTIT16(decode16_SSSE3, 16); | |
TESTIT16(decode16b_SSSE3, 16); | |
TESTIT32(decode24_ref, 24); | |
TESTIT32(decode24_SSE4_v1, 24); | |
TESTIT32(decode24_SSE4_v2, 24); | |
TESTIT32(decode30_SSE4, 30); | |
TESTIT32(decode32_ref, 32); | |
TESTIT32(decode32_SSSE3, 32); | |
TESTIT32(decode32b_SSSE3, 32); | |
TESTIT32(decode32c_SSSE3, 32); | |
#ifdef __RADAVX__ | |
TESTIT32(decode24_AVX2_v1, 24); | |
TESTIT32(decode30_AVX2, 30); | |
TESTIT32(decode32_AVX2, 30); | |
#endif | |
#undef TESTIT8 | |
#undef TESTIT16 | |
#undef TESTIT32 | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment