benoitjacob@desk:~$ cc -DM0=8 -DK0=4 -DN0=8 -DMMT4D_VARIANT_GENERIC mmt4dkernels.c -o /tmp/mmt4dkernels && /tmp/mmt4dkernels
called kernel: mmt4d_kernel_generic_8x4x8
57 30 63 45 27 45 33 42 
60 27 60 51 24 39 33 60 
60 24 66 45 36 42 36 33 
54 30 54 51 24 60 18 48 
69 45 66 57 18 54 39 48 
33 18 42 39 18 30 12 36 
72 39 63 63 30 66 39 51 
63 36 57 66 36 72 39 51 
benoitjacob@desk:~$ ~/android-ndk-r21d/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android30-clang -march=armv8.2-a+dotprod -DM0=8 -DK0=4 -DN0=8 -DMMT4D_VARIANT_NEON_DOTPROD mmt4dkernels.c -o /tmp/mmt4dkernels && adb push /tmp/mmt4dkernels /data/local/tmp && adb shell /data/local/tmp/mmt4dkernels
/tmp/mmt4dkernels: 1 file pushed, 0 skipped. 47.5 MB/s (12472 bytes in 0.000s)
called kernel: mmt4d_kernel_neon_dotprod_8x4x8
57 30 63 45 27 45 33 42 
60 27 60 51 24 39 33 60 
60 24 66 45 36 42 36 33 
54 30 54 51 24 60 18 48 
69 45 66 57 18 54 39 48 
33 18 42 39 18 30 12 36 
72 39 63 63 30 66 39 51 
63 36 57 66 36 72 39 51 
          Last active
          November 6, 2021 01:50 
        
      - 
      
 - 
        
Save bjacob/2965a199e83ca0171e7812f7acdd62e1 to your computer and use it in GitHub Desktop.  
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | #ifdef MMT4D_VARIANT_GENERIC | |
| #define VARIANT generic | |
| #elif defined MMT4D_VARIANT_NEON_DOTPROD | |
| #define VARIANT neon_dotprod | |
| #else | |
| #error need to pick a variant | |
| #endif | |
| #define __SYMBOL_NAME(_VARIANT, _M0, _K0, _N0) \ | |
| mmt4d_kernel_##_VARIANT##_##_M0##x##_K0##x##_N0 | |
| #define _SYMBOL_NAME(_VARIANT, _M0, _K0, _N0) \ | |
| __SYMBOL_NAME(_VARIANT, _M0, _K0, _N0) | |
| #define SYMBOL_NAME _SYMBOL_NAME(VARIANT, M0, K0, N0) | |
| #include <stdint.h> | |
| #include <string.h> | |
| #ifdef __aarch64__ | |
| void impl_neon_dotprod_kernel_8x4x8(int k_size, const int8_t* lhs, | |
| const int8_t* rhs, int32_t* dst); | |
| #endif | |
| void SYMBOL_NAME(int k_size, const int8_t* lhs, const int8_t* rhs, | |
| int32_t* dst) { | |
| #if (defined MMT4D_VARIANT_NEON_DOTPROD) && M0 == 8 && K0 == 4 && N0 == 8 | |
| impl_neon_dotprod_kernel_8x4x8(k_size, lhs, rhs, dst); | |
| #else | |
| int32_t acc[M0 * N0] = {0}; | |
| for (int k = 0; k < k_size; k += K0) { | |
| for (int m0 = 0; m0 < M0; m0++) { | |
| for (int n0 = 0; n0 < N0; n0++) { | |
| int32_t a = 0; | |
| for (int k0 = 0; k0 < K0; k0++) { | |
| a += lhs[m0 * K0 + k0] * rhs[n0 * K0 + k0]; | |
| } | |
| acc[m0 * N0 + n0] += a; | |
| } | |
| } | |
| lhs += M0 * K0; | |
| rhs += N0 * K0; | |
| } | |
| memcpy(dst, acc, M0 * N0 * sizeof(*dst)); | |
| #endif | |
| } | |
| #ifdef __aarch64__ | |
| #include <arm_neon.h> | |
| void impl_neon_dotprod_kernel_8x4x8(int k_size, const int8_t* lhs, | |
| const int8_t* rhs, int32_t* dst) { | |
| int32x4_t acc0 = vdupq_n_s32(0); | |
| int32x4_t acc1 = vdupq_n_s32(0); | |
| int32x4_t acc2 = vdupq_n_s32(0); | |
| int32x4_t acc3 = vdupq_n_s32(0); | |
| int32x4_t acc4 = vdupq_n_s32(0); | |
| int32x4_t acc5 = vdupq_n_s32(0); | |
| int32x4_t acc6 = vdupq_n_s32(0); | |
| int32x4_t acc7 = vdupq_n_s32(0); | |
| int32x4_t acc8 = vdupq_n_s32(0); | |
| int32x4_t acc9 = vdupq_n_s32(0); | |
| int32x4_t acc10 = vdupq_n_s32(0); | |
| int32x4_t acc11 = vdupq_n_s32(0); | |
| int32x4_t acc12 = vdupq_n_s32(0); | |
| int32x4_t acc13 = vdupq_n_s32(0); | |
| int32x4_t acc14 = vdupq_n_s32(0); | |
| int32x4_t acc15 = vdupq_n_s32(0); | |
| for (int k = 0; k < k_size; k += 4) { | |
| int8x16_t lhs0 = vld1q_s8(lhs + 0); | |
| int8x16_t lhs4 = vld1q_s8(lhs + 16); | |
| int8x16_t rhs0 = vld1q_s8(rhs + 0); | |
| int8x16_t rhs4 = vld1q_s8(rhs + 16); | |
| acc0 = vdotq_lane_s32(acc0, rhs0, vget_low_s8(lhs0), 0); | |
| acc1 = vdotq_lane_s32(acc1, rhs4, vget_low_s8(lhs0), 0); | |
| acc2 = vdotq_lane_s32(acc2, rhs0, vget_low_s8(lhs0), 1); | |
| acc3 = vdotq_lane_s32(acc3, rhs4, vget_low_s8(lhs0), 1); | |
| acc4 = vdotq_lane_s32(acc4, rhs0, vget_high_s8(lhs0), 0); | |
| acc5 = vdotq_lane_s32(acc5, rhs4, vget_high_s8(lhs0), 0); | |
| acc6 = vdotq_lane_s32(acc6, rhs0, vget_high_s8(lhs0), 1); | |
| acc7 = vdotq_lane_s32(acc7, rhs4, vget_high_s8(lhs0), 1); | |
| acc8 = vdotq_lane_s32(acc8, rhs0, vget_low_s8(lhs4), 0); | |
| acc9 = vdotq_lane_s32(acc9, rhs4, vget_low_s8(lhs4), 0); | |
| acc10 = vdotq_lane_s32(acc10, rhs0, vget_low_s8(lhs4), 1); | |
| acc11 = vdotq_lane_s32(acc11, rhs4, vget_low_s8(lhs4), 1); | |
| acc12 = vdotq_lane_s32(acc12, rhs0, vget_high_s8(lhs4), 0); | |
| acc13 = vdotq_lane_s32(acc13, rhs4, vget_high_s8(lhs4), 0); | |
| acc14 = vdotq_lane_s32(acc14, rhs0, vget_high_s8(lhs4), 1); | |
| acc15 = vdotq_lane_s32(acc15, rhs4, vget_high_s8(lhs4), 1); | |
| lhs += 8 * 4; | |
| rhs += 8 * 4; | |
| } | |
| vst1q_s32(dst + 0, acc0); | |
| vst1q_s32(dst + 4, acc1); | |
| vst1q_s32(dst + 8, acc2); | |
| vst1q_s32(dst + 12, acc3); | |
| vst1q_s32(dst + 16, acc4); | |
| vst1q_s32(dst + 20, acc5); | |
| vst1q_s32(dst + 24, acc6); | |
| vst1q_s32(dst + 28, acc7); | |
| vst1q_s32(dst + 32, acc8); | |
| vst1q_s32(dst + 36, acc9); | |
| vst1q_s32(dst + 40, acc10); | |
| vst1q_s32(dst + 44, acc11); | |
| vst1q_s32(dst + 48, acc12); | |
| vst1q_s32(dst + 52, acc13); | |
| vst1q_s32(dst + 56, acc14); | |
| vst1q_s32(dst + 60, acc15); | |
| } | |
| #endif | |
| #include <stdio.h> | |
| #include <stdlib.h> | |
| uint32_t dummy_random() { | |
| static uint32_t state = 0; | |
| state = (state * 123 + 456) % 321; | |
| return state; | |
| } | |
| int main(int argc, char* argv[]) { | |
| const int k_size = 4 * K0; | |
| int8_t* lhs = malloc(k_size * M0); | |
| for (int i = 0; i < k_size * M0; i++) { | |
| lhs[i] = dummy_random() % 5; | |
| } | |
| int8_t* rhs = malloc(k_size * N0); | |
| for (int i = 0; i < k_size * N0; i++) { | |
| rhs[i] = dummy_random() % 6; | |
| } | |
| int32_t* dst = malloc(M0 * N0 * sizeof(int32_t)); | |
| SYMBOL_NAME(k_size, lhs, rhs, dst); | |
| #define STR(s) STR_UNEXPANDED(s) | |
| #define STR_UNEXPANDED(s) #s | |
| printf("called kernel: " STR(SYMBOL_NAME) "\n"); | |
| for (int m = 0; m < M0; m++) { | |
| for (int n = 0; n < N0; n++) { | |
| printf("%d ", dst[m * N0 + n]); | |
| } | |
| printf("\n"); | |
| } | |
| } | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment