Skip to content

Instantly share code, notes, and snippets.

@zingaburga
Last active July 8, 2024 11:56
Show Gist options
  • Save zingaburga/e1a2650703f0c85cf9cf21cd6b028d82 to your computer and use it in GitHub Desktop.
Save zingaburga/e1a2650703f0c85cf9cf21cd6b028d82 to your computer and use it in GitHub Desktop.
Despace AVX512: aligned vs unaligned
#include <immintrin.h>
#include <stdio.h>
#include <sys/time.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
// assume initial pointers are aligned and len is a multiple of 128
// compiled with: cc -march=sapphirerapids -O3 despace.c
size_t scalar(char *restrict out, const char *restrict in, size_t len) {
uintptr_t out_start = (uintptr_t)out;
for(size_t i=0; i<len; i++) {
char d = *in++;
if(d != 32) {
*out++ = d;
}
}
return (uintptr_t)out - out_start;
}
size_t unaligned(char *restrict out, const char *restrict in, size_t len) {
const __m512i space = _mm512_set1_epi8(32);
// 2 cycle unroll
size_t outlen = 0;
for(size_t i=0; i<len; i+=128) {
__m512i d1 = _mm512_load_si512(in + i);
__m512i d2 = _mm512_load_si512(in + i + 64);
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space);
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space);
#if defined(__tune_znver4__)
d1 = _mm512_maskz_compress_epi8(m1, d1);
d2 = _mm512_maskz_compress_epi8(m2, d2);
_mm512_storeu_si512(out + outlen, d1);
outlen += _mm_popcnt_u64(_cvtmask64_u64(m1));
_mm512_storeu_si512(out + outlen, d2);
outlen += _mm_popcnt_u64(_cvtmask64_u64(m2));
#else
_mm512_mask_compressstoreu_epi8(out + outlen, m1, d1);
outlen += _mm_popcnt_u64(_cvtmask64_u64(m1));
_mm512_mask_compressstoreu_epi8(out + outlen, m2, d2);
outlen += _mm_popcnt_u64(_cvtmask64_u64(m2));
#endif
}
return outlen;
}
size_t aligned(char *restrict out, const char *restrict in, size_t len) {
const __m512i space = _mm512_set1_epi8(32);
const __m512i permidx = _mm512_set_epi32(
0x3f3e3d3c, 0x3b3a3938, 0x37363534, 0x33323130,
0x2f2e2d2c, 0x2b2a2928, 0x27262524, 0x23222120,
0x1f1e1d1c, 0x1b1a1918, 0x17161514, 0x13121110,
0x0f0e0d0c, 0x0b0a0908, 0x07060504, 0x03020100
);
// 2 cycle unroll
size_t outlen = 0;
__m512i aligned_vec;
for(size_t i=0; i<len; i+=128) {
__m512i d1 = _mm512_load_si512(in + i);
__m512i d2 = _mm512_load_si512(in + i + 64);
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space);
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space);
d1 = _mm512_maskz_compress_epi8(m1, d1);
d2 = _mm512_maskz_compress_epi8(m2, d2);
size_t len1 = _mm_popcnt_u64(_cvtmask64_u64(m1));
size_t len2 = _mm_popcnt_u64(_cvtmask64_u64(m2));
// merge
size_t len0 = outlen & 63;
__m512i pos1 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len0));
__m512i shifted1 = _mm512_permutexvar_epi8(pos1, d1);
__mmask64 merge1 = _mm512_movepi8_mask(pos1);
aligned_vec = _mm512_mask_blend_epi8(merge1, shifted1, aligned_vec);
// if the vector is now full, write it out
if(len0 + len1 >= 64) {
_mm512_store_si512(out + (outlen & ~63), aligned_vec);
aligned_vec = shifted1;
outlen += len1;
len1 = (len0 + len1) & 63;
} else {
outlen += len1;
len1 += len0;
}
__m512i pos2 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len1));
__m512i shifted2 = _mm512_permutexvar_epi8(pos2, d2);
__mmask64 merge2 = _mm512_movepi8_mask(pos2);
aligned_vec = _mm512_mask_blend_epi8(merge2, shifted2, aligned_vec);
if(len1 + len2 >= 64) {
_mm512_store_si512(out + (outlen & ~63), aligned_vec);
aligned_vec = shifted2;
}
outlen += len2;
}
_mm512_store_si512(out + (outlen & ~63), aligned_vec); // assume no overflow
return outlen;
}
size_t aligned_branchless(char *restrict out, const char *restrict in, size_t len) {
const __m512i space = _mm512_set1_epi8(32);
const __m512i permidx = _mm512_set_epi32(
0x3f3e3d3c, 0x3b3a3938, 0x37363534, 0x33323130,
0x2f2e2d2c, 0x2b2a2928, 0x27262524, 0x23222120,
0x1f1e1d1c, 0x1b1a1918, 0x17161514, 0x13121110,
0x0f0e0d0c, 0x0b0a0908, 0x07060504, 0x03020100
);
// 2 cycle unroll
size_t outlen = 0;
__m512i aligned_vec;
for(size_t i=0; i<len; i+=128) {
__m512i d1 = _mm512_load_si512(in + i);
__m512i d2 = _mm512_load_si512(in + i + 64);
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space);
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space);
d1 = _mm512_maskz_compress_epi8(m1, d1);
d2 = _mm512_maskz_compress_epi8(m2, d2);
size_t len1 = _mm_popcnt_u64(_cvtmask64_u64(m1));
size_t len2 = _mm_popcnt_u64(_cvtmask64_u64(m2));
// merge
size_t len0 = outlen & 63;
__m512i pos1 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len0));
__m512i shifted1 = _mm512_permutexvar_epi8(pos1, d1);
__mmask64 merge1 = _mm512_movepi8_mask(pos1);
aligned_vec = _mm512_mask_blend_epi8(merge1, shifted1, aligned_vec);
_mm512_store_si512(out + (outlen & ~63), aligned_vec);
// if the vector is now full, move in next vector
__mmask8 writeout = -(len0 + len1 >= 64);
aligned_vec = _mm512_mask_mov_epi64(aligned_vec, writeout, shifted1);
outlen += len1;
len1 = (len0 + len1) & 63;
__m512i pos2 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len1));
__m512i shifted2 = _mm512_permutexvar_epi8(pos2, d2);
__mmask64 merge2 = _mm512_movepi8_mask(pos2);
aligned_vec = _mm512_mask_blend_epi8(merge2, shifted2, aligned_vec);
_mm512_store_si512(out + (outlen & ~63), aligned_vec);
writeout = -(len1 + len2 >= 64);
aligned_vec = _mm512_mask_mov_epi64(aligned_vec, writeout, shifted2);
outlen += len2;
}
_mm512_store_si512(out + (outlen & ~63), aligned_vec); // assume no overflow
return outlen;
}
size_t aligned_maskedwrite(char *restrict out, const char *restrict in, size_t len) {
const __m512i space = _mm512_set1_epi8(32);
const __m512i permidx = _mm512_set_epi32(
0xbfbebdbc, 0xbbbab9b8, 0xb7b6b5b4, 0xb3b2b1b0,
0xafaeadac, 0xabaaa9a8, 0xa7a6a5a4, 0xa3a2a1a0,
0x9f9e9d9c, 0x9b9a9998, 0x97969594, 0x93929190,
0x8f8e8d8c, 0x8b8a8988, 0x87868584, 0x83828180
);
// 2 cycle unroll
size_t outlen = 0;
for(size_t i=0; i<len; i+=128) {
__m512i d1 = _mm512_load_si512(in + i);
__m512i d2 = _mm512_load_si512(in + i + 64);
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space);
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space);
d1 = _mm512_maskz_compress_epi8(m1, d1);
d2 = _mm512_maskz_compress_epi8(m2, d2);
size_t len1 = _mm_popcnt_u64(_cvtmask64_u64(m1));
size_t len2 = _mm_popcnt_u64(_cvtmask64_u64(m2));
// merge
size_t len0 = outlen & 63;
__m512i pos1 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len0));
__mmask64 write1 = _mm512_movepi8_mask(pos1);
__m512i shifted1 = _mm512_permutexvar_epi8(pos1, d1);
_mm512_mask_storeu_epi8(out + (outlen & ~63), write1, shifted1);
_mm512_mask_storeu_epi8(out + (outlen & ~63) + 64, _knot_mask64(write1), shifted1);
outlen += len1;
len1 = (len0 + len1) & 63;
__m512i pos2 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len1));
__mmask64 write2 = _mm512_movepi8_mask(pos2);
__m512i shifted2 = _mm512_permutexvar_epi8(pos2, d2);
_mm512_mask_storeu_epi8(out + (outlen & ~63), write2, shifted2);
_mm512_mask_storeu_epi8(out + (outlen & ~63) + 64, _knot_mask64(write2), shifted2);
outlen += len2;
}
return outlen;
}
size_t scalar_michaels(char * __restrict out, const char *__restrict in, size_t len) {
char * __restrict out0 = out;
const char SPACE = ' ';
// remove trailing spaces
while (len > 0 && in[len-1] == SPACE) --len;
if (len > 16) {
// align input on 8-byte boundary
size_t allen = -(intptr_t)in & 7;
len -= allen;
while (allen > 0) {
char c = *in++;
if (c != SPACE)
*out++ = c;
--allen;
}
const uint64_t* in8 = (const uint64_t*)in;
size_t nit = len / 16;
len = len % 16;
// process the bulk of the input
const uint64_t space_x1 = SPACE;
const uint64_t space_x2 = (space_x1 << 8*1) + space_x1;
const uint64_t space_x4 = (space_x2 << 8*2) + space_x2;
const uint64_t space_x8 = (space_x4 << 8*4) + space_x4;
const uint64_t msk0 = 255;
const uint64_t msk1 = msk0 << 8;
const uint64_t msk2 = msk1 << 8;
const uint64_t msk3 = msk2 << 8;
const uint64_t msk4 = msk3 << 8;
const uint64_t msk5 = msk4 << 8;
const uint64_t msk6 = msk5 << 8;
const uint64_t msk7 = msk6 << 8;
for(size_t i=0; i<nit; i++) {
uint64_t inw = *in8++;
uint64_t neq = inw ^ space_x8;
*out = (char)inw; inw >>= 8; out += (neq & msk0) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk1) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk2) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk3) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk4) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk5) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk6) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk7) != 0;
inw = *in8++;
neq = inw ^ space_x8;
*out = (char)inw; inw >>= 8; out += (neq & msk0) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk1) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk2) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk3) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk4) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk5) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk6) != 0;
*out = (char)inw; inw >>= 8; out += (neq & msk7) != 0;
}
in = (const char*)in8;
}
// process tail
while (len > 0) {
char c = *in++;
if (c != SPACE)
*out++ = c;
--len;
}
return out-out0;
}
int
timeval_subtract (struct timeval *result, struct timeval *x, struct timeval *y)
{
/* Perform the carry for the later subtraction by updating y. */
if (x->tv_usec < y->tv_usec) {
int nsec = (y->tv_usec - x->tv_usec) / 1000000 + 1;
y->tv_usec -= 1000000 * nsec;
y->tv_sec += nsec;
}
if (x->tv_usec - y->tv_usec > 1000000) {
int nsec = (x->tv_usec - y->tv_usec) / 1000000;
y->tv_usec += 1000000 * nsec;
y->tv_sec -= nsec;
}
/* Compute the time remaining to wait.
tv_usec is certainly positive. */
result->tv_sec = x->tv_sec - y->tv_sec;
result->tv_usec = x->tv_usec - y->tv_usec;
/* Return 1 if result is negative. */
return x->tv_sec < y->tv_sec;
}
int main(int argc, char** argv) {
// get input data
FILE *fp = fopen(argv[1], "r");
fseek(fp, 0L, SEEK_END);
size_t sz = ftell(fp) & ~127; // force 128 byte multiple
fseek(fp, 0L, SEEK_SET);
char* buf;
posix_memalign((void**)&buf, 64, sz);
fread(buf, sizeof(char), sz, fp);
fclose(fp);
char* out;
posix_memalign((void**)&out, 64, sz);
size_t r1;
// basic test
char* ref = (char*)malloc(sz);
size_t rr = scalar(ref, buf, sz);
#define TEST(f) \
if(f(out, buf, sz) != rr) { \
printf(#f " length mismatch\n"); \
return 1; \
} \
if(memcmp(ref, out, rr)) { \
printf(#f " data mismatch\n"); \
return 1; \
}
TEST(unaligned)
TEST(aligned)
TEST(aligned_branchless)
TEST(aligned_maskedwrite)
TEST(scalar_michaels)
unsigned long int elapsed;
struct timeval t_start, t_end, t_diff;
#define BENCH(f) \
gettimeofday(&t_start, NULL); \
for(int i=0; i<10; i++) { \
r1 = f(out, buf, sz); \
} \
gettimeofday(&t_end, NULL); \
timeval_subtract(&t_diff, &t_end, &t_start); \
elapsed = (t_diff.tv_sec*1e6 + t_diff.tv_usec); \
printf(#f " runs in: %lu microsecs\n", elapsed);
BENCH(unaligned)
BENCH(aligned)
BENCH(aligned_branchless)
BENCH(aligned_maskedwrite)
BENCH(scalar)
BENCH(scalar_michaels)
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment