Skip to content

Instantly share code, notes, and snippets.

@lhecker
Created May 13, 2014 15:14
Show Gist options
  • Save lhecker/d9cfb586c52a63e6e4ae to your computer and use it in GitHub Desktop.
Save lhecker/d9cfb586c52a63e6e4ae to your computer and use it in GitHub Desktop.
Optimized version of bytesum_intrinsics.c from http://jvns.ca/blog/2014/05/12/computers-are-fast/
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <time.h>
#include <emmintrin.h>
/*
* Round up n to the next multiple of m.
* m needs to be a power of 2.
*/
#define round_up(n, m) (((n) + (m) - 1) & ~((m) - 1))
/*
* Round down n to the next multiple of m.
* m needs to be a power of 2.
*/
#define round_down(n, m) ((n) & ~((m) - 1))
typedef struct buf_s buf_t;
typedef uint8_t(*test_fn_t)(buf_t buf);
struct buf_s {
char *base;
size_t size;
};
#define test(buf, test_fn) test_(buf, #test_fn, test_fn)
void test_(buf_t buf, char *test_fn_name, test_fn_t test_fn) {
clock_t tic = clock();
uint8_t res = test_fn(buf);
clock_t toc = clock();
printf("%s() --> %u --- %u ticks\n", test_fn_name, res, (unsigned int)(toc - tic));
}
/*
* plain C function by Julia Evans
*/
uint8_t sum_array_c_jvns(buf_t buf) {
uint8_t sum = 0;
for (size_t i = 0; i < buf.size; i++) {
sum += buf.base[i];
}
return sum;
}
/*
* SSE function by Julia Evans
*/
uint8_t sum_array_sse_jvns(buf_t buf) {
const __m128i vk0 = _mm_set1_epi8(0); // constant vector of all 0s for use with _mm_unpacklo_epi8/_mm_unpackhi_epi8
const __m128i vk1 = _mm_set1_epi16(1); // constant vector of all 1s for use with _mm_madd_epi16
__m128i vsum = _mm_set1_epi32(0); // initialise vector of four partial 32 bit sums
uint32_t sum;
for (size_t i = 0; i < buf.size; i += 16) {
__m128i v = _mm_load_si128((__m128i *)(buf.base + i)); // load vector of 8 bit values
__m128i vl = _mm_unpacklo_epi8(v, vk0); // unpack to two vectors of 16 bit values
__m128i vh = _mm_unpackhi_epi8(v, vk0);
vsum = _mm_add_epi32(vsum, _mm_madd_epi16(vl, vk1));
vsum = _mm_add_epi32(vsum, _mm_madd_epi16(vh, vk1));
// unpack and accumulate 16 bit values to
// 32 bit partial sum vector
}
// horizontal add of four 32 bit partial sums and return result
vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8));
vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4));
sum = _mm_cvtsi128_si32(vsum);
return (uint8_t)sum;
}
/*
* optimized SSE function by Leonard Hecker
*/
uint8_t sum_array_sse_hecker(buf_t buf) {
uint8_t sum = 0;
size_t i = 0;
size_t unalignedBytes = round_up((uintptr_t)buf.base, 16) - (uintptr_t)buf.base;
if (buf.size >= unalignedBytes + 32) {
/*
* buf.base might be unaligned (i.e. not aligned on a 16 byte boundary as required by _mm_load_si128).
*/
for (; i < unalignedBytes; i++) {
sum += buf.base[i];
}
/*
* Parallely bytewise add 32 byte of data to 2 SSE vectors.
*/
__m128i vsum1 = _mm_setzero_si128();
__m128i vsum2 = _mm_setzero_si128();
size_t end = round_down(buf.size - 31, 32);
do {
vsum1 = _mm_add_epi8(vsum1, _mm_load_si128((__m128i *)(buf.base + i)));
vsum2 = _mm_add_epi8(vsum2, _mm_load_si128((__m128i *)(buf.base + i + 16)));
i += 32;
} while (i < end);
/*
* Horizontally add,
* by halfing the "width" of vsum1 starting with 16 bytes down to 1 byte.
* This is done by shifting the upper half left and adding it to the lower half each time.
*/
vsum1 = _mm_add_epi8(vsum1, vsum2);
vsum1 = _mm_add_epi8(vsum1, _mm_srli_si128(vsum1, 8));
vsum1 = _mm_add_epi8(vsum1, _mm_srli_si128(vsum1, 4));
vsum1 = _mm_add_epi8(vsum1, _mm_srli_si128(vsum1, 2));
vsum1 = _mm_add_epi8(vsum1, _mm_srli_si128(vsum1, 1));
sum += ((uint8_t *)&vsum1)[0];
}
/*
* Add all remaining bytes.
*/
for (; i < buf.size; i++) {
sum += buf.base[i];
}
return sum;
}
int main(int argc, char *argv[]) {
if (argc < 2) {
printf("Too few arguments.\n");
return 1;
}
buf_t buf;
{
FILE *f = fopen(argv[1], "rb");
if (!f) {
printf("Could not open %s\n", argv[1]);
return 1;
}
clock_t tic = clock();
{
fseek(f, 0, SEEK_END);
buf.size = ftell(f);
fseek(f, 0, SEEK_SET);
buf.base = (char *)malloc(buf.size + 16);
if (!buf.base) {
printf("Could not allocate %zu bytes.\n", buf.size);
return 1;
}
size_t read = fread(buf.base, 1, buf.size, f);
if (read != buf.size) {
printf("Could not read %zu bytes. Only got %zu bytes.\n", buf.size, read);
return 1;
}
}
clock_t toc = clock();
printf("fread() --> %zu --- %u ticks\n", buf.size, (unsigned int)(toc - tic));
}
test(buf, sum_array_c_jvns);
test(buf, sum_array_sse_jvns);
test(buf, sum_array_sse_hecker);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment