Created
July 21, 2023 12:57
-
-
Save xen0n/09d333ed83b7716e3be50baad9be627a to your computer and use it in GitHub Desktop.
Sketching LoongArch SIMD acceleration for Linux XOR ops
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: GPL-2.0-or-later | |
/* | |
$ gcc -O3 -o linux-xor-simd-test linux-xor-simd-test.c | |
$ ./linux-xor-simd-test | |
ref (size=4096 ) passed 16383 times: 0.005044150 s total, 0.000000307 s per pass, 12687.191 MiB/s | |
lsx_32b (size=4096 ) passed 16383 times: 0.002663250 s total, 0.000000162 s per pass, 24029.323 MiB/s | |
lsx_64b (size=4096 ) passed 16383 times: 0.002517970 s total, 0.000000153 s per pass, 25415.749 MiB/s | |
lsx_128b (size=4096 ) passed 16383 times: 0.002517590 s total, 0.000000153 s per pass, 25419.585 MiB/s | |
lasx_32b (size=4096 ) passed 16383 times: 0.001935550 s total, 0.000000118 s per pass, 33063.519 MiB/s | |
lasx_64b (size=4096 ) passed 16383 times: 0.001813990 s total, 0.000000110 s per pass, 35279.188 MiB/s | |
lasx_128b (size=4096 ) passed 16383 times: 0.001756910 s total, 0.000000107 s per pass, 36425.368 MiB/s | |
*/ | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <time.h> | |
#include <sys/random.h> | |
// may not be widely available yet | |
// #include <lsxintrin.h> | |
// #define DATA_SIZE_MIN_ORDER 9 // 512 | |
// #define DATA_SIZE_MAX_ORDER 20 // 1MiB | |
// same as crypto/xor.c | |
#define DATA_SIZE_MIN_ORDER 12 // 4KiB | |
#define DATA_SIZE_MAX_ORDER 12 // 4KiB | |
#define TIMES 16383 // must be odd | |
typedef void (*xor_impl_t)(void * __restrict, const void * __restrict, size_t); | |
// taken from linux include/asm-generic/xor.h | |
static void | |
xor_8regs_2(unsigned long bytes, unsigned long * __restrict p1, | |
const unsigned long * __restrict p2) | |
{ | |
long lines = bytes / (sizeof (long)) / 8; | |
do { | |
p1[0] ^= p2[0]; | |
p1[1] ^= p2[1]; | |
p1[2] ^= p2[2]; | |
p1[3] ^= p2[3]; | |
p1[4] ^= p2[4]; | |
p1[5] ^= p2[5]; | |
p1[6] ^= p2[6]; | |
p1[7] ^= p2[7]; | |
p1 += 8; | |
p2 += 8; | |
} while (--lines > 0); | |
} | |
static void reference_xor(void * __restrict a, const void * __restrict b, size_t len) | |
{ | |
xor_8regs_2(len, a, b); | |
} | |
// | |
// LSX | |
// | |
static void | |
xor_lsx_32b(unsigned long bytes, void * __restrict p1, const void * __restrict p2) | |
{ | |
long lines = bytes / 32; | |
do { | |
asm volatile ( | |
"vld $vr0, %[dst], 0\n\t" | |
"vld $vr1, %[dst], 16\n\t" | |
"vld $vr2, %[src], 0\n\t" | |
"vld $vr3, %[src], 16\n\t" | |
"vxor.v $vr0, $vr0, $vr2\n\t" | |
"vxor.v $vr1, $vr1, $vr3\n\t" | |
"vst $vr0, %[dst], 0\n\t" | |
"vst $vr1, %[dst], 16\n\t" | |
: : [dst] "r"(p1), [src] "r"(p2) | |
: "memory" | |
); | |
p1 += 32; | |
p2 += 32; | |
} while (--lines > 0); | |
} | |
static void lsx_32b_glue(void * __restrict a, const void * __restrict b, size_t len) | |
{ | |
xor_lsx_32b(len, a, b); | |
} | |
static void | |
xor_lsx_64b(unsigned long bytes, void * __restrict p1, const void * __restrict p2) | |
{ | |
long lines = bytes / 64; | |
do { | |
asm volatile ( | |
"vld $vr0, %[dst], 0\n\t" | |
"vld $vr1, %[dst], 16\n\t" | |
"vld $vr2, %[dst], 32\n\t" | |
"vld $vr3, %[dst], 48\n\t" | |
"vld $vr4, %[src], 0\n\t" | |
"vld $vr5, %[src], 16\n\t" | |
"vld $vr6, %[src], 32\n\t" | |
"vld $vr7, %[src], 48\n\t" | |
"vxor.v $vr0, $vr0, $vr4\n\t" | |
"vxor.v $vr1, $vr1, $vr5\n\t" | |
"vxor.v $vr2, $vr2, $vr6\n\t" | |
"vxor.v $vr3, $vr3, $vr7\n\t" | |
"vst $vr0, %[dst], 0\n\t" | |
"vst $vr1, %[dst], 16\n\t" | |
"vst $vr2, %[dst], 32\n\t" | |
"vst $vr3, %[dst], 48\n\t" | |
: : [dst] "r"(p1), [src] "r"(p2) | |
: "memory" | |
); | |
p1 += 64; | |
p2 += 64; | |
} while (--lines > 0); | |
} | |
static void lsx_64b_glue(void * __restrict a, const void * __restrict b, size_t len) | |
{ | |
xor_lsx_64b(len, a, b); | |
} | |
static void | |
xor_lsx_128b(unsigned long bytes, void * __restrict p1, const void * __restrict p2) | |
{ | |
long lines = bytes / 128; | |
do { | |
asm volatile ( | |
"vld $vr0, %[dst], 0\n\t" | |
"vld $vr1, %[dst], 16\n\t" | |
"vld $vr2, %[dst], 32\n\t" | |
"vld $vr3, %[dst], 48\n\t" | |
"vld $vr4, %[dst], 64\n\t" | |
"vld $vr5, %[dst], 80\n\t" | |
"vld $vr6, %[dst], 96\n\t" | |
"vld $vr7, %[dst], 112\n\t" | |
"vld $vr8, %[src], 0\n\t" | |
"vld $vr9, %[src], 16\n\t" | |
"vld $vr10, %[src], 32\n\t" | |
"vld $vr11, %[src], 48\n\t" | |
"vld $vr12, %[src], 64\n\t" | |
"vld $vr13, %[src], 80\n\t" | |
"vld $vr14, %[src], 96\n\t" | |
"vld $vr15, %[src], 112\n\t" | |
"vxor.v $vr0, $vr0, $vr8\n\t" | |
"vxor.v $vr1, $vr1, $vr9\n\t" | |
"vxor.v $vr2, $vr2, $vr10\n\t" | |
"vxor.v $vr3, $vr3, $vr11\n\t" | |
"vxor.v $vr4, $vr4, $vr12\n\t" | |
"vxor.v $vr5, $vr5, $vr13\n\t" | |
"vxor.v $vr6, $vr6, $vr14\n\t" | |
"vxor.v $vr7, $vr7, $vr15\n\t" | |
"vst $vr0, %[dst], 0\n\t" | |
"vst $vr1, %[dst], 16\n\t" | |
"vst $vr2, %[dst], 32\n\t" | |
"vst $vr3, %[dst], 48\n\t" | |
"vst $vr4, %[dst], 64\n\t" | |
"vst $vr5, %[dst], 80\n\t" | |
"vst $vr6, %[dst], 96\n\t" | |
"vst $vr7, %[dst], 112\n\t" | |
: : [dst] "r"(p1), [src] "r"(p2) | |
: "memory" | |
); | |
p1 += 128; | |
p2 += 128; | |
} while (--lines > 0); | |
} | |
static void lsx_128b_glue(void * __restrict a, const void * __restrict b, size_t len) | |
{ | |
xor_lsx_128b(len, a, b); | |
} | |
// | |
// LASX | |
// | |
static void | |
xor_lasx_32b(unsigned long bytes, void * __restrict p1, const void * __restrict p2) | |
{ | |
long lines = bytes / 32; | |
do { | |
asm volatile ( | |
"xvld $xr0, %[dst], 0\n\t" | |
"xvld $xr1, %[src], 0\n\t" | |
"xvxor.v $xr0, $xr0, $xr1\n\t" | |
"xvst $xr0, %[dst], 0\n\t" | |
: : [dst] "r"(p1), [src] "r"(p2) | |
: "memory" | |
); | |
p1 += 32; | |
p2 += 32; | |
} while (--lines > 0); | |
} | |
static void lasx_32b_glue(void * __restrict a, const void * __restrict b, size_t len) | |
{ | |
xor_lasx_32b(len, a, b); | |
} | |
static void | |
xor_lasx_64b(unsigned long bytes, void * __restrict p1, const void * __restrict p2) | |
{ | |
long lines = bytes / 64; | |
do { | |
asm volatile ( | |
"xvld $xr0, %[dst], 0\n\t" | |
"xvld $xr1, %[dst], 32\n\t" | |
"xvld $xr2, %[src], 0\n\t" | |
"xvld $xr3, %[src], 32\n\t" | |
"xvxor.v $xr0, $xr0, $xr2\n\t" | |
"xvxor.v $xr1, $xr1, $xr3\n\t" | |
"xvst $xr0, %[dst], 0\n\t" | |
"xvst $xr1, %[dst], 32\n\t" | |
: : [dst] "r"(p1), [src] "r"(p2) | |
: "memory" | |
); | |
p1 += 64; | |
p2 += 64; | |
} while (--lines > 0); | |
} | |
static void lasx_64b_glue(void * __restrict a, const void * __restrict b, size_t len) | |
{ | |
xor_lasx_64b(len, a, b); | |
} | |
static void | |
xor_lasx_128b(unsigned long bytes, void * __restrict p1, const void * __restrict p2) | |
{ | |
long lines = bytes / 128; | |
do { | |
asm volatile ( | |
"xvld $xr0, %[dst], 0\n\t" | |
"xvld $xr1, %[dst], 32\n\t" | |
"xvld $xr2, %[dst], 64\n\t" | |
"xvld $xr3, %[dst], 96\n\t" | |
"xvld $xr4, %[src], 0\n\t" | |
"xvld $xr5, %[src], 32\n\t" | |
"xvld $xr6, %[src], 64\n\t" | |
"xvld $xr7, %[src], 96\n\t" | |
"xvxor.v $xr0, $xr0, $xr4\n\t" | |
"xvxor.v $xr1, $xr1, $xr5\n\t" | |
"xvxor.v $xr2, $xr2, $xr6\n\t" | |
"xvxor.v $xr3, $xr3, $xr7\n\t" | |
"xvst $xr0, %[dst], 0\n\t" | |
"xvst $xr1, %[dst], 32\n\t" | |
"xvst $xr2, %[dst], 64\n\t" | |
"xvst $xr3, %[dst], 96\n\t" | |
: : [dst] "r"(p1), [src] "r"(p2) | |
: "memory" | |
); | |
p1 += 128; | |
p2 += 128; | |
} while (--lines > 0); | |
} | |
static void lasx_128b_glue(void * __restrict a, const void * __restrict b, size_t len) | |
{ | |
xor_lasx_128b(len, a, b); | |
} | |
// | |
// helpers | |
// | |
static void must_fill_randomness(void *buf, size_t len) | |
{ | |
ssize_t ret; | |
void *p = buf; | |
while (len) { | |
ret = getrandom(p, len, 0); | |
if (ret < 0) | |
abort(); | |
p += ret; | |
len -= ret; | |
} | |
} | |
static struct timespec diff_timespec( | |
const struct timespec *time1, | |
const struct timespec *time0) | |
{ | |
struct timespec diff = { | |
.tv_sec = time1->tv_sec - time0->tv_sec, | |
.tv_nsec = time1->tv_nsec - time0->tv_nsec | |
}; | |
if (diff.tv_nsec < 0) { | |
diff.tv_nsec += 1000000000; // nsec/sec | |
diff.tv_sec--; | |
} | |
return diff; | |
} | |
static struct timespec div_timespec(struct timespec x, int denom) | |
{ | |
// assume the value is not very large | |
long s = x.tv_sec * 1000000000 + x.tv_nsec; | |
s /= denom; | |
struct timespec ret = { | |
.tv_sec = s / 1000000000, | |
.tv_nsec = s % 1000000000, | |
}; | |
return ret; | |
} | |
static double get_throughput(int size, struct timespec elapsed, int times) | |
{ | |
double secs = (double)(elapsed.tv_sec * 1000000000l + (long)(elapsed.tv_nsec)) / 1e9; | |
double total_size = (double)((long)size * (long)times); | |
return total_size / secs; | |
} | |
static int run_order(int order, const char *desc, xor_impl_t fn) | |
{ | |
void *a, *b, *ref; | |
int size = 1 << order; | |
struct timespec start, end, elapsed, pass_time; | |
int i, ret; | |
if (!(a = malloc(size))) | |
abort(); | |
if (!(b = malloc(size))) | |
abort(); | |
if (!(ref = malloc(size))) | |
abort(); | |
must_fill_randomness(a, size); | |
must_fill_randomness(b, size); | |
memcpy(ref, a, size); | |
reference_xor(ref, b, size); | |
{ | |
if (clock_gettime(CLOCK_THREAD_CPUTIME_ID, &start)) | |
abort(); | |
for (i = 0; i < TIMES; i++) | |
fn(a, b, size); | |
if (clock_gettime(CLOCK_THREAD_CPUTIME_ID, &end)) | |
abort(); | |
} | |
elapsed = diff_timespec(&end, &start); | |
pass_time = div_timespec(elapsed, TIMES); | |
ret = memcmp(a, ref, size) != 0; | |
printf( | |
"%-10s(size=%-7d) %s %d times: %ld.%09ld s total, %ld.%09ld s per pass, %.3lf MiB/s\n", | |
desc, | |
size, | |
ret ? "failed" : "passed", | |
TIMES, | |
elapsed.tv_sec, | |
elapsed.tv_nsec, | |
pass_time.tv_sec, | |
pass_time.tv_nsec, | |
get_throughput(size, elapsed, TIMES) / 1048576.0 | |
); | |
free(ref); | |
free(b); | |
free(a); | |
return ret; | |
} | |
static int try_all_orders(const char *desc, xor_impl_t fn) | |
{ | |
int order, ret = 0; | |
for (order = DATA_SIZE_MIN_ORDER; order <= DATA_SIZE_MAX_ORDER; order++) | |
ret |= run_order(order, desc, fn); | |
return ret; | |
} | |
int main(int argc, const char *argv[]) | |
{ | |
int ret = 0; | |
ret |= try_all_orders("ref", reference_xor); | |
ret |= try_all_orders("lsx_32b", lsx_32b_glue); | |
ret |= try_all_orders("lsx_64b", lsx_64b_glue); | |
ret |= try_all_orders("lsx_128b", lsx_128b_glue); | |
ret |= try_all_orders("lasx_32b", lasx_32b_glue); | |
ret |= try_all_orders("lasx_64b", lasx_64b_glue); | |
ret |= try_all_orders("lasx_128b", lasx_128b_glue); | |
return ret; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment