Last active
March 27, 2019 20:38
-
-
Save Barakat/60e2b3301accc863ea72e9ab7075689a to your computer and use it in GitHub Desktop.
Optimized dot product using SSE and AVX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "pch.h" | |
#include <xmmintrin.h> | |
#include <immintrin.h> | |
#include <random> | |
#include <chrono> | |
#include <chrono> | |
#include <functional> | |
#include <cstdio> | |
struct float4x2 | |
{ | |
float x0; | |
float x1; | |
float y0; | |
float y1; | |
float z0; | |
float z1; | |
float w0; | |
float w1; | |
}; | |
void float4x2_dot_product(const float4x2* vectors, float* dots, const std::size_t count) | |
{ | |
for (std::size_t i = 0; i < count; ++i) | |
{ | |
const auto v = &vectors[i]; // shorthand | |
dots[i] = v->x0 * v->x1 + v->y0 * v->y1 + v->z0 * v->z1 + v->w0 * v->w1; | |
} | |
} | |
void fill_float4_with_random_data(float4x2* vectors, const std::size_t count) | |
{ | |
std::random_device random_device; | |
std::mt19937 generator(random_device()); | |
const std::uniform_real_distribution<float> distribution(0, 10000); | |
for (std::size_t i = 0; i < count; ++i) | |
{ | |
vectors[i].x0 = distribution(generator); | |
vectors[i].x1 = distribution(generator); | |
vectors[i].y0 = distribution(generator); | |
vectors[i].y1 = distribution(generator); | |
vectors[i].z0 = distribution(generator); | |
vectors[i].z1 = distribution(generator); | |
vectors[i].w0 = distribution(generator); | |
vectors[i].w1 = distribution(generator); | |
} | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
struct m1284x2 | |
{ | |
__m128 x0; | |
__m128 x1; | |
__m128 y0; | |
__m128 y1; | |
__m128 z0; | |
__m128 z1; | |
__m128 w0; | |
__m128 w1; | |
}; | |
void m1284x2_dot_product(const m1284x2* vectors, __m128* dots, const std::size_t count) | |
{ | |
for (std::size_t i = 0; i < count; ++i) | |
{ | |
const auto v = &vectors[i]; // shorthand | |
const auto x = _mm_mul_ps(v->x0, v->x1); | |
const auto y = _mm_mul_ps(v->y0, v->y1); | |
const auto z = _mm_mul_ps(v->z0, v->z1); | |
const auto w = _mm_mul_ps(v->w0, v->w1); | |
const auto xy = _mm_add_ps(x, y); | |
const auto zw = _mm_add_ps(z, w); | |
dots[i] = _mm_add_ps(xy, zw); | |
} | |
} | |
void copy_from_float4_to_m1284x2(m1284x2* a, const float4x2* b, const std::size_t count) | |
{ | |
for (std::size_t i = 0; i < count; i += 4) | |
{ | |
#define MACRO(SYMBOL) \ | |
__declspec(align(16)) const float SYMBOL##_array[4] = { \ | |
b[i + 0].SYMBOL, b[i + 1].SYMBOL, b[i + 2].SYMBOL, b[i + 3].SYMBOL \ | |
}; \ | |
a[i / 4].SYMBOL = _mm_load_ps(SYMBOL##_array); | |
MACRO(x0); | |
MACRO(x1); | |
MACRO(y0); | |
MACRO(y1); | |
MACRO(z0); | |
MACRO(z1); | |
MACRO(w0); | |
MACRO(w1); | |
#undef MACRO | |
} | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
struct m2564x2 | |
{ | |
__m256 x0; | |
__m256 x1; | |
__m256 y0; | |
__m256 y1; | |
__m256 z0; | |
__m256 z1; | |
__m256 w0; | |
__m256 w1; | |
}; | |
void m2564x2_dot_product(const m2564x2* vectors, __m256* dots, const std::size_t count) | |
{ | |
for (std::size_t i = 0; i < count; ++i) | |
{ | |
const auto v = &vectors[i]; // shorthand | |
const auto x = _mm256_mul_ps(v->x0, v->x1); | |
const auto y = _mm256_mul_ps(v->y0, v->y1); | |
const auto z = _mm256_mul_ps(v->z0, v->z1); | |
const auto w = _mm256_mul_ps(v->w0, v->w1); | |
const auto xy = _mm256_add_ps(x, y); | |
const auto zw = _mm256_add_ps(z, w); | |
dots[i] = _mm256_add_ps(xy, zw); | |
} | |
} | |
void copy_from_float4_to_m2564x2(m2564x2* a, const float4x2* b, const std::size_t count) | |
{ | |
for (std::size_t i = 0; i < count; i += 8) | |
{ | |
#define MACRO(SYMBOL) \ | |
__declspec(align(32)) const float SYMBOL##_array[8] = { \ | |
b[i + 0].SYMBOL, b[i + 1].SYMBOL, b[i + 2].SYMBOL, b[i + 3].SYMBOL, \ | |
b[i + 4].SYMBOL, b[i + 5].SYMBOL, b[i + 6].SYMBOL, b[i + 7].SYMBOL \ | |
}; \ | |
a[i / 8].SYMBOL = _mm256_load_ps(SYMBOL##_array); | |
MACRO(x0); | |
MACRO(x1); | |
MACRO(y0); | |
MACRO(y1); | |
MACRO(z0); | |
MACRO(z1); | |
MACRO(w0); | |
MACRO(w1); | |
#undef MACRO | |
} | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
float benchmark(const std::function<void()>& function, const std::size_t repeat) | |
{ | |
const auto start_time = std::chrono::steady_clock::now(); | |
for (std::size_t i = 0; i < repeat; ++i) | |
{ | |
function(); | |
} | |
const auto end_time = std::chrono::steady_clock::now(); | |
return std::chrono::duration_cast<std::chrono::microseconds>( | |
end_time - start_time | |
).count() / static_cast<float>(repeat); | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | |
int main() | |
{ | |
const std::size_t size = 4096; | |
static_assert(size % 8 == 0); | |
const auto v_float4x2 = std::make_unique<float4x2[]>(size); | |
fill_float4_with_random_data(v_float4x2.get(), size); | |
const auto v_m1284x2 = std::make_unique<m1284x2[]>(size / 4); | |
copy_from_float4_to_m1284x2(v_m1284x2.get(), v_float4x2.get(), size); | |
const auto v_m2564x2 = std::make_unique<m2564x2[]>(size / 8); | |
copy_from_float4_to_m2564x2(v_m2564x2.get(), v_float4x2.get(), size); | |
const auto dots_float = std::make_unique<float[]>(size); | |
const auto dots_m128 = std::make_unique<__m128[]>(size / 4); | |
const auto dots_m256 = std::make_unique<__m256[]>(size / 8); | |
static const std::size_t repeat = 1000; | |
const auto r0_mean = benchmark([&]() | |
{ | |
float4x2_dot_product(v_float4x2.get(), dots_float.get(), size); | |
}, repeat); | |
const auto r1_mean = benchmark([&]() | |
{ | |
m1284x2_dot_product(v_m1284x2.get(), dots_m128.get(), size / 4); | |
}, repeat); | |
const auto r2_mean = benchmark([&]() | |
{ | |
m2564x2_dot_product(v_m2564x2.get(), dots_m256.get(), size / 8); | |
}, repeat); | |
std::printf("float: %0.2f ms\n", r0_mean); | |
std::printf(" SSE: %0.2f ms (speed up x%0.2f)\n", r1_mean, r0_mean / r1_mean); | |
std::printf(" AVX: %0.2f ms (speed up x%0.2f)\n", r2_mean, r0_mean / r2_mean); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment