Created
September 2, 2020 14:19
-
-
Save clausecker/42f4cc547c2eafa2c6ce6a493b4e2a73 to your computer and use it in GitHub Desktop.
summing decimal digits with different approaches
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// getnumericvalue(ptr) | |
.section .text | |
.type getnumericvalue, @function | |
.globl getnumericvalue | |
getnumericvalue: | |
xor %eax, %eax // digit counter | |
// process string until we reach cache-line alignment | |
test $64-1, %dil // is ptr aligned to 64 byte? | |
jz 0f | |
1: movzbl (%rdi), %edx // load a byte from the string | |
inc %rdi // advance pointer | |
test %edx, %edx // is this the NUL byte? | |
jz .Lend // if yes, finish this function | |
sub $'0', %edx // turn ASCII character into digit | |
add %edx, %eax // and add to counter | |
test $64-1, %dil // is ptr aligned to 64 byte? | |
jnz 1b // if not, process more data | |
// process data in cache line increments until the end | |
// of the string is found somewhere | |
0: vpbroadcastb zero(%rip), %zmm1 // mask of '0' characters | |
vpxor %xmm3, %xmm3, %xmm3 // vectorised digit counter | |
vmovdqa32 (%rdi), %zmm0 // load one cache line from the string | |
vptestmb %zmm0, %zmm0, %k0 // clear k0 bits if any byte is NUL | |
kortestq %k0, %k0 // clear CF if a NUL byte is found | |
jnc 0f // skip loop if a NUL byte is found | |
.align 16 | |
1: add $64, %rdi // advance pointer | |
vpsadbw %zmm1, %zmm0, %zmm0 // sum groups of 8 bytes into 8 words | |
// also subtracts '0' from each byte | |
vpaddq %zmm3, %zmm0, %zmm3 // add to counters | |
vmovdqa32 (%rdi), %zmm0 // load one cache line from the string | |
vptestmb %zmm0, %zmm0, %k0 // clear k0 bits if any byte is NUL | |
kortestq %k0, %k0 // clear CF if a NUL byte is found | |
jc 1b // go on unless a NUL byte was found | |
// reduce 8 vectorised counters into rdx | |
0: vextracti64x4 $1, %zmm3, %ymm2 // extract high 4 words | |
vpaddq %ymm2, %ymm3, %ymm3 // and add them to the low words | |
vextracti128 $1, %ymm3, %xmm2 // extract high 2 words | |
vpaddq %xmm2, %xmm3, %xmm3 // and add them to the low words | |
vpshufd $0x4e, %xmm3, %xmm2 // swap qwords into xmm2 | |
vpaddq %xmm2, %xmm3, %xmm3 // and add to xmm0 | |
vmovq %xmm3, %rdx // move digit counter back to rdx | |
add %rdx, %rax // and add to counts from scalar head | |
// process tail | |
1: movzbl (%rdi), %edx // load a byte from the string | |
inc %rdi // advance pointer | |
test %edx, %edx // is this the NUL byte? | |
jz .Lend // if yes, finish this function | |
sub $'0', %edx // turn ASCII character into digit | |
add %rdx, %rax // and add to counter | |
jnz 1b // if not, process more data | |
.Lend: xor %edx, %edx // zero-extend RAX into RDX:RAX | |
mov $9, %ecx // divide by 9 | |
div %rcx // perform division | |
mov %edx, %eax // move remainder to result register | |
test %eax, %eax // is the remainder zero? | |
cmovz %ecx, %eax // if yes, set remainder to 9 | |
vzeroupper // restore SSE performance | |
ret // and return | |
.size getnumericvalue, .-getnumericvalue | |
// constants | |
.section .rodata | |
zero: .byte '0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <time.h> | |
#include <stdint.h> | |
extern unsigned getnumericvalue_simple(const char *in_str); | |
extern unsigned getnumericvalue_naive(const char *ptr); | |
extern unsigned getnumericvalue_parallel(const char *ptr); | |
extern unsigned getnumericvalue(const char *ptr); | |
static void measure(const char *name, int digits, const char *p, unsigned(*fun)(const char*)) { | |
clock_t start; | |
unsigned result = 0; | |
double duration; | |
int i, n = 10000; | |
start = clock(); | |
for (i = 0; i < n; i++) | |
result += fun(p); | |
duration = (clock() - start) * 1000.0 / CLOCKS_PER_SEC; | |
printf("%-9s %d digits -> %u, %7.3f msec\n", name, digits, result, duration/n); | |
} | |
int main(int argc, char *argv[]) { | |
int digits = argc < 2 ? 1000000 : strtol(argv[1], NULL, 0); | |
char *p = malloc(digits + 1); | |
for (int i = 0; i < digits; i++) | |
p[i] = "0123456789123456"[i & 15]; | |
p[digits] = '\0'; | |
measure("simple", digits, p, getnumericvalue_simple); | |
measure("naive", digits, p, getnumericvalue_naive); | |
measure("parallel", digits, p, getnumericvalue_parallel); | |
measure("simd", digits, p, getnumericvalue); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <time.h> | |
#include <stdint.h> | |
unsigned getnumericvalue_simple(const char *in_str) { | |
unsigned long number = 0; | |
const char *ptr = in_str; | |
do { | |
if (*ptr != '9') number += (*ptr - '0'); // Exclude '9' | |
ptr++; | |
} while (*ptr != 0); | |
return number <= 9 ? number : ((number - 1) % 9) + 1; | |
} | |
unsigned getnumericvalue_naive(const char *ptr) { | |
unsigned long number = 0; | |
while (*ptr) { | |
number += *ptr++ - '0'; | |
} | |
return number ? 1 + (number - 1) % 9 : 0; | |
} | |
unsigned getnumericvalue_parallel(const char *ptr) { | |
unsigned long long number = 0; | |
unsigned long long pack1, pack2, pack3; | |
/* align source on ull boundary */ | |
while ((uintptr_t)ptr & (sizeof(unsigned long long) - 1)) { | |
if (*ptr == '\0') | |
return number ? 1 + (number - 1) % 9 : 0; | |
number += *ptr++ - '0'; | |
} | |
/* scan 8 bytes at a time */ | |
pack3 = 0x3030303030303030; | |
for (;;) { | |
pack1 = 0; | |
#define REP8(x) x;x;x;x;x;x;x;x | |
#define REP28(x) x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x | |
REP28(pack2 = *(const unsigned long long *)(const void *)ptr; | |
if ((pack2 & pack3) != pack3) | |
break; | |
ptr += sizeof(unsigned long long); | |
pack1 += pack2 - pack3); | |
REP8(number += pack1 & 0xFF; pack1 >>= 8); | |
} | |
REP8(number += pack1 & 0xFF; pack1 >>= 8); | |
/* finish trailing bytes */ | |
while (*ptr) { | |
number += *ptr++ - '0'; | |
} | |
return number ? 1 + (number - 1) % 9 : 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment