Skip to content

Instantly share code, notes, and snippets.

@zhangce
Created April 15, 2015 22:22
Show Gist options
  • Select an option

  • Save zhangce/3cb47f9b5e0d401fd3f0 to your computer and use it in GitHub Desktop.

Select an option

Save zhangce/3cb47f9b5e0d401fd3f0 to your computer and use it in GitHub Desktop.
#include <stdio.h>
#include <stdlib.h>
#include <x86intrin.h>
#include "timer.h"
#include "cblas.h"
#define BLOCKSIZE 256
#define cstore16x1_rstore1x6_gemm(A_c, B_r) \
A_ci_first8 = _mm256_load_ps(A_c); \
A_ci_second8 = _mm256_load_ps(A_c+8); \
__builtin_prefetch(A_c+16, 0); \
__builtin_prefetch(A_c+24, 0); \
B_ri_num1 = _mm256_broadcast_ss(B_r); \
B_ri_num2 = _mm256_broadcast_ss(B_r+1); \
C_c1_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num1, C_c1_first8); \
C_c1_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num1, C_c1_second8); \
C_c2_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num2, C_c2_first8); \
C_c2_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num2, C_c2_second8); \
B_ri_num1 = _mm256_broadcast_ss(B_r+2); \
B_ri_num2 = _mm256_broadcast_ss(B_r+3); \
C_c3_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num1, C_c3_first8); \
C_c3_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num1, C_c3_second8); \
C_c4_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num2, C_c4_first8); \
C_c4_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num2, C_c4_second8); \
B_ri_num1 = _mm256_broadcast_ss(B_r+4); \
B_ri_num2 = _mm256_broadcast_ss(B_r+5); \
C_c5_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num1, C_c5_first8); \
C_c5_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num1, C_c5_second8); \
C_c6_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num2, C_c6_first8); \
C_c6_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num2, C_c6_second8);
void cstore16x5x5x1_rstore28x28x1_chunk_within_depth(
const float * const K, const float * const D, float * O,
const int chunk){
const int ir = chunk/24;
const int ic = chunk%24;
__m256 num1;
__m256 num2;
__m256 K_c_first8, K_c_second8;
__m256 C_c1_first8 = _mm256_load_ps(O) ; __m256 C_c1_second8 = _mm256_load_ps(O+8) ;
__m256 C_c2_first8 = _mm256_load_ps(O + 16) ; __m256 C_c2_second8 = _mm256_load_ps(O+24) ;
__m256 C_c3_first8 = _mm256_load_ps(O + 32) ; __m256 C_c3_second8 = _mm256_load_ps(O+40) ;
__m256 C_c4_first8 = _mm256_load_ps(O + 48) ; __m256 C_c4_second8 = _mm256_load_ps(O+56) ;
__m256 C_c5_first8 = _mm256_load_ps(O + 64) ; __m256 C_c5_second8 = _mm256_load_ps(O+72) ;
__m256 C_c6_first8 = _mm256_load_ps(O + 80) ; __m256 C_c6_second8 = _mm256_load_ps(O+88) ;
const float * K_c = K;
for(int kr=0;kr<5;kr++){
for(int kc=0;kc<5;kc++){
K_c_first8 = _mm256_load_ps(K_c);
K_c_second8 = _mm256_load_ps(K_c+8);
const float * _d = &D[ic+kc + (ir+kr)*28];
num1 = _mm256_broadcast_ss(_d);
num2 = _mm256_broadcast_ss(_d+1);
C_c1_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c1_first8);
C_c1_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c1_second8);
C_c2_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c2_first8);
C_c2_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c2_second8);
num1 = _mm256_broadcast_ss(_d+2);
num2 = _mm256_broadcast_ss(_d+3);
C_c3_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c3_first8);
C_c3_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c3_second8);
C_c4_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c4_first8);
C_c4_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c4_second8);
num1 = _mm256_broadcast_ss(_d+4);
num2 = _mm256_broadcast_ss(_d+5);
C_c5_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c5_first8);
C_c5_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c5_second8);
C_c6_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c6_first8);
C_c6_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c6_second8);
K_c += 16; // next column of the kernel
}
}
_mm256_store_ps(O, C_c1_first8);
_mm256_store_ps(O+8, C_c1_second8);
_mm256_store_ps(O+16, C_c2_first8);
_mm256_store_ps(O+24, C_c2_second8);
_mm256_store_ps(O+32, C_c3_first8);
_mm256_store_ps(O+40, C_c3_second8);
_mm256_store_ps(O+48, C_c4_first8);
_mm256_store_ps(O+56, C_c4_second8);
_mm256_store_ps(O+64, C_c5_first8);
_mm256_store_ps(O+72, C_c5_second8);
_mm256_store_ps(O+80, C_c6_first8);
_mm256_store_ps(O+88, C_c6_second8);
}
void cstore16x5x5x96_rstore28x28x96_CONV_within_depth(
const float * const K, const float * const D, float * O){
const int oR = (28-5+1);
Timer t;
// D is blocked by two things -- feature map, and every 6 output position,
// First, get the pointer for each feature map
for(int d=0;d<96;d++){
const float * const D_ = &D[d*28*28];
const float * const K_ = &K[d*5*5*16];
for(int i=0;i<oR*oR;i+=6){
float * const O_ = &O[i*16];
cstore16x5x5x1_rstore28x28x1_chunk_within_depth(K_, D_, O_, i);
}
}
float elapsed = t.elapsed();
printf("time: %f seconds\n", elapsed);
printf("time for 256x256: %f seconds\n", elapsed * 256 * (256/16));
float flop = 16*oR*oR*5*5*96*2;
float gflops = flop/elapsed/1024/1024/1024;
printf("flops: %f \n", flop);
printf("GFlops: %f GFLOPS\n", gflops);
}
void cstore16x5x5x1_rstore28x28x1_chunk(
const float * const K, const float * const D, float * O,
const int chunk){
const int ir = chunk/24;
const int ic = chunk%24;
__m256 num1;
__m256 num2;
__m256 K_c_first8, K_c_second8;
__m256 C_c1_first8 = _mm256_load_ps(O) ; __m256 C_c1_second8 = _mm256_load_ps(O+8) ;
__m256 C_c2_first8 = _mm256_load_ps(O + 16) ; __m256 C_c2_second8 = _mm256_load_ps(O+24) ;
__m256 C_c3_first8 = _mm256_load_ps(O + 32) ; __m256 C_c3_second8 = _mm256_load_ps(O+40) ;
__m256 C_c4_first8 = _mm256_load_ps(O + 48) ; __m256 C_c4_second8 = _mm256_load_ps(O+56) ;
__m256 C_c5_first8 = _mm256_load_ps(O + 64) ; __m256 C_c5_second8 = _mm256_load_ps(O+72) ;
__m256 C_c6_first8 = _mm256_load_ps(O + 80) ; __m256 C_c6_second8 = _mm256_load_ps(O+88) ;
const float * K_c = K;
for(int depth=0;depth<96;depth++){
for(int kr=0;kr<5;kr++){
for(int kc=0;kc<5;kc++){
K_c_first8 = _mm256_load_ps(K_c);
K_c_second8 = _mm256_load_ps(K_c+8);
__builtin_prefetch(K_c+16, 0);
__builtin_prefetch(K_c+24, 0);
const float * _d = &D[ic+kc + (ir+kr)*28 + depth*28*28];
num1 = _mm256_broadcast_ss(_d);
num2 = _mm256_broadcast_ss(_d+1);
C_c1_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c1_first8);
C_c1_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c1_second8);
C_c2_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c2_first8);
C_c2_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c2_second8);
num1 = _mm256_broadcast_ss(_d+2);
num2 = _mm256_broadcast_ss(_d+3);
C_c3_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c3_first8);
C_c3_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c3_second8);
C_c4_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c4_first8);
C_c4_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c4_second8);
num1 = _mm256_broadcast_ss(_d+4);
num2 = _mm256_broadcast_ss(_d+5);
C_c5_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c5_first8);
C_c5_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c5_second8);
C_c6_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c6_first8);
C_c6_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c6_second8);
K_c += 16; // next column of the kernel
}
}
}
_mm256_store_ps(O, C_c1_first8);
_mm256_store_ps(O+8, C_c1_second8);
_mm256_store_ps(O+16, C_c2_first8);
_mm256_store_ps(O+24, C_c2_second8);
_mm256_store_ps(O+32, C_c3_first8);
_mm256_store_ps(O+40, C_c3_second8);
_mm256_store_ps(O+48, C_c4_first8);
_mm256_store_ps(O+56, C_c4_second8);
_mm256_store_ps(O+64, C_c5_first8);
_mm256_store_ps(O+72, C_c5_second8);
_mm256_store_ps(O+80, C_c6_first8);
_mm256_store_ps(O+88, C_c6_second8);
}
void cstore16x5x5x96_rstore28x28x96_CONV(
const float * const K, const float * const D, float * O){
const int oR = (28-5+1);
Timer t;
// D is blocked by two things -- feature map, and every 6 output position,
// However, it turns out the we do not need to block by feature map
// compare with this function with the variant cstore16x5x5x96_rstore28x28x96_CONV_within_depth
const float * const D_ = D;
const float * const K_ = K;
for(int i=0;i<oR*oR;i+=6){
float * const O_ = &O[i*16];
cstore16x5x5x1_rstore28x28x1_chunk(K_, D_, O_, i);
}
float elapsed = t.elapsed();
printf("time: %f seconds\n", elapsed);
printf("time for 256x256: %f seconds\n", elapsed * 256 * (256/16));
float flop = 16*oR*oR*5*5*96*2;
float gflops = flop/elapsed/1024/1024/1024;
printf("flops: %f \n", flop);
printf("GFlops: %f GFLOPS\n", gflops);
float io = 1.0 * (16*5*5*96+28*28*96+oR*oR*16) * sizeof(float);
float gb = 1.0* io / elapsed / 1024/1024/1024;
printf("GB: %f GB/s\n", gb);
}
void cstore5x5_rstore28x28_CONV(
const float * const K, const float * const D, float * O, const int depth, const int oD){
}
inline void cstore16x4_rstore4x6_gemm(
const float * const A,
const float * const B,
float * const C){
__m256 C_c1_first8 ; __m256 C_c1_second8 ;
__m256 C_c2_first8 ; __m256 C_c2_second8 ;
__m256 C_c3_first8 ; __m256 C_c3_second8 ;
__m256 C_c4_first8 ; __m256 C_c4_second8 ;
__m256 C_c5_first8 ; __m256 C_c5_second8 ;
__m256 C_c6_first8 ; __m256 C_c6_second8 ;
__m256 A_ci_first8;
__m256 A_ci_second8;
__m256 B_ri_num1;
__m256 B_ri_num2;
for(int i=0;i<BLOCKSIZE;i++){
cstore16x1_rstore1x6_gemm(A+i*16, B+i*6);
}
_mm256_store_ps(C, C_c1_first8);
_mm256_store_ps(C+8, C_c1_second8);
_mm256_store_ps(C+16, C_c2_first8);
_mm256_store_ps(C+24, C_c2_second8);
_mm256_store_ps(C+32, C_c3_first8);
_mm256_store_ps(C+40, C_c3_second8);
_mm256_store_ps(C+48, C_c4_first8);
_mm256_store_ps(C+56, C_c4_second8);
_mm256_store_ps(C+64, C_c5_first8);
_mm256_store_ps(C+72, C_c5_second8);
_mm256_store_ps(C+80, C_c6_first8);
_mm256_store_ps(C+88, C_c6_second8);
}
void test_throughput(){
const int N_ABLOCK = 1;
const int N_BBLOCK = N_ABLOCK;
const int N_CBLOCK = N_ABLOCK;
float * const _A = (float*) _mm_malloc(16*BLOCKSIZE*sizeof(float) * N_ABLOCK, 32);
float * const _B = (float*) _mm_malloc(BLOCKSIZE*6*sizeof(float) * N_BBLOCK, 32);
float * const _C = (float*) _mm_malloc(16*6*sizeof(float) * N_CBLOCK, 32);
for(int i=0;i<16*BLOCKSIZE;i++){
_A[i] = i % 100;
}
for(int i=0;i<BLOCKSIZE*6;i++){
_B[i] = i * 100;
}
Timer t;
float * A = _A;
float * B = _B;
float * C = _C;
const int inc_a = 1;
const int inc_b = 1;
const int inc_c = 1;
for(int i=0;i<N_ABLOCK;i++){
//cstore16x4_rstore4x6_gemm(
// A, A+16, A+32, A+48,
// B, B+6, A+12, B+18,
// C, C+16, C+32, C+48, C+64, C+80);
cstore16x4_rstore4x6_gemm(A, B, C);
//cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, 16, 6, BLOCKSIZE,
// 1.0, A, 16, B, 6, 1.0, C, 16);
//cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 16, 6, BLOCKSIZE,
// 1.0, A, 16, B, BLOCKSIZE, 1.0, C, 16);
A += inc_a * 16*4;
B += inc_b * 4*6;
C += inc_c * 16*6;
}
double elapsed = t.elapsed();
printf("Elapsed: %f\n seconds\n", elapsed);
float flop = 1.0* N_ABLOCK * 16 * BLOCKSIZE * 6 * 2;
float gflops = flop/elapsed/1024/1024/1024;
printf("GFlops: %f GFLOPS\n", gflops);
float io = 1.0 * N_ABLOCK * (16*BLOCKSIZE*inc_a + BLOCKSIZE*6*inc_b + 16*6*inc_c) * sizeof(float);
float gb = 1.0* io / elapsed / 1024/1024/1024;
printf("GB: %f GB/s\n", gb);
printf("C=\n");
for(int r=0;r<16;r++){
for(int c=0;c<6;c++){
printf("%f ", _C[r+c*16]);
}
printf("\n");
}
}
void test_16x6(){
float * const A = (float*) _mm_malloc(16*4*sizeof(float), 32);
for(int i=0;i<16*4;i++){
A[i] = i;
}
float * const B = (float*) _mm_malloc(4*6*sizeof(float), 32);
for(int i=0;i<4*6;i++){
B[i] = i;
}
float * const C = (float*) _mm_malloc(16*6*sizeof(float), 32);
/*
cstore16x4_rstore4x6_gemm(
A,
B,
C);
*/
//cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, 16, 6, 4,
// 1.0, A, 16, B, 6, 1.0, C, 16);
printf("A=\n");
for(int r=0;r<16;r++){
for(int c=0;c<4;c++){
printf("%f ", A[r+c*16]);
}
printf("\n");
}
printf("B=\n");
for(int r=0;r<4;r++){
for(int c=0;c<6;c++){
printf("%f ", B[c+r*6]);
}
printf("\n");
}
printf("C=\n");
for(int r=0;r<16;r++){
for(int c=0;c<6;c++){
printf("%f ", C[r+c*16]);
}
printf("\n");
}
/*
for(int i=0;i<16*6;i++){
printf("%f ", C[i]);
}
printf("\n");
*/
}
void test_blas_speed_16x6(){
const int M = 1000;
const int N = M;
const int K = M;
float * const A = (float*) _mm_malloc(M*N*sizeof(float), 32);
for(int i=0;i<16*4;i++){
A[i] = i;
}
float * const B = (float*) _mm_malloc(N*K*sizeof(float), 32);
for(int i=0;i<4*6;i++){
B[i] = i;
}
float * const C = (float*) _mm_malloc(N*K*sizeof(float), 32);
Timer t;
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, M, N, K,
1.0, A, M, B, M, 1.0, C, M);
float elapsed = t.elapsed();
printf("Elapsed: %f\n seconds\n", elapsed);
float flop = 1.0* M * N * K * 2;
float gflops = flop/elapsed/1024/1024/1024;
printf("GFlops: %f GFLOPS\n", gflops);
printf("A=\n");
for(int r=0;r<16;r++){
for(int c=0;c<4;c++){
printf("%f ", A[r+c*16]);
}
printf("\n");
}
printf("B=\n");
for(int r=0;r<4;r++){
for(int c=0;c<6;c++){
printf("%f ", B[c+r*6]);
}
printf("\n");
}
printf("C=\n");
for(int r=0;r<16;r++){
for(int c=0;c<6;c++){
printf("%f ", C[r+c*16]);
}
printf("\n");
}
/*
for(int i=0;i<16*6;i++){
printf("%f ", C[i]);
}
printf("\n");
*/
}
void cstore16x5x5x96_rstore28x28x96_baseline(
const float * const K, const float * const D, float * O){
const int oR = (28-5+1);
const int oC = (28-5+1);
for(int omap=0;omap<16;omap++){
for(int outr=0;outr<oR;outr++){
for(int outc=0;outc<oC;outc++){
float sum = 0.0;
for(int kr=0;kr<5;kr++){
for(int kc=0;kc<5;kc++){
for(int depth=0;depth<96;depth++){
float d = D[depth*28*28 + (outr+kr)*28 + (outc+kc)];
float k = K[omap + (depth*5*5+kr*5+kc)*16];
sum += d*k;
}
}
}
O[omap + (outr*oR+outc)*16] = sum;
}
}
}
}
int main(int argc, char** argv){
//cstore16x4_rstore4x6_gemm(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
// NULL, NULL, NULL, NULL, NULL, NULL);
openblas_set_num_threads(1);
//test_16x6();
//test_blas_speed_16x6();
//test_throughput();
float * const K = (float*) _mm_malloc(16*5*5*96*sizeof(float), 32);
for(int i=0;i<16*5*5*96;i++){
K[i] = i%7;
}
float * const D = (float*) _mm_malloc(28*28*96*sizeof(float), 32);
for(int i=0;i<28*28*96;i++){
D[i] = i%7;
}
float * const O = (float*) _mm_malloc(16*24*24*sizeof(float), 32);
cstore16x5x5x96_rstore28x28x96_CONV(K, D, O);
//cstore16x5x5x96_rstore28x28x96_baseline(K,D,O);
/*
for(int i=0;i<16;i++){
for(int j=0;j<24*24;j++){
printf("%f ", O[i*24*24+j]);
}
printf("\n");
}
*/
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment