-
-
Save ShigekiKarita/15092e7f2c3f96ba007a336dd11f36b3 to your computer and use it in GitHub Desktop.
SSE,AVX組み込み関数を用いたベクトルの内積計算高速化の実験コード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// origin https://gist.githubusercontent.com/belltailjp/4653695/raw/1cf8b5cbb6c3b4d4f9374b8b1ccae702867543ef/simd.cpp | |
#include <iostream> | |
#include <random> | |
#include <algorithm> | |
#include <xmmintrin.h> | |
#include <immintrin.h> | |
// #include <boost/format.hpp> | |
// #include <osakana/stopwatch.hpp> | |
template <typename T> | |
T dot_normal(const T *vec1, const T *vec2, unsigned n) | |
{ | |
T sum = 0; | |
for(unsigned i = 0; i < n; ++i) | |
sum += vec1[i] * vec2[i]; | |
return sum; | |
} | |
float dot_sse(const float *vec1, const float *vec2, unsigned n) | |
{ | |
__m128 u = {0}; | |
for (unsigned i = 0; i < n; i += 4) | |
{ | |
__m128 w = _mm_load_ps(&vec1[i]); | |
__m128 x = _mm_load_ps(&vec2[i]); | |
x = _mm_mul_ps(w, x); | |
u = _mm_add_ps(u, x); | |
} | |
__attribute__((aligned(16))) float t[4] = {0}; | |
_mm_store_ps(t, u); | |
return t[0] + t[1] + t[2] + t[3]; | |
} | |
float dot_avx(const float *vec1, const float *vec2, unsigned n) | |
{ | |
__m256 u = {0}; | |
for(unsigned i = 0; i < n; i += 8) | |
{ | |
__m256 w = _mm256_load_ps(&vec1[i]); | |
__m256 x = _mm256_load_ps(&vec2[i]); | |
x = _mm256_mul_ps(w, x); | |
u = _mm256_add_ps(u, x); | |
} | |
__attribute__((aligned(32))) float t[8] = {0}; | |
_mm256_store_ps(t, u); | |
return t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7]; | |
} | |
float dot_avx_2(const float *vec1, const float *vec2, unsigned n) | |
{ | |
__m256 u1 = {0}; | |
__m256 u2 = {0}; | |
for(unsigned i = 0; i < n; i += 16) | |
{ | |
__m256 w1 = _mm256_load_ps(&vec1[i]); | |
__m256 w2 = _mm256_load_ps(&vec1[i + 8]); | |
__m256 x1 = _mm256_load_ps(&vec2[i]); | |
__m256 x2 = _mm256_load_ps(&vec2[i + 8]); | |
x1 = _mm256_mul_ps(w1, x1); | |
x2 = _mm256_mul_ps(w2, x2); | |
u1 = _mm256_add_ps(u1, x1); | |
u2 = _mm256_add_ps(u2, x2); | |
} | |
u1 = _mm256_add_ps(u1, u2); | |
__attribute__((aligned(32))) float t[8] = {0}; | |
_mm256_store_ps(t, u1); | |
return t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7]; | |
} | |
//FMA版 | |
float dot_avx_fma(const float *vec1, const float *vec2, unsigned n) | |
{ | |
__m256 u1 = {0}; | |
__m256 u2 = {0}; | |
for(unsigned i = 0; i < n; i += 16) | |
{ | |
__m256 w1 = _mm256_load_ps(&vec1[i]); | |
__m256 w2 = _mm256_load_ps(&vec1[i + 8]); | |
__m256 x1 = _mm256_load_ps(&vec2[i]); | |
__m256 x2 = _mm256_load_ps(&vec2[i + 8]); | |
//FMA命令で加算と乗算を行うけど,Haswellアーキテクチャ待ち(´・ω・`) | |
u1 = _mm256_fmadd_ps(w1, x1, u1); | |
u2 = _mm256_fmadd_ps(w2, x2, u2); | |
} | |
u1 = _mm256_add_ps(u1, u2); | |
//レジスタから書き戻し | |
// __attribute__((aligned(32))) | |
alignas(alignof(u1)) float t[8] = {0}; | |
_mm256_store_ps(t, u1); | |
return t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7]; | |
} | |
//FMA版 | |
double dot_avx_fma(const double *vec1, const double *vec2, unsigned n) | |
{ | |
__m256d u1 = {0}; | |
__m256d u2 = {0}; | |
for(unsigned i = 0; i < n; i += 8) | |
{ | |
__m256d w1 = _mm256_load_pd(&vec1[i]); | |
__m256d w2 = _mm256_load_pd(&vec1[i + 4]); | |
__m256d x1 = _mm256_load_pd(&vec2[i]); | |
__m256d x2 = _mm256_load_pd(&vec2[i + 4]); | |
//FMA命令で加算と乗算を行うけど,Haswellアーキテクチャ待ち(´・ω・`) | |
u1 = _mm256_fmadd_pd(w1, x1, u1); | |
u2 = _mm256_fmadd_pd(w2, x2, u2); | |
} | |
u1 = _mm256_add_pd(u1, u2); | |
//レジスタから書き戻し | |
// __attribute__((aligned(32))) | |
alignas(alignof(u1)) double t[4] = {0}; | |
_mm256_store_pd(t, u1); | |
return t[0] + t[1] + t[2] + t[3]; | |
} | |
/* | |
float dot_avx512_fma(const float *vec1, const float *vec2, unsigned n) | |
{ | |
__m512 u1 = {0}; | |
__m512 u2 = {0}; | |
for(unsigned i = 0; i < n; i += 32) | |
{ | |
__m512 w1 = _mm512_load_ps(&vec1[i]); | |
__m512 w2 = _mm512_load_ps(&vec1[i + 16]); | |
__m512 x1 = _mm512_load_ps(&vec2[i]); | |
__m512 x2 = _mm512_load_ps(&vec2[i + 16]); | |
//FMA命令で加算と乗算を行うけど,Haswellアーキテクチャ待ち(´・ω・`) | |
u1 = _mm512_fmadd_ps(w1, x1, u1); | |
u2 = _mm512_fmadd_ps(w2, x2, u2); | |
} | |
u1 = _mm512_add_ps(u1, u2); | |
//レジスタから書き戻し | |
__attribute__((aligned(32))) float t[16] = {0}; | |
_mm512_store_ps(t, u1); | |
float ret = 0; | |
for (unsigned i = 0; i < 16; ++i) { | |
ret += t[i]; | |
} | |
return ret; | |
} | |
*/ | |
#include <chrono> | |
template<class T> | |
double calc_for_a_moment(T t, unsigned ms) | |
{ | |
using MS = std::chrono::milliseconds; | |
auto msec = MS(ms); | |
auto start = std::chrono::high_resolution_clock::now(); | |
int cnt = 0; | |
volatile float sum = 0; //最適化で消されるの防止 | |
auto elapsed = start; | |
while(elapsed-start < msec) | |
{ | |
sum += t(); | |
elapsed = std::chrono::high_resolution_clock::now(); | |
++cnt; | |
} | |
return (double) std::chrono::duration_cast<MS>(elapsed-start).count() / cnt; | |
} | |
template <typename T> | |
void assert_approx(T a, T b, std::string name, T eps=1e-4) { | |
if (std::fabs((a - b) / a) > eps) { | |
std::cerr << "std::fabs(" << a << " - " << b << ") > " << eps << std::endl; | |
throw std::runtime_error(name + " is wrong answer"); | |
} | |
} | |
int main() | |
{ | |
const unsigned len_begin = 8; | |
const unsigned len_end = 512 * 1024; | |
const unsigned len_fact = 2; | |
const unsigned run_ms = 100; | |
std::mt19937 rng; | |
std::uniform_real_distribution<> dst(-1, 1); | |
using F = double; | |
using M256 = __m256; | |
for(unsigned len = len_begin; len <= len_end; len *= len_fact) | |
{ | |
F *p1 = new __attribute__((aligned(32))) F[len + 8]; | |
F *p2 = new __attribute__((aligned(32))) F[len + 8]; | |
F *vec1 = p1; | |
F *vec2 = p2; | |
while(reinterpret_cast<long>(vec1) % 32) ++vec1; | |
while(reinterpret_cast<long>(vec2) % 32) ++vec2; | |
std::generate(vec1, vec1 + len, [&rng, &dst](){ return dst(rng); }); | |
std::generate(vec2, vec2 + len, [&rng, &dst](){ return dst(rng); }); | |
// std::cout << (boost::format("%d %lf %lf %lf %lf") | |
printf("len:\t%d\n" | |
"ref:\t%lf\n" | |
// "sse:\t%lf\n" | |
// "avx:\t%lf\n" | |
// "avx2:\t%lf\n" | |
"avxfma:\t%lf\n", | |
len, | |
calc_for_a_moment([vec1, vec2, len](){ return dot_normal(vec1, vec2, len); }, run_ms), | |
// calc_for_a_moment([vec1, vec2, len](){ return dot_sse (vec1, vec2, len); }, run_ms), | |
// calc_for_a_moment([vec1, vec2, len](){ return dot_avx (vec1, vec2, len); }, run_ms), | |
// calc_for_a_moment([vec1, vec2, len](){ return dot_avx_2 (vec1, vec2, len); }, run_ms), | |
calc_for_a_moment([vec1, vec2, len](){ return dot_avx_fma(vec1, vec2, len); }, run_ms) | |
); | |
auto expected = dot_normal(vec1, vec2, len); | |
// assert_approx(expected, dot_sse(vec1, vec2, len), "sse"); | |
assert_approx(expected, dot_avx_fma(vec1, vec2, len), "avx_fma"); | |
// ) << std::endl; | |
delete[] p1; | |
delete[] p2; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment