Created
April 15, 2015 22:22
-
-
Save zhangce/3cb47f9b5e0d401fd3f0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <stdio.h> | |
| #include <stdlib.h> | |
| #include <x86intrin.h> | |
| #include "timer.h" | |
| #include "cblas.h" | |
| #define BLOCKSIZE 256 | |
| #define cstore16x1_rstore1x6_gemm(A_c, B_r) \ | |
| A_ci_first8 = _mm256_load_ps(A_c); \ | |
| A_ci_second8 = _mm256_load_ps(A_c+8); \ | |
| __builtin_prefetch(A_c+16, 0); \ | |
| __builtin_prefetch(A_c+24, 0); \ | |
| B_ri_num1 = _mm256_broadcast_ss(B_r); \ | |
| B_ri_num2 = _mm256_broadcast_ss(B_r+1); \ | |
| C_c1_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num1, C_c1_first8); \ | |
| C_c1_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num1, C_c1_second8); \ | |
| C_c2_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num2, C_c2_first8); \ | |
| C_c2_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num2, C_c2_second8); \ | |
| B_ri_num1 = _mm256_broadcast_ss(B_r+2); \ | |
| B_ri_num2 = _mm256_broadcast_ss(B_r+3); \ | |
| C_c3_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num1, C_c3_first8); \ | |
| C_c3_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num1, C_c3_second8); \ | |
| C_c4_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num2, C_c4_first8); \ | |
| C_c4_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num2, C_c4_second8); \ | |
| B_ri_num1 = _mm256_broadcast_ss(B_r+4); \ | |
| B_ri_num2 = _mm256_broadcast_ss(B_r+5); \ | |
| C_c5_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num1, C_c5_first8); \ | |
| C_c5_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num1, C_c5_second8); \ | |
| C_c6_first8 = _mm256_fmadd_ps(A_ci_first8, B_ri_num2, C_c6_first8); \ | |
| C_c6_second8 = _mm256_fmadd_ps(A_ci_second8, B_ri_num2, C_c6_second8); | |
| void cstore16x5x5x1_rstore28x28x1_chunk_within_depth( | |
| const float * const K, const float * const D, float * O, | |
| const int chunk){ | |
| const int ir = chunk/24; | |
| const int ic = chunk%24; | |
| __m256 num1; | |
| __m256 num2; | |
| __m256 K_c_first8, K_c_second8; | |
| __m256 C_c1_first8 = _mm256_load_ps(O) ; __m256 C_c1_second8 = _mm256_load_ps(O+8) ; | |
| __m256 C_c2_first8 = _mm256_load_ps(O + 16) ; __m256 C_c2_second8 = _mm256_load_ps(O+24) ; | |
| __m256 C_c3_first8 = _mm256_load_ps(O + 32) ; __m256 C_c3_second8 = _mm256_load_ps(O+40) ; | |
| __m256 C_c4_first8 = _mm256_load_ps(O + 48) ; __m256 C_c4_second8 = _mm256_load_ps(O+56) ; | |
| __m256 C_c5_first8 = _mm256_load_ps(O + 64) ; __m256 C_c5_second8 = _mm256_load_ps(O+72) ; | |
| __m256 C_c6_first8 = _mm256_load_ps(O + 80) ; __m256 C_c6_second8 = _mm256_load_ps(O+88) ; | |
| const float * K_c = K; | |
| for(int kr=0;kr<5;kr++){ | |
| for(int kc=0;kc<5;kc++){ | |
| K_c_first8 = _mm256_load_ps(K_c); | |
| K_c_second8 = _mm256_load_ps(K_c+8); | |
| const float * _d = &D[ic+kc + (ir+kr)*28]; | |
| num1 = _mm256_broadcast_ss(_d); | |
| num2 = _mm256_broadcast_ss(_d+1); | |
| C_c1_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c1_first8); | |
| C_c1_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c1_second8); | |
| C_c2_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c2_first8); | |
| C_c2_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c2_second8); | |
| num1 = _mm256_broadcast_ss(_d+2); | |
| num2 = _mm256_broadcast_ss(_d+3); | |
| C_c3_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c3_first8); | |
| C_c3_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c3_second8); | |
| C_c4_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c4_first8); | |
| C_c4_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c4_second8); | |
| num1 = _mm256_broadcast_ss(_d+4); | |
| num2 = _mm256_broadcast_ss(_d+5); | |
| C_c5_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c5_first8); | |
| C_c5_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c5_second8); | |
| C_c6_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c6_first8); | |
| C_c6_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c6_second8); | |
| K_c += 16; // next column of the kernel | |
| } | |
| } | |
| _mm256_store_ps(O, C_c1_first8); | |
| _mm256_store_ps(O+8, C_c1_second8); | |
| _mm256_store_ps(O+16, C_c2_first8); | |
| _mm256_store_ps(O+24, C_c2_second8); | |
| _mm256_store_ps(O+32, C_c3_first8); | |
| _mm256_store_ps(O+40, C_c3_second8); | |
| _mm256_store_ps(O+48, C_c4_first8); | |
| _mm256_store_ps(O+56, C_c4_second8); | |
| _mm256_store_ps(O+64, C_c5_first8); | |
| _mm256_store_ps(O+72, C_c5_second8); | |
| _mm256_store_ps(O+80, C_c6_first8); | |
| _mm256_store_ps(O+88, C_c6_second8); | |
| } | |
| void cstore16x5x5x96_rstore28x28x96_CONV_within_depth( | |
| const float * const K, const float * const D, float * O){ | |
| const int oR = (28-5+1); | |
| Timer t; | |
| // D is blocked by two things -- feature map, and every 6 output position, | |
| // First, get the pointer for each feature map | |
| for(int d=0;d<96;d++){ | |
| const float * const D_ = &D[d*28*28]; | |
| const float * const K_ = &K[d*5*5*16]; | |
| for(int i=0;i<oR*oR;i+=6){ | |
| float * const O_ = &O[i*16]; | |
| cstore16x5x5x1_rstore28x28x1_chunk_within_depth(K_, D_, O_, i); | |
| } | |
| } | |
| float elapsed = t.elapsed(); | |
| printf("time: %f seconds\n", elapsed); | |
| printf("time for 256x256: %f seconds\n", elapsed * 256 * (256/16)); | |
| float flop = 16*oR*oR*5*5*96*2; | |
| float gflops = flop/elapsed/1024/1024/1024; | |
| printf("flops: %f \n", flop); | |
| printf("GFlops: %f GFLOPS\n", gflops); | |
| } | |
| void cstore16x5x5x1_rstore28x28x1_chunk( | |
| const float * const K, const float * const D, float * O, | |
| const int chunk){ | |
| const int ir = chunk/24; | |
| const int ic = chunk%24; | |
| __m256 num1; | |
| __m256 num2; | |
| __m256 K_c_first8, K_c_second8; | |
| __m256 C_c1_first8 = _mm256_load_ps(O) ; __m256 C_c1_second8 = _mm256_load_ps(O+8) ; | |
| __m256 C_c2_first8 = _mm256_load_ps(O + 16) ; __m256 C_c2_second8 = _mm256_load_ps(O+24) ; | |
| __m256 C_c3_first8 = _mm256_load_ps(O + 32) ; __m256 C_c3_second8 = _mm256_load_ps(O+40) ; | |
| __m256 C_c4_first8 = _mm256_load_ps(O + 48) ; __m256 C_c4_second8 = _mm256_load_ps(O+56) ; | |
| __m256 C_c5_first8 = _mm256_load_ps(O + 64) ; __m256 C_c5_second8 = _mm256_load_ps(O+72) ; | |
| __m256 C_c6_first8 = _mm256_load_ps(O + 80) ; __m256 C_c6_second8 = _mm256_load_ps(O+88) ; | |
| const float * K_c = K; | |
| for(int depth=0;depth<96;depth++){ | |
| for(int kr=0;kr<5;kr++){ | |
| for(int kc=0;kc<5;kc++){ | |
| K_c_first8 = _mm256_load_ps(K_c); | |
| K_c_second8 = _mm256_load_ps(K_c+8); | |
| __builtin_prefetch(K_c+16, 0); | |
| __builtin_prefetch(K_c+24, 0); | |
| const float * _d = &D[ic+kc + (ir+kr)*28 + depth*28*28]; | |
| num1 = _mm256_broadcast_ss(_d); | |
| num2 = _mm256_broadcast_ss(_d+1); | |
| C_c1_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c1_first8); | |
| C_c1_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c1_second8); | |
| C_c2_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c2_first8); | |
| C_c2_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c2_second8); | |
| num1 = _mm256_broadcast_ss(_d+2); | |
| num2 = _mm256_broadcast_ss(_d+3); | |
| C_c3_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c3_first8); | |
| C_c3_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c3_second8); | |
| C_c4_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c4_first8); | |
| C_c4_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c4_second8); | |
| num1 = _mm256_broadcast_ss(_d+4); | |
| num2 = _mm256_broadcast_ss(_d+5); | |
| C_c5_first8 = _mm256_fmadd_ps(K_c_first8, num1, C_c5_first8); | |
| C_c5_second8 = _mm256_fmadd_ps(K_c_second8, num1, C_c5_second8); | |
| C_c6_first8 = _mm256_fmadd_ps(K_c_first8, num2, C_c6_first8); | |
| C_c6_second8 = _mm256_fmadd_ps(K_c_second8, num2, C_c6_second8); | |
| K_c += 16; // next column of the kernel | |
| } | |
| } | |
| } | |
| _mm256_store_ps(O, C_c1_first8); | |
| _mm256_store_ps(O+8, C_c1_second8); | |
| _mm256_store_ps(O+16, C_c2_first8); | |
| _mm256_store_ps(O+24, C_c2_second8); | |
| _mm256_store_ps(O+32, C_c3_first8); | |
| _mm256_store_ps(O+40, C_c3_second8); | |
| _mm256_store_ps(O+48, C_c4_first8); | |
| _mm256_store_ps(O+56, C_c4_second8); | |
| _mm256_store_ps(O+64, C_c5_first8); | |
| _mm256_store_ps(O+72, C_c5_second8); | |
| _mm256_store_ps(O+80, C_c6_first8); | |
| _mm256_store_ps(O+88, C_c6_second8); | |
| } | |
| void cstore16x5x5x96_rstore28x28x96_CONV( | |
| const float * const K, const float * const D, float * O){ | |
| const int oR = (28-5+1); | |
| Timer t; | |
| // D is blocked by two things -- feature map, and every 6 output position, | |
| // However, it turns out the we do not need to block by feature map | |
| // compare with this function with the variant cstore16x5x5x96_rstore28x28x96_CONV_within_depth | |
| const float * const D_ = D; | |
| const float * const K_ = K; | |
| for(int i=0;i<oR*oR;i+=6){ | |
| float * const O_ = &O[i*16]; | |
| cstore16x5x5x1_rstore28x28x1_chunk(K_, D_, O_, i); | |
| } | |
| float elapsed = t.elapsed(); | |
| printf("time: %f seconds\n", elapsed); | |
| printf("time for 256x256: %f seconds\n", elapsed * 256 * (256/16)); | |
| float flop = 16*oR*oR*5*5*96*2; | |
| float gflops = flop/elapsed/1024/1024/1024; | |
| printf("flops: %f \n", flop); | |
| printf("GFlops: %f GFLOPS\n", gflops); | |
| float io = 1.0 * (16*5*5*96+28*28*96+oR*oR*16) * sizeof(float); | |
| float gb = 1.0* io / elapsed / 1024/1024/1024; | |
| printf("GB: %f GB/s\n", gb); | |
| } | |
| void cstore5x5_rstore28x28_CONV( | |
| const float * const K, const float * const D, float * O, const int depth, const int oD){ | |
| } | |
| inline void cstore16x4_rstore4x6_gemm( | |
| const float * const A, | |
| const float * const B, | |
| float * const C){ | |
| __m256 C_c1_first8 ; __m256 C_c1_second8 ; | |
| __m256 C_c2_first8 ; __m256 C_c2_second8 ; | |
| __m256 C_c3_first8 ; __m256 C_c3_second8 ; | |
| __m256 C_c4_first8 ; __m256 C_c4_second8 ; | |
| __m256 C_c5_first8 ; __m256 C_c5_second8 ; | |
| __m256 C_c6_first8 ; __m256 C_c6_second8 ; | |
| __m256 A_ci_first8; | |
| __m256 A_ci_second8; | |
| __m256 B_ri_num1; | |
| __m256 B_ri_num2; | |
| for(int i=0;i<BLOCKSIZE;i++){ | |
| cstore16x1_rstore1x6_gemm(A+i*16, B+i*6); | |
| } | |
| _mm256_store_ps(C, C_c1_first8); | |
| _mm256_store_ps(C+8, C_c1_second8); | |
| _mm256_store_ps(C+16, C_c2_first8); | |
| _mm256_store_ps(C+24, C_c2_second8); | |
| _mm256_store_ps(C+32, C_c3_first8); | |
| _mm256_store_ps(C+40, C_c3_second8); | |
| _mm256_store_ps(C+48, C_c4_first8); | |
| _mm256_store_ps(C+56, C_c4_second8); | |
| _mm256_store_ps(C+64, C_c5_first8); | |
| _mm256_store_ps(C+72, C_c5_second8); | |
| _mm256_store_ps(C+80, C_c6_first8); | |
| _mm256_store_ps(C+88, C_c6_second8); | |
| } | |
| void test_throughput(){ | |
| const int N_ABLOCK = 1; | |
| const int N_BBLOCK = N_ABLOCK; | |
| const int N_CBLOCK = N_ABLOCK; | |
| float * const _A = (float*) _mm_malloc(16*BLOCKSIZE*sizeof(float) * N_ABLOCK, 32); | |
| float * const _B = (float*) _mm_malloc(BLOCKSIZE*6*sizeof(float) * N_BBLOCK, 32); | |
| float * const _C = (float*) _mm_malloc(16*6*sizeof(float) * N_CBLOCK, 32); | |
| for(int i=0;i<16*BLOCKSIZE;i++){ | |
| _A[i] = i % 100; | |
| } | |
| for(int i=0;i<BLOCKSIZE*6;i++){ | |
| _B[i] = i * 100; | |
| } | |
| Timer t; | |
| float * A = _A; | |
| float * B = _B; | |
| float * C = _C; | |
| const int inc_a = 1; | |
| const int inc_b = 1; | |
| const int inc_c = 1; | |
| for(int i=0;i<N_ABLOCK;i++){ | |
| //cstore16x4_rstore4x6_gemm( | |
| // A, A+16, A+32, A+48, | |
| // B, B+6, A+12, B+18, | |
| // C, C+16, C+32, C+48, C+64, C+80); | |
| cstore16x4_rstore4x6_gemm(A, B, C); | |
| //cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, 16, 6, BLOCKSIZE, | |
| // 1.0, A, 16, B, 6, 1.0, C, 16); | |
| //cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 16, 6, BLOCKSIZE, | |
| // 1.0, A, 16, B, BLOCKSIZE, 1.0, C, 16); | |
| A += inc_a * 16*4; | |
| B += inc_b * 4*6; | |
| C += inc_c * 16*6; | |
| } | |
| double elapsed = t.elapsed(); | |
| printf("Elapsed: %f\n seconds\n", elapsed); | |
| float flop = 1.0* N_ABLOCK * 16 * BLOCKSIZE * 6 * 2; | |
| float gflops = flop/elapsed/1024/1024/1024; | |
| printf("GFlops: %f GFLOPS\n", gflops); | |
| float io = 1.0 * N_ABLOCK * (16*BLOCKSIZE*inc_a + BLOCKSIZE*6*inc_b + 16*6*inc_c) * sizeof(float); | |
| float gb = 1.0* io / elapsed / 1024/1024/1024; | |
| printf("GB: %f GB/s\n", gb); | |
| printf("C=\n"); | |
| for(int r=0;r<16;r++){ | |
| for(int c=0;c<6;c++){ | |
| printf("%f ", _C[r+c*16]); | |
| } | |
| printf("\n"); | |
| } | |
| } | |
| void test_16x6(){ | |
| float * const A = (float*) _mm_malloc(16*4*sizeof(float), 32); | |
| for(int i=0;i<16*4;i++){ | |
| A[i] = i; | |
| } | |
| float * const B = (float*) _mm_malloc(4*6*sizeof(float), 32); | |
| for(int i=0;i<4*6;i++){ | |
| B[i] = i; | |
| } | |
| float * const C = (float*) _mm_malloc(16*6*sizeof(float), 32); | |
| /* | |
| cstore16x4_rstore4x6_gemm( | |
| A, | |
| B, | |
| C); | |
| */ | |
| //cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, 16, 6, 4, | |
| // 1.0, A, 16, B, 6, 1.0, C, 16); | |
| printf("A=\n"); | |
| for(int r=0;r<16;r++){ | |
| for(int c=0;c<4;c++){ | |
| printf("%f ", A[r+c*16]); | |
| } | |
| printf("\n"); | |
| } | |
| printf("B=\n"); | |
| for(int r=0;r<4;r++){ | |
| for(int c=0;c<6;c++){ | |
| printf("%f ", B[c+r*6]); | |
| } | |
| printf("\n"); | |
| } | |
| printf("C=\n"); | |
| for(int r=0;r<16;r++){ | |
| for(int c=0;c<6;c++){ | |
| printf("%f ", C[r+c*16]); | |
| } | |
| printf("\n"); | |
| } | |
| /* | |
| for(int i=0;i<16*6;i++){ | |
| printf("%f ", C[i]); | |
| } | |
| printf("\n"); | |
| */ | |
| } | |
| void test_blas_speed_16x6(){ | |
| const int M = 1000; | |
| const int N = M; | |
| const int K = M; | |
| float * const A = (float*) _mm_malloc(M*N*sizeof(float), 32); | |
| for(int i=0;i<16*4;i++){ | |
| A[i] = i; | |
| } | |
| float * const B = (float*) _mm_malloc(N*K*sizeof(float), 32); | |
| for(int i=0;i<4*6;i++){ | |
| B[i] = i; | |
| } | |
| float * const C = (float*) _mm_malloc(N*K*sizeof(float), 32); | |
| Timer t; | |
| cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, M, N, K, | |
| 1.0, A, M, B, M, 1.0, C, M); | |
| float elapsed = t.elapsed(); | |
| printf("Elapsed: %f\n seconds\n", elapsed); | |
| float flop = 1.0* M * N * K * 2; | |
| float gflops = flop/elapsed/1024/1024/1024; | |
| printf("GFlops: %f GFLOPS\n", gflops); | |
| printf("A=\n"); | |
| for(int r=0;r<16;r++){ | |
| for(int c=0;c<4;c++){ | |
| printf("%f ", A[r+c*16]); | |
| } | |
| printf("\n"); | |
| } | |
| printf("B=\n"); | |
| for(int r=0;r<4;r++){ | |
| for(int c=0;c<6;c++){ | |
| printf("%f ", B[c+r*6]); | |
| } | |
| printf("\n"); | |
| } | |
| printf("C=\n"); | |
| for(int r=0;r<16;r++){ | |
| for(int c=0;c<6;c++){ | |
| printf("%f ", C[r+c*16]); | |
| } | |
| printf("\n"); | |
| } | |
| /* | |
| for(int i=0;i<16*6;i++){ | |
| printf("%f ", C[i]); | |
| } | |
| printf("\n"); | |
| */ | |
| } | |
| void cstore16x5x5x96_rstore28x28x96_baseline( | |
| const float * const K, const float * const D, float * O){ | |
| const int oR = (28-5+1); | |
| const int oC = (28-5+1); | |
| for(int omap=0;omap<16;omap++){ | |
| for(int outr=0;outr<oR;outr++){ | |
| for(int outc=0;outc<oC;outc++){ | |
| float sum = 0.0; | |
| for(int kr=0;kr<5;kr++){ | |
| for(int kc=0;kc<5;kc++){ | |
| for(int depth=0;depth<96;depth++){ | |
| float d = D[depth*28*28 + (outr+kr)*28 + (outc+kc)]; | |
| float k = K[omap + (depth*5*5+kr*5+kc)*16]; | |
| sum += d*k; | |
| } | |
| } | |
| } | |
| O[omap + (outr*oR+outc)*16] = sum; | |
| } | |
| } | |
| } | |
| } | |
| int main(int argc, char** argv){ | |
| //cstore16x4_rstore4x6_gemm(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, | |
| // NULL, NULL, NULL, NULL, NULL, NULL); | |
| openblas_set_num_threads(1); | |
| //test_16x6(); | |
| //test_blas_speed_16x6(); | |
| //test_throughput(); | |
| float * const K = (float*) _mm_malloc(16*5*5*96*sizeof(float), 32); | |
| for(int i=0;i<16*5*5*96;i++){ | |
| K[i] = i%7; | |
| } | |
| float * const D = (float*) _mm_malloc(28*28*96*sizeof(float), 32); | |
| for(int i=0;i<28*28*96;i++){ | |
| D[i] = i%7; | |
| } | |
| float * const O = (float*) _mm_malloc(16*24*24*sizeof(float), 32); | |
| cstore16x5x5x96_rstore28x28x96_CONV(K, D, O); | |
| //cstore16x5x5x96_rstore28x28x96_baseline(K,D,O); | |
| /* | |
| for(int i=0;i<16;i++){ | |
| for(int j=0;j<24*24;j++){ | |
| printf("%f ", O[i*24*24+j]); | |
| } | |
| printf("\n"); | |
| } | |
| */ | |
| return 0; | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment