Skip to content

Instantly share code, notes, and snippets.

@ChenZhongPu
Last active April 19, 2023 09:39
Show Gist options
  • Save ChenZhongPu/98e8de25652970d67e083bbe0ce58886 to your computer and use it in GitHub Desktop.
Save ChenZhongPu/98e8de25652970d67e083bbe0ce58886 to your computer and use it in GitHub Desktop.
SIMD Demo
float avx_dot_product(float *a, float *b, int n) {
__m256 sum = _mm256_setzero_ps(); // set sum to 0
for (int i = 0; i < n; i += 8) {
__m256 a_vals = _mm256_loadu_ps(a + i); // load 8 floats from a
__m256 b_vals = _mm256_loadu_ps(b + i); // load 8 floats from b
sum =
_mm256_add_ps(sum, _mm256_mul_ps(a_vals, b_vals)); // sum += a[i] * b[i]
}
__m128 low = _mm256_castps256_ps128(sum);
__m128 high = _mm256_extractf128_ps(sum, 1);
low = _mm_add_ps(low, high);
low = _mm_hadd_ps(low, low);
low = _mm_hadd_ps(low, low);
return _mm_cvtss_f32(low);
}
#if defined(__ARM_NEON__)
#include <arm_neon.h>
#else
#include <immintrin.h>
#endif
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#define N 1000000
double time_func(float (*f)(float *, float *, int), float *a, float *b, int n) {
struct timespec start, end;
clock_gettime(CLOCK_MONOTONIC, &start);
f(a, b, n);
clock_gettime(CLOCK_MONOTONIC, &end);
double diff = (end.tv_sec - start.tv_sec) * 1e9;
diff += (end.tv_nsec - start.tv_nsec);
return diff;
}
float simd_dot_product(float *a, float *b, int n) {
float result = 0.0f;
#if defined(__ARM_NEON__)
float32x4_t sum = vdupq_n_f32(0.0f);
for (int i = 0; i < n; i += 4) {
float32x4_t va = vld1q_f32(a + i);
float32x4_t vb = vld1q_f32(b + i);
sum = vmlaq_f32(sum, va, vb);
}
float32x2_t sum2 = vadd_f32(vget_low_f32(sum), vget_high_f32(sum));
sum2 = vpadd_f32(sum2, sum2);
result = vget_lane_f32(sum2, 0);
#else
__m128 sum = _mm_setzero_ps(); // set sum to 0
for (int i = 0; i < n; i += 4) {
__m128 a_vals = _mm_loadu_ps(a + i); // load 4 floats from a
__m128 b_vals = _mm_loadu_ps(b + i); // load 4 floats from b
sum = _mm_add_ps(sum, _mm_mul_ps(a_vals, b_vals)); // sum += a[i] * b[i]
}
// add horizontal sum of values in sum
sum = _mm_hadd_ps(sum, sum);
sum = _mm_hadd_ps(sum, sum);
// convert to float and return
_mm_store_ss(&result, sum);
#endif
return result;
}
float dot_product(float *a, float *b, int n) {
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += a[i] * b[i];
}
return sum;
}
int main(int argc, char *argv[]) {
float a[N];
float b[N];
for (int i = 0; i < N; i++) {
// random values
a[i] = rand() / (float)RAND_MAX;
b[i] = rand() / (float)RAND_MAX;
}
double t = time_func(simd_dot_product, a, b, N);
printf("SIMD time = %f ns\n", t);
double t2 = time_func(dot_product, a, b, N);
printf("Scalar time = %f ns\n", t2);
printf("SIMD speedup = %f\n", t2 / t);
return EXIT_SUCCESS;
}
@ChenZhongPu
Copy link
Author

ChenZhongPu commented Apr 18, 2023

In my Intel desktop with -O0 flag:

SSE SIMD time = 812549.000000 ns
AVX SIMD time = 433023.000000 ns
Scalar time = 1630100.000000 ns
SSE speedup = 2.006156
AVX speedup = 3.764465

@ChenZhongPu
Copy link
Author

With -O3 flag:

SSE SIMD time = 261804.000000 ns
AVX SIMD time = 221153.000000 ns
Scalar time = 417953.000000 ns
SSE speedup = 1.596435
AVX speedup = 1.889882

@ChenZhongPu
Copy link
Author

While on Mac M1 Pro, with -O3 flag:

SIMD time = 507000.000000 ns
Scalar time = 1364000.000000 ns
SIMD speedup = 2.690335

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment