Created
October 16, 2024 06:37
-
-
Save idatsy/203e3c228a075aa2f4eb84851174b3a2 to your computer and use it in GitHub Desktop.
C compile time function dispatch based on available SIMD instruction set
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
// Scalar (default) implementation | |
__attribute__((target("default"))) | |
void matmul_default(Arr* c, Arr* a, Arr* b) { | |
int P = a->shape[0]; | |
int Q = a->shape[1]; | |
int R = b->shape[1]; | |
for (int i = 0; i < P; i++) { | |
for (int j = 0; j < R; j++) { | |
float tmp = 0.0f; | |
for (int k = 0; k < Q; k++) { | |
int pos_a = i * a->strides[0] + k * a->strides[1]; | |
int pos_b = k * b->strides[0] + j * b->strides[1]; | |
tmp += a->values[pos_a] * b->values[pos_b]; | |
} | |
int pos_c = i * c->strides[0] + j * c->strides[1]; | |
c->values[pos_c] = tmp; | |
} | |
} | |
} | |
// SSE implementation | |
__attribute__((target("sse4.1"))) | |
void matmul_sse(Arr* c, Arr* a, Arr* b) { | |
int P = a->shape[0]; | |
int Q = a->shape[1]; | |
int R = b->shape[1]; | |
for (int i = 0; i < P; i++) { | |
for (int j = 0; j < R; j += 4) { | |
__m128 sum = _mm_setzero_ps(); | |
for (int k = 0; k < Q; k++) { | |
int pos_a = i * a->strides[0] + k * a->strides[1]; | |
__m128 a_val = _mm_set1_ps(a->values[pos_a]); | |
int pos_b = k * b->strides[0] + j * b->strides[1]; | |
__m128 b_val = _mm_loadu_ps(&b->values[pos_b]); | |
sum = _mm_add_ps(sum, _mm_mul_ps(a_val, b_val)); | |
} | |
int pos_c = i * c->strides[0] + j * c->strides[1]; | |
_mm_storeu_ps(&c->values[pos_c], sum); | |
} | |
} | |
} | |
// AVX2 implementation | |
__attribute__((target("avx2"))) | |
void matmul_avx2(Arr* c, Arr* a, Arr* b) { | |
int P = a->shape[0]; | |
int Q = a->shape[1]; | |
int R = b->shape[1]; | |
for (int i = 0; i < P; i++) { | |
for (int j = 0; j < R; j += 8) { | |
__m256 sum = _mm256_setzero_ps(); | |
for (int k = 0; k < Q; k++) { | |
int pos_a = i * a->strides[0] + k * a->strides[1]; | |
__m256 a_val = _mm256_set1_ps(a->values[pos_a]); | |
int pos_b = k * b->strides[0] + j * b->strides[1]; | |
__m256 b_val = _mm256_loadu_ps(&b->values[pos_b]); | |
sum = _mm256_add_ps(sum, _mm256_mul_ps(a_val, b_val)); | |
} | |
int pos_c = i * c->strides[0] + j * c->strides[1]; | |
_mm256_storeu_ps(&c->values[pos_c], sum); | |
} | |
} | |
} | |
// AVX-512 implementation | |
__attribute__((target("avx512f"))) | |
void matmul_avx512(Arr* c, Arr* a, Arr* b) { | |
int P = a->shape[0]; | |
int Q = a->shape[1]; | |
int R = b->shape[1]; | |
for (int i = 0; i < P; i++) { | |
for (int j = 0; j < R; j += 16) { | |
__m512 sum = _mm512_setzero_ps(); | |
for (int k = 0; k < Q; k++) { | |
int pos_a = i * a->strides[0] + k * a->strides[1]; | |
__m512 a_val = _mm512_set1_ps(a->values[pos_a]); | |
int pos_b = k * b->strides[0] + j * b->strides[1]; | |
__m512 b_val = _mm512_loadu_ps(&b->values[pos_b]); | |
sum = _mm512_add_ps(sum, _mm512_mul_ps(a_val, b_val)); | |
} | |
int pos_c = i * c->strides[0] + j * c->strides[1]; | |
_mm512_storeu_ps(&c->values[pos_c], sum); | |
} | |
} | |
} | |
// Dispatcher function | |
void matmul(Arr* c, Arr* a, Arr* b) { | |
matmul_default(c, a, b); // This will call the best available version | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment