Skip to content

Instantly share code, notes, and snippets.

@idatsy
Created October 16, 2024 06:37
Show Gist options
  • Save idatsy/203e3c228a075aa2f4eb84851174b3a2 to your computer and use it in GitHub Desktop.
Save idatsy/203e3c228a075aa2f4eb84851174b3a2 to your computer and use it in GitHub Desktop.
C compile time function dispatch based on available SIMD instruction set
#include <immintrin.h>
// Scalar (default) implementation
__attribute__((target("default")))
void matmul_default(Arr* c, Arr* a, Arr* b) {
int P = a->shape[0];
int Q = a->shape[1];
int R = b->shape[1];
for (int i = 0; i < P; i++) {
for (int j = 0; j < R; j++) {
float tmp = 0.0f;
for (int k = 0; k < Q; k++) {
int pos_a = i * a->strides[0] + k * a->strides[1];
int pos_b = k * b->strides[0] + j * b->strides[1];
tmp += a->values[pos_a] * b->values[pos_b];
}
int pos_c = i * c->strides[0] + j * c->strides[1];
c->values[pos_c] = tmp;
}
}
}
// SSE implementation
__attribute__((target("sse4.1")))
void matmul_sse(Arr* c, Arr* a, Arr* b) {
int P = a->shape[0];
int Q = a->shape[1];
int R = b->shape[1];
for (int i = 0; i < P; i++) {
for (int j = 0; j < R; j += 4) {
__m128 sum = _mm_setzero_ps();
for (int k = 0; k < Q; k++) {
int pos_a = i * a->strides[0] + k * a->strides[1];
__m128 a_val = _mm_set1_ps(a->values[pos_a]);
int pos_b = k * b->strides[0] + j * b->strides[1];
__m128 b_val = _mm_loadu_ps(&b->values[pos_b]);
sum = _mm_add_ps(sum, _mm_mul_ps(a_val, b_val));
}
int pos_c = i * c->strides[0] + j * c->strides[1];
_mm_storeu_ps(&c->values[pos_c], sum);
}
}
}
// AVX2 implementation
__attribute__((target("avx2")))
void matmul_avx2(Arr* c, Arr* a, Arr* b) {
int P = a->shape[0];
int Q = a->shape[1];
int R = b->shape[1];
for (int i = 0; i < P; i++) {
for (int j = 0; j < R; j += 8) {
__m256 sum = _mm256_setzero_ps();
for (int k = 0; k < Q; k++) {
int pos_a = i * a->strides[0] + k * a->strides[1];
__m256 a_val = _mm256_set1_ps(a->values[pos_a]);
int pos_b = k * b->strides[0] + j * b->strides[1];
__m256 b_val = _mm256_loadu_ps(&b->values[pos_b]);
sum = _mm256_add_ps(sum, _mm256_mul_ps(a_val, b_val));
}
int pos_c = i * c->strides[0] + j * c->strides[1];
_mm256_storeu_ps(&c->values[pos_c], sum);
}
}
}
// AVX-512 implementation
__attribute__((target("avx512f")))
void matmul_avx512(Arr* c, Arr* a, Arr* b) {
int P = a->shape[0];
int Q = a->shape[1];
int R = b->shape[1];
for (int i = 0; i < P; i++) {
for (int j = 0; j < R; j += 16) {
__m512 sum = _mm512_setzero_ps();
for (int k = 0; k < Q; k++) {
int pos_a = i * a->strides[0] + k * a->strides[1];
__m512 a_val = _mm512_set1_ps(a->values[pos_a]);
int pos_b = k * b->strides[0] + j * b->strides[1];
__m512 b_val = _mm512_loadu_ps(&b->values[pos_b]);
sum = _mm512_add_ps(sum, _mm512_mul_ps(a_val, b_val));
}
int pos_c = i * c->strides[0] + j * c->strides[1];
_mm512_storeu_ps(&c->values[pos_c], sum);
}
}
}
// Dispatcher function
void matmul(Arr* c, Arr* a, Arr* b) {
matmul_default(c, a, b); // This will call the best available version
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment