-
-
Save kaja47/241145603bbebf8766f16da943ce2bdd to your computer and use it in GitHub Desktop.
matrix multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__attribute__((always_inline)) inline float hadd(__m256 x) { | |
x = _mm256_hadd_ps(x, x); | |
x = _mm256_hadd_ps(x, x); | |
return ((float*)&x)[0] + ((float*)&x)[4]; | |
} | |
#define PAD 16 | |
void square_mat_mul_tiered(float *a, float *b, const int len, float *res) { | |
for (int i = 0; i < len*(len+PAD); i++) { res[i] = 0; } | |
const int tile1 = TILE1; | |
const int tile2 = TILE2; | |
const int tile3 = TILE3; | |
const int segment = S; | |
#pragma omp parallel for | |
for (int tilei3 = 0; tilei3 < len; tilei3 += tile3) { | |
const int TILEI1 = tile1*2; | |
__m256 sums[TILEI1*tile1]; | |
for (int tilej3 = 0; tilej3 < len; tilej3 += tile3) { | |
for (int tilei2 = tilei3; tilei2 < tilei3+tile3; tilei2 += tile2) { | |
for (int tilej2 = tilej3; tilej2 < tilej3+tile3; tilej2 += tile2) { | |
for (int tilei1 = tilei2; tilei1 < tilei2+tile2; tilei1 += TILEI1) { | |
for (int tilej1 = tilej2; tilej1 < tilej2+tile2; tilej1 += tile1) { | |
__m256 zero = _mm256_set1_ps(0.0); | |
for (int i = 0; i < TILEI1*tile1; i++) sums[i] = zero; | |
for (int p = 0; p < len; p += segment) { | |
for (int i = tilei1; i < tilei1+TILEI1; i += 2) { | |
for (int j = tilej1; j < tilej1+tile1; j += 2) { | |
int ii = i-tilei1; | |
int jj = j-tilej1; | |
float *_a = a+(i*(len+PAD))+p; | |
float *_b = a+((i+1)*(len+PAD))+p; | |
float *_c = b+(j*(len+PAD))+p; | |
float *_d = b+((j+1)*(len+PAD))+p; | |
__m256 da = sums[ ii *tile1+jj] ; | |
__m256 db = sums[ ii *tile1+jj+1]; | |
__m256 dc = sums[(ii+1)*tile1+jj] ; | |
__m256 dd = sums[(ii+1)*tile1+jj+1]; | |
for (int i = 0; i < segment; i+=8) { | |
__m256 aa = _mm256_load_ps(_a + i); | |
__m256 bb = _mm256_load_ps(_b + i); | |
__m256 cc = _mm256_load_ps(_c + i); | |
__m256 dd = _mm256_load_ps(_d + i); | |
da = _mm256_add_ps(da, _mm256_mul_ps(aa, cc)); | |
db = _mm256_add_ps(db, _mm256_mul_ps(aa, dd)); | |
dc = _mm256_add_ps(dc, _mm256_mul_ps(bb, cc)); | |
dd = _mm256_add_ps(dd, _mm256_mul_ps(bb, dd)); | |
} | |
sums[ ii *tile1+jj] = da; | |
sums[ ii *tile1+jj+1] = db; | |
sums[(ii+1)*tile1+jj] = dc; | |
sums[(ii+1)*tile1+jj+1] = dd; | |
} | |
} | |
} | |
for (int ii = 0; ii < TILEI1; ii += 1) { | |
for (int jj = 0; jj < tile1; jj += 1) { | |
int i = ii+tilei1; | |
int j = jj+tilej1; | |
res[i*len+j] = hadd(sums[ii*tile1+jj]); | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment