Created
July 17, 2016 09:01
-
-
Save koturn/43ac1695297b6ba88f4ea95091c7a114 to your computer and use it in GitHub Desktop.
Calculate inner product with SIMD instruction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if defined(ENABLE_AVX) && !defined(__AVX__) | |
# error Macro: ENABLE_AVX is defined, but unable to use AVX intrinsic functions | |
#elif defined(ENABLE_SSE) && !defined(__SSE2__) | |
# error Macro: ENABLE_SSE is defined, but unable to use SSE intrinsic functions | |
#elif defined(ENABLE_NEON) && !defined(__ARM_NEON) && !defined(__ARM_NEON__) | |
# error Macro: ENABLE_NEON is defined, but unable to use NEON intrinsic functions | |
#else | |
#include <cstddef> | |
#include <iostream> | |
#include <memory> | |
#include <type_traits> | |
#if defined(_MSC_VER) || defined(__MINGW32__) | |
# include <malloc.h> | |
#else | |
# include <cstdlib> | |
#endif // defined(_MSC_VER) || defined(__MINGW32__) | |
#if defined(ENABLE_AVX) || defined(ENABLE_SSE) | |
# ifdef _MSC_VER | |
# include <intrin.h> | |
# else | |
# include <x86intrin.h> | |
# endif // _MSC_VER | |
#elif defined(ENABLE_NEON) | |
# include <arm_neon.h> | |
#endif // defined(ENABLE_AVX) || defined(ENABLE_SSE) | |
#if defined(__cplusplus) && __cplusplus < 201103L | |
# ifdef _MSC_VER | |
# define alignof(n) __alignof(n) | |
# else | |
# define alignof(n) __alignof(n)__ | |
# endif // _MSC_VER | |
#endif // defined(__cplusplus) || cplusplus < 201103L | |
template<typename T = void*, typename std::enable_if<std::is_pointer<T>::value, std::nullptr_t>::type = nullptr> | |
static inline T | |
alignedMalloc(std::size_t size, std::size_t alignment) noexcept | |
{ | |
#if defined(_MSC_VER) || defined(__MINGW32__) | |
return reinterpret_cast<T>(_aligned_malloc(size, alignment)); | |
#else | |
void* p; | |
return reinterpret_cast<T>(posix_memalign(&p, alignment, size) == 0 ? p : nullptr); | |
#endif // defined(_MSC_VER) || defined(__MINGW32__) | |
} | |
static inline void | |
alignedFree(void* ptr) noexcept | |
{ | |
#if defined(_MSC_VER) || defined(__MINGW32__) | |
_aligned_free(ptr); | |
#else | |
std::free(ptr); | |
#endif // defined(_MSC_VER) || defined(__MINGW32__) | |
} | |
struct AlignedDeleter | |
{ | |
void | |
operator()(void* p) const noexcept | |
{ | |
alignedFree(p); | |
} | |
}; | |
#if defined(ENABLE_AVX) | |
static constexpr int ALIGN = alignof(__m256); | |
#elif defined(ENABLE_SSE) | |
static constexpr int ALIGN = alignof(__m128); | |
#elif defined(ENABLE_NEON) | |
static constexpr int ALIGN = alignof(float32x4_t); | |
#else | |
static constexpr int ALIGN = 8; | |
#endif // defined(ENABLE_AVX) | |
static inline float | |
innerProduct(const float* a, const float* b, std::size_t n) noexcept | |
{ | |
#if defined(ENABLE_AVX) | |
static constexpr std::size_t INTERVAL = sizeof(__m256) / sizeof(float); | |
__m256 sumx8 = {0}; | |
for (std::size_t i = 0; i < n; i += INTERVAL) { | |
__m256 ax8 = _mm256_load_ps(&a[i]); | |
__m256 bx8 = _mm256_load_ps(&b[i]); | |
ax8 = _mm256_mul_ps(ax8, bx8); | |
sumx8 = _mm256_add_ps(sumx8, ax8); | |
} | |
alignas(ALIGN) float s[INTERVAL] = {0}; | |
_mm256_store_ps(s, sumx8); | |
float sum = s[0] + s[1] + s[2] + s[3] + s[4] + s[5] + s[6] + s[7]; | |
for (std::size_t i = n - n % INTERVAL; i < n; i++) { | |
sum += a[i] * b[i]; | |
} | |
return sum; | |
#elif defined(ENABLE_SSE) | |
static constexpr std::size_t INTERVAL = sizeof(__m128) / sizeof(float); | |
__m128 sumx4 = {0}; | |
for (std::size_t i = 0; i < n; i += INTERVAL) { | |
__m128 ax4 = _mm_load_ps(&a[i]); | |
__m128 bx4 = _mm_load_ps(&b[i]); | |
ax4 = _mm_mul_ps(ax4, bx4); | |
sumx4 = _mm_add_ps(sumx4, ax4); | |
} | |
alignas(ALIGN) float s[INTERVAL] = {0}; | |
_mm_store_ps(s, sumx4); | |
float sum = s[0] + s[1] + s[2] + s[3]; | |
for (std::size_t i = n - n % INTERVAL; i < n; i++) { | |
sum += a[i] * b[i]; | |
} | |
return sum; | |
#elif defined(ENABLE_NEON) | |
static constexpr std::size_t INTERVAL = sizeof(float32x4_t) / sizeof(float); | |
float32x4_t sumx4 = {0}; | |
for (std::size_t i = 0; i < n; i += INTERVAL) { | |
float32x4_t ax4 = vld1q_f32(&a[i]); | |
float32x4_t bx4 = vld1q_f32(&b[i]); | |
sumx4 = vmlaq_f32(sumx4, ax4, bx4); | |
} | |
float sum = sumx.val[0] + sumx.val[1] + sumx.val[2] + sumx.val[3]; | |
for (std::size_t i = n - n % INTERVAL; i < n; i++) { | |
sum += a[i] * b[i]; | |
} | |
return sum; | |
#else | |
float sum = 0.0; | |
for (std::size_t i = 0; i < n; i++) { | |
sum += a[i] * b[i]; | |
} | |
return sum; | |
#endif // defined(ENABLE_AVX) | |
} | |
int | |
main() | |
{ | |
static constexpr int N_ELEMENT = 256; | |
std::unique_ptr<float[], AlignedDeleter> a(alignedMalloc<float*>(N_ELEMENT * sizeof(float), ALIGN)); | |
std::unique_ptr<float[], AlignedDeleter> b(alignedMalloc<float*>(N_ELEMENT * sizeof(float), ALIGN)); | |
if (a.get() == nullptr || b.get() == nullptr) { | |
std::cerr << "Failed to allocate memory" << std::endl; | |
return 1; | |
} | |
for (int i = 0; i < N_ELEMENT; i++) { | |
a[i] = static_cast<float>(i); | |
b[i] = static_cast<float>(i); | |
} | |
std::cout << innerProduct(a.get(), b.get(), N_ELEMENT) << std::endl; | |
return 0; | |
} | |
#endif // defined(ENABLE_AVX) && !defined(__AVX__) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment