Last active
July 8, 2024 11:56
-
-
Save zingaburga/e1a2650703f0c85cf9cf21cd6b028d82 to your computer and use it in GitHub Desktop.
Despace AVX512: aligned vs unaligned
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
#include <stdio.h> | |
#include <sys/time.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <stdint.h> | |
// assume initial pointers are aligned and len is a multiple of 128 | |
// compiled with: cc -march=sapphirerapids -O3 despace.c | |
size_t scalar(char *restrict out, const char *restrict in, size_t len) { | |
uintptr_t out_start = (uintptr_t)out; | |
for(size_t i=0; i<len; i++) { | |
char d = *in++; | |
if(d != 32) { | |
*out++ = d; | |
} | |
} | |
return (uintptr_t)out - out_start; | |
} | |
size_t unaligned(char *restrict out, const char *restrict in, size_t len) { | |
const __m512i space = _mm512_set1_epi8(32); | |
// 2 cycle unroll | |
size_t outlen = 0; | |
for(size_t i=0; i<len; i+=128) { | |
__m512i d1 = _mm512_load_si512(in + i); | |
__m512i d2 = _mm512_load_si512(in + i + 64); | |
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space); | |
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space); | |
#if defined(__tune_znver4__) | |
d1 = _mm512_maskz_compress_epi8(m1, d1); | |
d2 = _mm512_maskz_compress_epi8(m2, d2); | |
_mm512_storeu_si512(out + outlen, d1); | |
outlen += _mm_popcnt_u64(_cvtmask64_u64(m1)); | |
_mm512_storeu_si512(out + outlen, d2); | |
outlen += _mm_popcnt_u64(_cvtmask64_u64(m2)); | |
#else | |
_mm512_mask_compressstoreu_epi8(out + outlen, m1, d1); | |
outlen += _mm_popcnt_u64(_cvtmask64_u64(m1)); | |
_mm512_mask_compressstoreu_epi8(out + outlen, m2, d2); | |
outlen += _mm_popcnt_u64(_cvtmask64_u64(m2)); | |
#endif | |
} | |
return outlen; | |
} | |
size_t aligned(char *restrict out, const char *restrict in, size_t len) { | |
const __m512i space = _mm512_set1_epi8(32); | |
const __m512i permidx = _mm512_set_epi32( | |
0x3f3e3d3c, 0x3b3a3938, 0x37363534, 0x33323130, | |
0x2f2e2d2c, 0x2b2a2928, 0x27262524, 0x23222120, | |
0x1f1e1d1c, 0x1b1a1918, 0x17161514, 0x13121110, | |
0x0f0e0d0c, 0x0b0a0908, 0x07060504, 0x03020100 | |
); | |
// 2 cycle unroll | |
size_t outlen = 0; | |
__m512i aligned_vec; | |
for(size_t i=0; i<len; i+=128) { | |
__m512i d1 = _mm512_load_si512(in + i); | |
__m512i d2 = _mm512_load_si512(in + i + 64); | |
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space); | |
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space); | |
d1 = _mm512_maskz_compress_epi8(m1, d1); | |
d2 = _mm512_maskz_compress_epi8(m2, d2); | |
size_t len1 = _mm_popcnt_u64(_cvtmask64_u64(m1)); | |
size_t len2 = _mm_popcnt_u64(_cvtmask64_u64(m2)); | |
// merge | |
size_t len0 = outlen & 63; | |
__m512i pos1 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len0)); | |
__m512i shifted1 = _mm512_permutexvar_epi8(pos1, d1); | |
__mmask64 merge1 = _mm512_movepi8_mask(pos1); | |
aligned_vec = _mm512_mask_blend_epi8(merge1, shifted1, aligned_vec); | |
// if the vector is now full, write it out | |
if(len0 + len1 >= 64) { | |
_mm512_store_si512(out + (outlen & ~63), aligned_vec); | |
aligned_vec = shifted1; | |
outlen += len1; | |
len1 = (len0 + len1) & 63; | |
} else { | |
outlen += len1; | |
len1 += len0; | |
} | |
__m512i pos2 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len1)); | |
__m512i shifted2 = _mm512_permutexvar_epi8(pos2, d2); | |
__mmask64 merge2 = _mm512_movepi8_mask(pos2); | |
aligned_vec = _mm512_mask_blend_epi8(merge2, shifted2, aligned_vec); | |
if(len1 + len2 >= 64) { | |
_mm512_store_si512(out + (outlen & ~63), aligned_vec); | |
aligned_vec = shifted2; | |
} | |
outlen += len2; | |
} | |
_mm512_store_si512(out + (outlen & ~63), aligned_vec); // assume no overflow | |
return outlen; | |
} | |
size_t aligned_branchless(char *restrict out, const char *restrict in, size_t len) { | |
const __m512i space = _mm512_set1_epi8(32); | |
const __m512i permidx = _mm512_set_epi32( | |
0x3f3e3d3c, 0x3b3a3938, 0x37363534, 0x33323130, | |
0x2f2e2d2c, 0x2b2a2928, 0x27262524, 0x23222120, | |
0x1f1e1d1c, 0x1b1a1918, 0x17161514, 0x13121110, | |
0x0f0e0d0c, 0x0b0a0908, 0x07060504, 0x03020100 | |
); | |
// 2 cycle unroll | |
size_t outlen = 0; | |
__m512i aligned_vec; | |
for(size_t i=0; i<len; i+=128) { | |
__m512i d1 = _mm512_load_si512(in + i); | |
__m512i d2 = _mm512_load_si512(in + i + 64); | |
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space); | |
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space); | |
d1 = _mm512_maskz_compress_epi8(m1, d1); | |
d2 = _mm512_maskz_compress_epi8(m2, d2); | |
size_t len1 = _mm_popcnt_u64(_cvtmask64_u64(m1)); | |
size_t len2 = _mm_popcnt_u64(_cvtmask64_u64(m2)); | |
// merge | |
size_t len0 = outlen & 63; | |
__m512i pos1 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len0)); | |
__m512i shifted1 = _mm512_permutexvar_epi8(pos1, d1); | |
__mmask64 merge1 = _mm512_movepi8_mask(pos1); | |
aligned_vec = _mm512_mask_blend_epi8(merge1, shifted1, aligned_vec); | |
_mm512_store_si512(out + (outlen & ~63), aligned_vec); | |
// if the vector is now full, move in next vector | |
__mmask8 writeout = -(len0 + len1 >= 64); | |
aligned_vec = _mm512_mask_mov_epi64(aligned_vec, writeout, shifted1); | |
outlen += len1; | |
len1 = (len0 + len1) & 63; | |
__m512i pos2 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len1)); | |
__m512i shifted2 = _mm512_permutexvar_epi8(pos2, d2); | |
__mmask64 merge2 = _mm512_movepi8_mask(pos2); | |
aligned_vec = _mm512_mask_blend_epi8(merge2, shifted2, aligned_vec); | |
_mm512_store_si512(out + (outlen & ~63), aligned_vec); | |
writeout = -(len1 + len2 >= 64); | |
aligned_vec = _mm512_mask_mov_epi64(aligned_vec, writeout, shifted2); | |
outlen += len2; | |
} | |
_mm512_store_si512(out + (outlen & ~63), aligned_vec); // assume no overflow | |
return outlen; | |
} | |
size_t aligned_maskedwrite(char *restrict out, const char *restrict in, size_t len) { | |
const __m512i space = _mm512_set1_epi8(32); | |
const __m512i permidx = _mm512_set_epi32( | |
0xbfbebdbc, 0xbbbab9b8, 0xb7b6b5b4, 0xb3b2b1b0, | |
0xafaeadac, 0xabaaa9a8, 0xa7a6a5a4, 0xa3a2a1a0, | |
0x9f9e9d9c, 0x9b9a9998, 0x97969594, 0x93929190, | |
0x8f8e8d8c, 0x8b8a8988, 0x87868584, 0x83828180 | |
); | |
// 2 cycle unroll | |
size_t outlen = 0; | |
for(size_t i=0; i<len; i+=128) { | |
__m512i d1 = _mm512_load_si512(in + i); | |
__m512i d2 = _mm512_load_si512(in + i + 64); | |
__mmask64 m1 = _mm512_cmpneq_epi8_mask(d1, space); | |
__mmask64 m2 = _mm512_cmpneq_epi8_mask(d2, space); | |
d1 = _mm512_maskz_compress_epi8(m1, d1); | |
d2 = _mm512_maskz_compress_epi8(m2, d2); | |
size_t len1 = _mm_popcnt_u64(_cvtmask64_u64(m1)); | |
size_t len2 = _mm_popcnt_u64(_cvtmask64_u64(m2)); | |
// merge | |
size_t len0 = outlen & 63; | |
__m512i pos1 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len0)); | |
__mmask64 write1 = _mm512_movepi8_mask(pos1); | |
__m512i shifted1 = _mm512_permutexvar_epi8(pos1, d1); | |
_mm512_mask_storeu_epi8(out + (outlen & ~63), write1, shifted1); | |
_mm512_mask_storeu_epi8(out + (outlen & ~63) + 64, _knot_mask64(write1), shifted1); | |
outlen += len1; | |
len1 = (len0 + len1) & 63; | |
__m512i pos2 = _mm512_sub_epi8(permidx, _mm512_set1_epi8(len1)); | |
__mmask64 write2 = _mm512_movepi8_mask(pos2); | |
__m512i shifted2 = _mm512_permutexvar_epi8(pos2, d2); | |
_mm512_mask_storeu_epi8(out + (outlen & ~63), write2, shifted2); | |
_mm512_mask_storeu_epi8(out + (outlen & ~63) + 64, _knot_mask64(write2), shifted2); | |
outlen += len2; | |
} | |
return outlen; | |
} | |
size_t scalar_michaels(char * __restrict out, const char *__restrict in, size_t len) { | |
char * __restrict out0 = out; | |
const char SPACE = ' '; | |
// remove trailing spaces | |
while (len > 0 && in[len-1] == SPACE) --len; | |
if (len > 16) { | |
// align input on 8-byte boundary | |
size_t allen = -(intptr_t)in & 7; | |
len -= allen; | |
while (allen > 0) { | |
char c = *in++; | |
if (c != SPACE) | |
*out++ = c; | |
--allen; | |
} | |
const uint64_t* in8 = (const uint64_t*)in; | |
size_t nit = len / 16; | |
len = len % 16; | |
// process the bulk of the input | |
const uint64_t space_x1 = SPACE; | |
const uint64_t space_x2 = (space_x1 << 8*1) + space_x1; | |
const uint64_t space_x4 = (space_x2 << 8*2) + space_x2; | |
const uint64_t space_x8 = (space_x4 << 8*4) + space_x4; | |
const uint64_t msk0 = 255; | |
const uint64_t msk1 = msk0 << 8; | |
const uint64_t msk2 = msk1 << 8; | |
const uint64_t msk3 = msk2 << 8; | |
const uint64_t msk4 = msk3 << 8; | |
const uint64_t msk5 = msk4 << 8; | |
const uint64_t msk6 = msk5 << 8; | |
const uint64_t msk7 = msk6 << 8; | |
for(size_t i=0; i<nit; i++) { | |
uint64_t inw = *in8++; | |
uint64_t neq = inw ^ space_x8; | |
*out = (char)inw; inw >>= 8; out += (neq & msk0) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk1) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk2) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk3) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk4) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk5) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk6) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk7) != 0; | |
inw = *in8++; | |
neq = inw ^ space_x8; | |
*out = (char)inw; inw >>= 8; out += (neq & msk0) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk1) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk2) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk3) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk4) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk5) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk6) != 0; | |
*out = (char)inw; inw >>= 8; out += (neq & msk7) != 0; | |
} | |
in = (const char*)in8; | |
} | |
// process tail | |
while (len > 0) { | |
char c = *in++; | |
if (c != SPACE) | |
*out++ = c; | |
--len; | |
} | |
return out-out0; | |
} | |
int | |
timeval_subtract (struct timeval *result, struct timeval *x, struct timeval *y) | |
{ | |
/* Perform the carry for the later subtraction by updating y. */ | |
if (x->tv_usec < y->tv_usec) { | |
int nsec = (y->tv_usec - x->tv_usec) / 1000000 + 1; | |
y->tv_usec -= 1000000 * nsec; | |
y->tv_sec += nsec; | |
} | |
if (x->tv_usec - y->tv_usec > 1000000) { | |
int nsec = (x->tv_usec - y->tv_usec) / 1000000; | |
y->tv_usec += 1000000 * nsec; | |
y->tv_sec -= nsec; | |
} | |
/* Compute the time remaining to wait. | |
tv_usec is certainly positive. */ | |
result->tv_sec = x->tv_sec - y->tv_sec; | |
result->tv_usec = x->tv_usec - y->tv_usec; | |
/* Return 1 if result is negative. */ | |
return x->tv_sec < y->tv_sec; | |
} | |
int main(int argc, char** argv) { | |
// get input data | |
FILE *fp = fopen(argv[1], "r"); | |
fseek(fp, 0L, SEEK_END); | |
size_t sz = ftell(fp) & ~127; // force 128 byte multiple | |
fseek(fp, 0L, SEEK_SET); | |
char* buf; | |
posix_memalign((void**)&buf, 64, sz); | |
fread(buf, sizeof(char), sz, fp); | |
fclose(fp); | |
char* out; | |
posix_memalign((void**)&out, 64, sz); | |
size_t r1; | |
// basic test | |
char* ref = (char*)malloc(sz); | |
size_t rr = scalar(ref, buf, sz); | |
#define TEST(f) \ | |
if(f(out, buf, sz) != rr) { \ | |
printf(#f " length mismatch\n"); \ | |
return 1; \ | |
} \ | |
if(memcmp(ref, out, rr)) { \ | |
printf(#f " data mismatch\n"); \ | |
return 1; \ | |
} | |
TEST(unaligned) | |
TEST(aligned) | |
TEST(aligned_branchless) | |
TEST(aligned_maskedwrite) | |
TEST(scalar_michaels) | |
unsigned long int elapsed; | |
struct timeval t_start, t_end, t_diff; | |
#define BENCH(f) \ | |
gettimeofday(&t_start, NULL); \ | |
for(int i=0; i<10; i++) { \ | |
r1 = f(out, buf, sz); \ | |
} \ | |
gettimeofday(&t_end, NULL); \ | |
timeval_subtract(&t_diff, &t_end, &t_start); \ | |
elapsed = (t_diff.tv_sec*1e6 + t_diff.tv_usec); \ | |
printf(#f " runs in: %lu microsecs\n", elapsed); | |
BENCH(unaligned) | |
BENCH(aligned) | |
BENCH(aligned_branchless) | |
BENCH(aligned_maskedwrite) | |
BENCH(scalar) | |
BENCH(scalar_michaels) | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment