Skip to content

Instantly share code, notes, and snippets.

@koturn
Created July 17, 2016 09:01
Show Gist options
  • Save koturn/43ac1695297b6ba88f4ea95091c7a114 to your computer and use it in GitHub Desktop.
Save koturn/43ac1695297b6ba88f4ea95091c7a114 to your computer and use it in GitHub Desktop.
Calculate inner product with SIMD instruction
#if defined(ENABLE_AVX) && !defined(__AVX__)
# error Macro: ENABLE_AVX is defined, but unable to use AVX intrinsic functions
#elif defined(ENABLE_SSE) && !defined(__SSE2__)
# error Macro: ENABLE_SSE is defined, but unable to use SSE intrinsic functions
#elif defined(ENABLE_NEON) && !defined(__ARM_NEON) && !defined(__ARM_NEON__)
# error Macro: ENABLE_NEON is defined, but unable to use NEON intrinsic functions
#else
#include <cstddef>
#include <iostream>
#include <memory>
#include <type_traits>
#if defined(_MSC_VER) || defined(__MINGW32__)
# include <malloc.h>
#else
# include <cstdlib>
#endif // defined(_MSC_VER) || defined(__MINGW32__)
#if defined(ENABLE_AVX) || defined(ENABLE_SSE)
# ifdef _MSC_VER
# include <intrin.h>
# else
# include <x86intrin.h>
# endif // _MSC_VER
#elif defined(ENABLE_NEON)
# include <arm_neon.h>
#endif // defined(ENABLE_AVX) || defined(ENABLE_SSE)
#if defined(__cplusplus) && __cplusplus < 201103L
# ifdef _MSC_VER
# define alignof(n) __alignof(n)
# else
# define alignof(n) __alignof(n)__
# endif // _MSC_VER
#endif // defined(__cplusplus) || cplusplus < 201103L
template<typename T = void*, typename std::enable_if<std::is_pointer<T>::value, std::nullptr_t>::type = nullptr>
static inline T
alignedMalloc(std::size_t size, std::size_t alignment) noexcept
{
#if defined(_MSC_VER) || defined(__MINGW32__)
return reinterpret_cast<T>(_aligned_malloc(size, alignment));
#else
void* p;
return reinterpret_cast<T>(posix_memalign(&p, alignment, size) == 0 ? p : nullptr);
#endif // defined(_MSC_VER) || defined(__MINGW32__)
}
static inline void
alignedFree(void* ptr) noexcept
{
#if defined(_MSC_VER) || defined(__MINGW32__)
_aligned_free(ptr);
#else
std::free(ptr);
#endif // defined(_MSC_VER) || defined(__MINGW32__)
}
struct AlignedDeleter
{
void
operator()(void* p) const noexcept
{
alignedFree(p);
}
};
#if defined(ENABLE_AVX)
static constexpr int ALIGN = alignof(__m256);
#elif defined(ENABLE_SSE)
static constexpr int ALIGN = alignof(__m128);
#elif defined(ENABLE_NEON)
static constexpr int ALIGN = alignof(float32x4_t);
#else
static constexpr int ALIGN = 8;
#endif // defined(ENABLE_AVX)
static inline float
innerProduct(const float* a, const float* b, std::size_t n) noexcept
{
#if defined(ENABLE_AVX)
static constexpr std::size_t INTERVAL = sizeof(__m256) / sizeof(float);
__m256 sumx8 = {0};
for (std::size_t i = 0; i < n; i += INTERVAL) {
__m256 ax8 = _mm256_load_ps(&a[i]);
__m256 bx8 = _mm256_load_ps(&b[i]);
ax8 = _mm256_mul_ps(ax8, bx8);
sumx8 = _mm256_add_ps(sumx8, ax8);
}
alignas(ALIGN) float s[INTERVAL] = {0};
_mm256_store_ps(s, sumx8);
float sum = s[0] + s[1] + s[2] + s[3] + s[4] + s[5] + s[6] + s[7];
for (std::size_t i = n - n % INTERVAL; i < n; i++) {
sum += a[i] * b[i];
}
return sum;
#elif defined(ENABLE_SSE)
static constexpr std::size_t INTERVAL = sizeof(__m128) / sizeof(float);
__m128 sumx4 = {0};
for (std::size_t i = 0; i < n; i += INTERVAL) {
__m128 ax4 = _mm_load_ps(&a[i]);
__m128 bx4 = _mm_load_ps(&b[i]);
ax4 = _mm_mul_ps(ax4, bx4);
sumx4 = _mm_add_ps(sumx4, ax4);
}
alignas(ALIGN) float s[INTERVAL] = {0};
_mm_store_ps(s, sumx4);
float sum = s[0] + s[1] + s[2] + s[3];
for (std::size_t i = n - n % INTERVAL; i < n; i++) {
sum += a[i] * b[i];
}
return sum;
#elif defined(ENABLE_NEON)
static constexpr std::size_t INTERVAL = sizeof(float32x4_t) / sizeof(float);
float32x4_t sumx4 = {0};
for (std::size_t i = 0; i < n; i += INTERVAL) {
float32x4_t ax4 = vld1q_f32(&a[i]);
float32x4_t bx4 = vld1q_f32(&b[i]);
sumx4 = vmlaq_f32(sumx4, ax4, bx4);
}
float sum = sumx.val[0] + sumx.val[1] + sumx.val[2] + sumx.val[3];
for (std::size_t i = n - n % INTERVAL; i < n; i++) {
sum += a[i] * b[i];
}
return sum;
#else
float sum = 0.0;
for (std::size_t i = 0; i < n; i++) {
sum += a[i] * b[i];
}
return sum;
#endif // defined(ENABLE_AVX)
}
int
main()
{
static constexpr int N_ELEMENT = 256;
std::unique_ptr<float[], AlignedDeleter> a(alignedMalloc<float*>(N_ELEMENT * sizeof(float), ALIGN));
std::unique_ptr<float[], AlignedDeleter> b(alignedMalloc<float*>(N_ELEMENT * sizeof(float), ALIGN));
if (a.get() == nullptr || b.get() == nullptr) {
std::cerr << "Failed to allocate memory" << std::endl;
return 1;
}
for (int i = 0; i < N_ELEMENT; i++) {
a[i] = static_cast<float>(i);
b[i] = static_cast<float>(i);
}
std::cout << innerProduct(a.get(), b.get(), N_ELEMENT) << std::endl;
return 0;
}
#endif // defined(ENABLE_AVX) && !defined(__AVX__)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment