Skip to content

Instantly share code, notes, and snippets.

@vurtun
Last active March 3, 2025 08:53
Show Gist options
  • Save vurtun/47c14b1abd98911b10e5f0a45c254616 to your computer and use it in GitHub Desktop.
Save vurtun/47c14b1abd98911b10e5f0a45c254616 to your computer and use it in GitHub Desktop.
object string filter using sse2, avx2, neon
/* This algorithm uses as a predicate equality of the first and the last characters from the substring. These two characters are populated in two registers
* F and L respectively. Then in each iteration two chunks of strings are loaded. The first chunk (A) is read from offset i (where i is the current offset)
* and the second chunk (B) is read from offset i + k - 1, where k is sub string's length. Then we compute a vector expression F == A and B == L.
* This step yields a byte vector (or a bit mask), where "true" values denote position of potential substring occurrences.
* Finally, just at these positions an exact comparisons of sub strings are performed.
*
* Example: Let's assume 8-byte registers. We're searching for word "cat", thus:
*
* F = [ c | c | c | c | c | c | c | c ]
* L = [ t | t | t | t | t | t | t | t ]
*
* We're searching in the string "a_cat_tries". In the first iteration the register A gets data from offset 0, B from offset 2:
*
* A = [ a | _ | c | a | t | _ | t | r ]
* B = [ c | a | t | _ | t | r | i | e ]
*
* Now we compare:
*
* AF = ( A == F ) = [ 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 ]
* BL = ( B == L ) = [ 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 ]
*
* After merging comparison results, i.e.AF & BL, we get following mask :
*
* mask = [ 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 ]
*
* Since the mask is non-zero, it means there are possible substring occurrences.
* As we see, there is only one non-zero element at index 2, thus only one sub string comparison must be performed.
*
* For the actual search of all objects we use a single buffer containing all descriptions of all objects separated by '\0'.
* In addition a separate array of object ids is kept to map from description based on number of `\0` encountered to array index
* and therefore object id.
*/
#include <emmintrin.h> // SSE2 intrinsics
extern int
fltr_obj(int *res, const char *obj_desc_buf, const char *obj_desc_buf_end,
const char *fltr_str, const char *fltr_str_end) {
if (fltr_str == fltr_str_end) {
return 0;
}
const char *buf_ptr = obj_desc_buf;
const char *fltr_start = fltr_str;
int fltr_len = fltr_str_end - fltr_str;
int match_cnt = 0;
int str_cnt = 0;
const __m128i zero_vec = _mm_setzero_si128();
const __m128i first_vec = _mm_set1_epi8(fltr_str[0]);
const __m128i last_vec = _mm_set1_epi8(fltr_str_end[-1]);
while (buf_ptr < obj_desc_buf_end) {
_mm_prefetch(buf_ptr + 256, _MM_HINT_T0); // Smaller prefetch distance for SSE2
__m128i first_block = _mm_load_si128((const __m128i*)buf_ptr);
__m128i last_block = _mm_loadu_si128((const __m128i*)(buf_ptr + fltr_len - 1));
/* Compare first, last, and zero */
__m128i eq_first = _mm_cmpeq_epi8(first_vec, first_block);
__m128i eq_last = _mm_cmpeq_epi8(last_vec, last_block);
__m128i eq_zero = _mm_cmpeq_epi8(first_block, zero_vec);
/* Combine matches */
__m128i match_vec = _mm_and_si128(eq_first, eq_last);
unsigned match_msk = _mm_movemask_epi8(match_vec);
unsigned zero_mask = _mm_movemask_epi8(eq_zero);
while (match_msk) {
unsigned bit_pos = __builtin_ctz(match_msk);
const char *match_ptr = buf_ptr + bit_pos;
const char *filter_ptr = fltr_start;
int left = fltr_len;
int match = 1;
/* Inline SIMD comparison for full vectors */
while (left > 15) {
__m128i match_vec = _mm_loadu_si128((const __m128i*)match_ptr);
__m128i filter_vec = _mm_loadu_si128((const __m128i*)filter_ptr);
__m128i eq_vec = _mm_cmpeq_epi8(match_vec, filter_vec);
if (_mm_movemask_epi8(eq_vec) != 0xFFFF) { // 16 bits set for 128-bit vector
match = 0;
break;
}
match_ptr += 16;
filter_ptr += 16;
left -= 16;
}
while (match && left > 7) {
unsigned long long *match_vec = (unsigned long long*)match_ptr;
unsigned long long *fltr_vec = (unsigned long long*)filter_ptr;
if (*match_vec != *fltr_vec) {
match = 0;
break;
}
match_ptr += 8;
filter_ptr += 8;
left -= 8;
}
while (match && left > 3) {
unsigned *match_vec = (unsigned*)match_ptr;
unsigned *fltr_vec = (unsigned*)filter_ptr;
if (*match_vec != *fltr_vec) {
match = 0;
break;
}
match_ptr += 4;
filter_ptr += 4;
left -= 4;
}
/* Scalar comparison with pointer increments */
while (match && left > 0) {
if (*match_ptr != *filter_ptr) {
match = 0;
break;
}
match_ptr++;
filter_ptr++;
left--;
}
if (match) {
unsigned cur_idx = str_cnt + __builtin_popcount(zero_mask & ((1u << bit_pos) - 1));
res[match_cnt++] = cur_idx;
}
match_msk &= (match_msk - 1);
}
str_cnt += __builtin_popcount(zero_mask);
buf_ptr += 16;
}
return match_cnt;
}
#include <immintrin.h>
extern int
fltr_obj(int *res, const char *obj_desc_buf, const char *obj_desc_buf_end,
const char *fltr_str, const char *fltr_str_end) {
if (fltr_str == fltr_str_end) {
return 0;
}
const char *buf_ptr = obj_desc_buf;
const char *fltr_start = fltr_str;
int fltr_len = fltr_str_end - fltr_str;
int match_cnt = 0;
int str_cnt = 0;
const __m256i zero_vec = _mm256_setzero_si256();
const __m256i first_vec = _mm256_set1_epi8(fltr_str[0]);
const __m256i last_vec = _mm256_set1_epi8(fltr_str_end[-1]);
while (buf_ptr < obj_desc_buf_end) {
_mm_prefetch(buf_ptr + 512, _MM_HINT_T0);
__m256i first_block = _mm256_load_si256((const __m256i*)buf_ptr);
__m256i last_block = _mm256_loadu_si256((const __m256i*)(buf_ptr + fltr_len - 1));
/* compare first, last, and zero */
__m256i eq_first = _mm256_cmpeq_epi8(first_vec, first_block);
__m256i eq_last = _mm256_cmpeq_epi8(last_vec, last_block);
__m256i eq_zero = _mm256_cmpeq_epi8(first_block, zero_vec);
/* combine matches */
__m256i match_vec = _mm256_and_si256(eq_first, eq_last);
unsigned match_msk = _mm256_movemask_epi8(match_vec);
unsigned zero_mask = _mm256_movemask_epi8(eq_zero);
while (match_msk) {
unsigned bit_pos = __builtin_ctz(match_msk);
const char *match_ptr = buf_ptr + bit_pos;
const char *filter_ptr = fltr_start;
int left = fltr_len;
int match = 1;
/* inline SIMD comparison for full vectors */
while (left > 31) {
__m256i match_vec = _mm256_loadu_si256((const __m256i*)match_ptr);
__m256i filter_vec = _mm256_loadu_si256((const __m256i*)filter_ptr);
__m256i eq_vec = _mm256_cmpeq_epi8(match_vec, filter_vec);
if (_mm256_movemask_epi8(eq_vec) != -1) {
match = 0;
break;
}
match_ptr += 32;
filter_ptr += 32;
left -= 32;
}
while (match && left > 15) {
__m128i match_vec = _mm_loadu_si128((const __m128i*)match_ptr);
__m128i filter_vec = _mm_loadu_si128((const __m128i*)filter_ptr);
__m128i eq_vec = _mm_cmpeq_epi8(match_vec, filter_vec);
if (_mm_movemask_epi8(eq_vec) != 0xFFFF) { // 16 bits set for 128-bit vector
match = 0;
break;
}
match_ptr += 16;
filter_ptr += 16;
left -= 16;
}
while (match && left > 7) {
unsigned long long *match_vec = (unsigned long long*)match_ptr;
unsigned long long *fltr_vec = (unsigned long long*)filter_ptr;
if (*match_vec != *fltr_vec) {
match = 0;
break;
}
match_ptr += 8;
filter_ptr += 8;
left -= 8;
}
while (match && left > 3) {
unsigned *match_vec = (unsigned*)match_ptr;
unsigned *fltr_vec = (unsigned*)filter_ptr;
if (*match_vec != *fltr_vec) {
match = 0;
break;
}
match_ptr += 4;
filter_ptr += 4;
left -= 4;
}
/* scalar comparison with pointer increments */
while (match && left > 0) {
if (*match_ptr != *filter_ptr) {
match = 0;
break;
}
match_ptr++;
filter_ptr++;
left--;
}
if (match) {
unsigned cur_idx = str_cnt + __builtin_popcount(zero_mask & ((1u << bit_pos) - 1));
res[match_cnt++] = cur_idx;
}
match_msk &= (match_msk - 1);
}
str_cnt += __builtin_popcount(zero_mask);
buf_ptr += 32;
}
return match_cnt;
}
#include <arm_neon.h>
extern int
fltr_obj(int *res, const char *obj_desc_buf, const char *obj_desc_buf_end,
const char *fltr_str, const char *fltr_str_end) {
if (fltr_str == fltr_str_end) {
return 0;
}
const char *buf_ptr = obj_desc_buf;
const char *fltr_start = fltr_str;
int fltr_len = fltr_str_end - fltr_str;
int match_cnt = 0;
int str_cnt = 0;
const uint8x16_t zero_vec = vdupq_n_u8(0);
const uint8x16_t first_vec = vdupq_n_u8(fltr_str[0]);
const uint8x16_t last_vec = vdupq_n_u8(fltr_str_end[-1]);
const uint32_t full_match_sum = 0xFF * 16;
while (buf_ptr < obj_desc_buf_end) {
__builtin_prefetch(buf_ptr + 256, 0, 3);
uint8x16_t first_block1 = vld1q_u8((const uint8_t *)buf_ptr);
uint8x16_t first_block2 = vld1q_u8((const uint8_t *)(buf_ptr + 16));
uint8x16_t last_block1 = vld1q_u8((const uint8_t *)(buf_ptr + fltr_len - 1));
uint8x16_t last_block2 = vld1q_u8((const uint8_t *)(buf_ptr + fltr_len + 15));
uint8x16_t eq_first1 = vceqq_u8(first_vec, first_block1);
uint8x16_t eq_first2 = vceqq_u8(first_vec, first_block2);
uint8x16_t eq_last1 = vceqq_u8(last_vec, last_block1);
uint8x16_t eq_last2 = vceqq_u8(last_vec, last_block2);
uint8x16_t eq_zero1 = vceqq_u8(first_block1, zero_vec);
uint8x16_t eq_zero2 = vceqq_u8(first_block2, zero_vec);
uint8x16_t match_vec1 = vandq_u8(eq_first1, eq_last1);
uint8x16_t match_vec2 = vandq_u8(eq_first2, eq_last2);
uint32_t match_msk = (vaddvq_u8(match_vec1) ? 0xFFFF : 0) |
(vaddvq_u8(match_vec2) ? 0xFFFF0000 : 0);
uint32_t zero_mask = (uint32_t)vaddvq_u8(eq_zero1) | ((uint32_t)vaddvq_u8(eq_zero2) << 16);
while (match_msk) {
unsigned bit_pos = __builtin_ctz(match_msk);
const char *match_ptr = buf_ptr + bit_pos;
const char *filter_ptr = fltr_start;
int left = fltr_len;
int match = 1;
// Force NEON with inline attribute
#pragma clang loop vectorize(enable)
while (left > 15) {
uint8x16_t match_vec = vld1q_u8((const uint8_t *)match_ptr);
uint8x16_t filter_vec = vld1q_u8((const uint8_t *)filter_ptr);
uint8x16_t eq_vec = vceqq_u8(match_vec, filter_vec);
if ((uint32_t)vaddvq_u8(eq_vec) != full_match_sum) {
match = 0;
break;
}
match_ptr += 16;
filter_ptr += 16;
left -= 16;
}
while (match && left > 7) {
unsigned long long *match_vec = (unsigned long long*)match_ptr;
unsigned long long *fltr_vec = (unsigned long long*)filter_ptr;
if (*match_vec != *fltr_vec) {
match = 0;
break;
}
match_ptr += 8;
filter_ptr += 8;
left -= 8;
}
while (match && left > 3) {
unsigned *match_vec = (unsigned*)match_ptr;
unsigned *fltr_vec = (unsigned*)filter_ptr;
if (*match_vec != *fltr_vec) {
match = 0;
break;
}
match_ptr += 4;
filter_ptr += 4;
left -= 4;
}
while (match && left > 0) {
if (*match_ptr != *filter_ptr) {
match = 0;
break;
}
match_ptr++;
filter_ptr++;
left--;
}
if (match) {
unsigned cur_idx = str_cnt + __builtin_popcount(zero_mask & ((1u << bit_pos) - 1));
res[match_cnt++] = cur_idx;
}
match_msk &= (match_msk - 1);
}
str_cnt += __builtin_popcount(zero_mask);
buf_ptr += 32;
}
return match_cnt;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment