Last active
March 3, 2025 08:53
-
-
Save vurtun/47c14b1abd98911b10e5f0a45c254616 to your computer and use it in GitHub Desktop.
object string filter using sse2, avx2, neon
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* This algorithm uses as a predicate equality of the first and the last characters from the substring. These two characters are populated in two registers | |
* F and L respectively. Then in each iteration two chunks of strings are loaded. The first chunk (A) is read from offset i (where i is the current offset) | |
* and the second chunk (B) is read from offset i + k - 1, where k is sub string's length. Then we compute a vector expression F == A and B == L. | |
* This step yields a byte vector (or a bit mask), where "true" values denote position of potential substring occurrences. | |
* Finally, just at these positions an exact comparisons of sub strings are performed. | |
* | |
* Example: Let's assume 8-byte registers. We're searching for word "cat", thus: | |
* | |
* F = [ c | c | c | c | c | c | c | c ] | |
* L = [ t | t | t | t | t | t | t | t ] | |
* | |
* We're searching in the string "a_cat_tries". In the first iteration the register A gets data from offset 0, B from offset 2: | |
* | |
* A = [ a | _ | c | a | t | _ | t | r ] | |
* B = [ c | a | t | _ | t | r | i | e ] | |
* | |
* Now we compare: | |
* | |
* AF = ( A == F ) = [ 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 ] | |
* BL = ( B == L ) = [ 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 ] | |
* | |
* After merging comparison results, i.e.AF & BL, we get following mask : | |
* | |
* mask = [ 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 ] | |
* | |
* Since the mask is non-zero, it means there are possible substring occurrences. | |
* As we see, there is only one non-zero element at index 2, thus only one sub string comparison must be performed. | |
* | |
* For the actual search of all objects we use a single buffer containing all descriptions of all objects separated by '\0'. | |
* In addition a separate array of object ids is kept to map from description based on number of `\0` encountered to array index | |
* and therefore object id. | |
*/ | |
#include <emmintrin.h> // SSE2 intrinsics | |
extern int | |
fltr_obj(int *res, const char *obj_desc_buf, const char *obj_desc_buf_end, | |
const char *fltr_str, const char *fltr_str_end) { | |
if (fltr_str == fltr_str_end) { | |
return 0; | |
} | |
const char *buf_ptr = obj_desc_buf; | |
const char *fltr_start = fltr_str; | |
int fltr_len = fltr_str_end - fltr_str; | |
int match_cnt = 0; | |
int str_cnt = 0; | |
const __m128i zero_vec = _mm_setzero_si128(); | |
const __m128i first_vec = _mm_set1_epi8(fltr_str[0]); | |
const __m128i last_vec = _mm_set1_epi8(fltr_str_end[-1]); | |
while (buf_ptr < obj_desc_buf_end) { | |
_mm_prefetch(buf_ptr + 256, _MM_HINT_T0); // Smaller prefetch distance for SSE2 | |
__m128i first_block = _mm_load_si128((const __m128i*)buf_ptr); | |
__m128i last_block = _mm_loadu_si128((const __m128i*)(buf_ptr + fltr_len - 1)); | |
/* Compare first, last, and zero */ | |
__m128i eq_first = _mm_cmpeq_epi8(first_vec, first_block); | |
__m128i eq_last = _mm_cmpeq_epi8(last_vec, last_block); | |
__m128i eq_zero = _mm_cmpeq_epi8(first_block, zero_vec); | |
/* Combine matches */ | |
__m128i match_vec = _mm_and_si128(eq_first, eq_last); | |
unsigned match_msk = _mm_movemask_epi8(match_vec); | |
unsigned zero_mask = _mm_movemask_epi8(eq_zero); | |
while (match_msk) { | |
unsigned bit_pos = __builtin_ctz(match_msk); | |
const char *match_ptr = buf_ptr + bit_pos; | |
const char *filter_ptr = fltr_start; | |
int left = fltr_len; | |
int match = 1; | |
/* Inline SIMD comparison for full vectors */ | |
while (left > 15) { | |
__m128i match_vec = _mm_loadu_si128((const __m128i*)match_ptr); | |
__m128i filter_vec = _mm_loadu_si128((const __m128i*)filter_ptr); | |
__m128i eq_vec = _mm_cmpeq_epi8(match_vec, filter_vec); | |
if (_mm_movemask_epi8(eq_vec) != 0xFFFF) { // 16 bits set for 128-bit vector | |
match = 0; | |
break; | |
} | |
match_ptr += 16; | |
filter_ptr += 16; | |
left -= 16; | |
} | |
while (match && left > 7) { | |
unsigned long long *match_vec = (unsigned long long*)match_ptr; | |
unsigned long long *fltr_vec = (unsigned long long*)filter_ptr; | |
if (*match_vec != *fltr_vec) { | |
match = 0; | |
break; | |
} | |
match_ptr += 8; | |
filter_ptr += 8; | |
left -= 8; | |
} | |
while (match && left > 3) { | |
unsigned *match_vec = (unsigned*)match_ptr; | |
unsigned *fltr_vec = (unsigned*)filter_ptr; | |
if (*match_vec != *fltr_vec) { | |
match = 0; | |
break; | |
} | |
match_ptr += 4; | |
filter_ptr += 4; | |
left -= 4; | |
} | |
/* Scalar comparison with pointer increments */ | |
while (match && left > 0) { | |
if (*match_ptr != *filter_ptr) { | |
match = 0; | |
break; | |
} | |
match_ptr++; | |
filter_ptr++; | |
left--; | |
} | |
if (match) { | |
unsigned cur_idx = str_cnt + __builtin_popcount(zero_mask & ((1u << bit_pos) - 1)); | |
res[match_cnt++] = cur_idx; | |
} | |
match_msk &= (match_msk - 1); | |
} | |
str_cnt += __builtin_popcount(zero_mask); | |
buf_ptr += 16; | |
} | |
return match_cnt; | |
} | |
#include <immintrin.h> | |
extern int | |
fltr_obj(int *res, const char *obj_desc_buf, const char *obj_desc_buf_end, | |
const char *fltr_str, const char *fltr_str_end) { | |
if (fltr_str == fltr_str_end) { | |
return 0; | |
} | |
const char *buf_ptr = obj_desc_buf; | |
const char *fltr_start = fltr_str; | |
int fltr_len = fltr_str_end - fltr_str; | |
int match_cnt = 0; | |
int str_cnt = 0; | |
const __m256i zero_vec = _mm256_setzero_si256(); | |
const __m256i first_vec = _mm256_set1_epi8(fltr_str[0]); | |
const __m256i last_vec = _mm256_set1_epi8(fltr_str_end[-1]); | |
while (buf_ptr < obj_desc_buf_end) { | |
_mm_prefetch(buf_ptr + 512, _MM_HINT_T0); | |
__m256i first_block = _mm256_load_si256((const __m256i*)buf_ptr); | |
__m256i last_block = _mm256_loadu_si256((const __m256i*)(buf_ptr + fltr_len - 1)); | |
/* compare first, last, and zero */ | |
__m256i eq_first = _mm256_cmpeq_epi8(first_vec, first_block); | |
__m256i eq_last = _mm256_cmpeq_epi8(last_vec, last_block); | |
__m256i eq_zero = _mm256_cmpeq_epi8(first_block, zero_vec); | |
/* combine matches */ | |
__m256i match_vec = _mm256_and_si256(eq_first, eq_last); | |
unsigned match_msk = _mm256_movemask_epi8(match_vec); | |
unsigned zero_mask = _mm256_movemask_epi8(eq_zero); | |
while (match_msk) { | |
unsigned bit_pos = __builtin_ctz(match_msk); | |
const char *match_ptr = buf_ptr + bit_pos; | |
const char *filter_ptr = fltr_start; | |
int left = fltr_len; | |
int match = 1; | |
/* inline SIMD comparison for full vectors */ | |
while (left > 31) { | |
__m256i match_vec = _mm256_loadu_si256((const __m256i*)match_ptr); | |
__m256i filter_vec = _mm256_loadu_si256((const __m256i*)filter_ptr); | |
__m256i eq_vec = _mm256_cmpeq_epi8(match_vec, filter_vec); | |
if (_mm256_movemask_epi8(eq_vec) != -1) { | |
match = 0; | |
break; | |
} | |
match_ptr += 32; | |
filter_ptr += 32; | |
left -= 32; | |
} | |
while (match && left > 15) { | |
__m128i match_vec = _mm_loadu_si128((const __m128i*)match_ptr); | |
__m128i filter_vec = _mm_loadu_si128((const __m128i*)filter_ptr); | |
__m128i eq_vec = _mm_cmpeq_epi8(match_vec, filter_vec); | |
if (_mm_movemask_epi8(eq_vec) != 0xFFFF) { // 16 bits set for 128-bit vector | |
match = 0; | |
break; | |
} | |
match_ptr += 16; | |
filter_ptr += 16; | |
left -= 16; | |
} | |
while (match && left > 7) { | |
unsigned long long *match_vec = (unsigned long long*)match_ptr; | |
unsigned long long *fltr_vec = (unsigned long long*)filter_ptr; | |
if (*match_vec != *fltr_vec) { | |
match = 0; | |
break; | |
} | |
match_ptr += 8; | |
filter_ptr += 8; | |
left -= 8; | |
} | |
while (match && left > 3) { | |
unsigned *match_vec = (unsigned*)match_ptr; | |
unsigned *fltr_vec = (unsigned*)filter_ptr; | |
if (*match_vec != *fltr_vec) { | |
match = 0; | |
break; | |
} | |
match_ptr += 4; | |
filter_ptr += 4; | |
left -= 4; | |
} | |
/* scalar comparison with pointer increments */ | |
while (match && left > 0) { | |
if (*match_ptr != *filter_ptr) { | |
match = 0; | |
break; | |
} | |
match_ptr++; | |
filter_ptr++; | |
left--; | |
} | |
if (match) { | |
unsigned cur_idx = str_cnt + __builtin_popcount(zero_mask & ((1u << bit_pos) - 1)); | |
res[match_cnt++] = cur_idx; | |
} | |
match_msk &= (match_msk - 1); | |
} | |
str_cnt += __builtin_popcount(zero_mask); | |
buf_ptr += 32; | |
} | |
return match_cnt; | |
} | |
#include <arm_neon.h> | |
extern int | |
fltr_obj(int *res, const char *obj_desc_buf, const char *obj_desc_buf_end, | |
const char *fltr_str, const char *fltr_str_end) { | |
if (fltr_str == fltr_str_end) { | |
return 0; | |
} | |
const char *buf_ptr = obj_desc_buf; | |
const char *fltr_start = fltr_str; | |
int fltr_len = fltr_str_end - fltr_str; | |
int match_cnt = 0; | |
int str_cnt = 0; | |
const uint8x16_t zero_vec = vdupq_n_u8(0); | |
const uint8x16_t first_vec = vdupq_n_u8(fltr_str[0]); | |
const uint8x16_t last_vec = vdupq_n_u8(fltr_str_end[-1]); | |
const uint32_t full_match_sum = 0xFF * 16; | |
while (buf_ptr < obj_desc_buf_end) { | |
__builtin_prefetch(buf_ptr + 256, 0, 3); | |
uint8x16_t first_block1 = vld1q_u8((const uint8_t *)buf_ptr); | |
uint8x16_t first_block2 = vld1q_u8((const uint8_t *)(buf_ptr + 16)); | |
uint8x16_t last_block1 = vld1q_u8((const uint8_t *)(buf_ptr + fltr_len - 1)); | |
uint8x16_t last_block2 = vld1q_u8((const uint8_t *)(buf_ptr + fltr_len + 15)); | |
uint8x16_t eq_first1 = vceqq_u8(first_vec, first_block1); | |
uint8x16_t eq_first2 = vceqq_u8(first_vec, first_block2); | |
uint8x16_t eq_last1 = vceqq_u8(last_vec, last_block1); | |
uint8x16_t eq_last2 = vceqq_u8(last_vec, last_block2); | |
uint8x16_t eq_zero1 = vceqq_u8(first_block1, zero_vec); | |
uint8x16_t eq_zero2 = vceqq_u8(first_block2, zero_vec); | |
uint8x16_t match_vec1 = vandq_u8(eq_first1, eq_last1); | |
uint8x16_t match_vec2 = vandq_u8(eq_first2, eq_last2); | |
uint32_t match_msk = (vaddvq_u8(match_vec1) ? 0xFFFF : 0) | | |
(vaddvq_u8(match_vec2) ? 0xFFFF0000 : 0); | |
uint32_t zero_mask = (uint32_t)vaddvq_u8(eq_zero1) | ((uint32_t)vaddvq_u8(eq_zero2) << 16); | |
while (match_msk) { | |
unsigned bit_pos = __builtin_ctz(match_msk); | |
const char *match_ptr = buf_ptr + bit_pos; | |
const char *filter_ptr = fltr_start; | |
int left = fltr_len; | |
int match = 1; | |
// Force NEON with inline attribute | |
#pragma clang loop vectorize(enable) | |
while (left > 15) { | |
uint8x16_t match_vec = vld1q_u8((const uint8_t *)match_ptr); | |
uint8x16_t filter_vec = vld1q_u8((const uint8_t *)filter_ptr); | |
uint8x16_t eq_vec = vceqq_u8(match_vec, filter_vec); | |
if ((uint32_t)vaddvq_u8(eq_vec) != full_match_sum) { | |
match = 0; | |
break; | |
} | |
match_ptr += 16; | |
filter_ptr += 16; | |
left -= 16; | |
} | |
while (match && left > 7) { | |
unsigned long long *match_vec = (unsigned long long*)match_ptr; | |
unsigned long long *fltr_vec = (unsigned long long*)filter_ptr; | |
if (*match_vec != *fltr_vec) { | |
match = 0; | |
break; | |
} | |
match_ptr += 8; | |
filter_ptr += 8; | |
left -= 8; | |
} | |
while (match && left > 3) { | |
unsigned *match_vec = (unsigned*)match_ptr; | |
unsigned *fltr_vec = (unsigned*)filter_ptr; | |
if (*match_vec != *fltr_vec) { | |
match = 0; | |
break; | |
} | |
match_ptr += 4; | |
filter_ptr += 4; | |
left -= 4; | |
} | |
while (match && left > 0) { | |
if (*match_ptr != *filter_ptr) { | |
match = 0; | |
break; | |
} | |
match_ptr++; | |
filter_ptr++; | |
left--; | |
} | |
if (match) { | |
unsigned cur_idx = str_cnt + __builtin_popcount(zero_mask & ((1u << bit_pos) - 1)); | |
res[match_cnt++] = cur_idx; | |
} | |
match_msk &= (match_msk - 1); | |
} | |
str_cnt += __builtin_popcount(zero_mask); | |
buf_ptr += 32; | |
} | |
return match_cnt; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment