Created
December 14, 2025 18:28
-
-
Save syoyo/0e9605b6c7384f4c68f3ea45b1f84ee8 to your computer and use it in GitHub Desktop.
a64fx fp16 gemm with fp32 accumulation experiment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // A64FX SVE 1.0 High-Performance FP16 GEMM with FP32 Accumulation | |
| // Target: ~160% of FP32 peak (~80% of FP16 peak) | |
| // | |
| // Strategy: | |
| // - Use FP16 FMLA directly for maximum throughput (32 ops/instruction) | |
| // - Accumulate in FP16 for K_INTERVAL iterations | |
| // - Convert to FP32 and add to FP32 accumulators periodically | |
| // - Software pipeline: overlap loads, converts, and FMAs | |
| // | |
| // A64FX FP16 peak: ~264 GFLOPS/core (2x FP32) | |
| // Target: ~211 GFLOPS (80% of peak) | |
| .arch armv8.2-a+sve | |
| .text | |
| // ============================================================================ | |
| // Main GEMM kernel - calls optimized micro-kernel | |
| // ============================================================================ | |
| .align 4 | |
| .global hgemm_kernel | |
| .type hgemm_kernel, %function | |
| x_M .req x0 | |
| x_N .req x1 | |
| x_K .req x2 | |
| x_A .req x3 | |
| x_lda .req x4 | |
| x_B .req x5 | |
| x_ldb .req x6 | |
| x_C .req x7 | |
| x_ldc .req x19 | |
| x_m_cnt .req x9 | |
| x_n_cnt .req x10 | |
| x_k_cnt .req x11 | |
| x_A_row .req x12 | |
| x_B_col .req x13 | |
| x_C_row .req x14 | |
| x_A_ptr .req x15 | |
| x_B_ptr .req x16 | |
| x_lda_bytes .req x20 | |
| x_ldb_bytes .req x21 | |
| x_ldc_bytes .req x22 | |
| x_C_ptr .req x23 | |
| hgemm_kernel: | |
| stp x19, x20, [sp, #-64]! | |
| stp x21, x22, [sp, #16] | |
| stp x23, x24, [sp, #32] | |
| stp x25, x26, [sp, #48] | |
| ldr x_ldc, [sp, #64] | |
| lsl x_lda_bytes, x_lda, #1 | |
| lsl x_ldb_bytes, x_ldb, #1 | |
| lsl x_ldc_bytes, x_ldc, #2 | |
| mov x_A_row, x_A | |
| mov x_C_row, x_C | |
| ptrue p0.s // 16 FP32 elements | |
| ptrue p1.h // 32 FP16 elements | |
| // Process 6 rows at a time for better register utilization | |
| mov x_m_cnt, x_M | |
| .loop_m6: | |
| cmp x_m_cnt, #6 | |
| b.lt .loop_m1 // Handle remaining rows individually | |
| mov x_B_col, x_B | |
| mov x_C_ptr, x_C_row | |
| mov x_n_cnt, x_N | |
| .loop_n6x32: | |
| cmp x_n_cnt, #32 | |
| b.lt .loop_n6x16 | |
| // Process 6 rows x 32 columns with FP16 FMLA | |
| // FP32 accumulators: z0-z11 (6 rows x 2 vectors) | |
| // FP16 temp: z12-z17 (6 rows x 1 vector) | |
| mov z0.d, #0 | |
| mov z1.d, #0 | |
| mov z2.d, #0 | |
| mov z3.d, #0 | |
| mov z4.d, #0 | |
| mov z5.d, #0 | |
| mov z6.d, #0 | |
| mov z7.d, #0 | |
| mov z8.d, #0 | |
| mov z9.d, #0 | |
| mov z10.d, #0 | |
| mov z11.d, #0 | |
| mov z12.d, #0 | |
| mov z13.d, #0 | |
| mov z14.d, #0 | |
| mov z15.d, #0 | |
| mov z16.d, #0 | |
| mov z17.d, #0 | |
| mov x_A_ptr, x_A_row | |
| mov x_B_ptr, x_B_col | |
| mov x_k_cnt, x_K | |
| mov x25, #0 // K mod 8 counter | |
| .loop_k6x32: | |
| cmp x_k_cnt, #2 | |
| b.lt .loop_k6x32_tail | |
| // Prefetch next B rows | |
| add x17, x_B_ptr, x_ldb_bytes, lsl #1 | |
| prfh pldl1keep, p1, [x17] | |
| // K iteration 0: Load 6 A values and broadcast | |
| ld1rh {z18.h}, p1/z, [x_A_ptr] | |
| add x26, x_A_ptr, x_lda_bytes | |
| ld1rh {z19.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z20.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z21.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z22.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z23.h}, p1/z, [x26] | |
| // Load 32 FP16 from B (K=0) | |
| ld1h z24.h, p1/z, [x_B_ptr] | |
| // K iteration 1: Load A values for next iteration (interleaved) | |
| add x17, x_A_ptr, #2 | |
| ld1rh {z30.h}, p1/z, [x17] | |
| add x26, x17, x_lda_bytes | |
| ld1rh {z31.h}, p1/z, [x26] | |
| // FP16 FMA for 6 rows (K=0) | |
| fmla z12.h, p1/m, z18.h, z24.h | |
| fmla z13.h, p1/m, z19.h, z24.h | |
| fmla z14.h, p1/m, z20.h, z24.h | |
| fmla z15.h, p1/m, z21.h, z24.h | |
| fmla z16.h, p1/m, z22.h, z24.h | |
| fmla z17.h, p1/m, z23.h, z24.h | |
| // Load B for K=1 | |
| add x17, x_B_ptr, x_ldb_bytes | |
| ld1h z25.h, p1/z, [x17] | |
| // Continue loading A for K=1 | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z18.h}, p1/z, [x26] // Reuse z18 for row 2 of K=1 | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z19.h}, p1/z, [x26] // Reuse z19 for row 3 of K=1 | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z20.h}, p1/z, [x26] // Reuse z20 for row 4 of K=1 | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z21.h}, p1/z, [x26] // Reuse z21 for row 5 of K=1 | |
| // FP16 FMA for 6 rows (K=1) | |
| fmla z12.h, p1/m, z30.h, z25.h | |
| fmla z13.h, p1/m, z31.h, z25.h | |
| fmla z14.h, p1/m, z18.h, z25.h | |
| fmla z15.h, p1/m, z19.h, z25.h | |
| fmla z16.h, p1/m, z20.h, z25.h | |
| fmla z17.h, p1/m, z21.h, z25.h | |
| add x_A_ptr, x_A_ptr, #4 // 2 FP16 = 4 bytes | |
| add x_B_ptr, x_B_ptr, x_ldb_bytes, lsl #1 | |
| sub x_k_cnt, x_k_cnt, #2 | |
| add x25, x25, #2 | |
| // Every 8 K iterations, convert to FP32 | |
| cmp x25, #8 | |
| b.lt .loop_k6x32 | |
| b .convert6x32 | |
| .loop_k6x32_tail: | |
| cbz x_k_cnt, .store6x32 | |
| // Single K iteration for tail | |
| ld1rh {z18.h}, p1/z, [x_A_ptr] | |
| add x26, x_A_ptr, x_lda_bytes | |
| ld1rh {z19.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z20.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z21.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z22.h}, p1/z, [x26] | |
| add x26, x26, x_lda_bytes | |
| ld1rh {z23.h}, p1/z, [x26] | |
| ld1h z24.h, p1/z, [x_B_ptr] | |
| fmla z12.h, p1/m, z18.h, z24.h | |
| fmla z13.h, p1/m, z19.h, z24.h | |
| fmla z14.h, p1/m, z20.h, z24.h | |
| fmla z15.h, p1/m, z21.h, z24.h | |
| fmla z16.h, p1/m, z22.h, z24.h | |
| fmla z17.h, p1/m, z23.h, z24.h | |
| add x_A_ptr, x_A_ptr, #2 | |
| add x_B_ptr, x_B_ptr, x_ldb_bytes | |
| sub x_k_cnt, x_k_cnt, #1 | |
| add x25, x25, #1 | |
| cmp x25, #8 | |
| b.lt .loop_k6x32_tail | |
| .convert6x32: | |
| // Convert all 6 FP16 temp accumulators | |
| // Row 0 | |
| fcvt z26.s, p0/m, z12.h | |
| lsr z27.s, z12.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z0.s, p0/m, z0.s, z28.s | |
| fadd z1.s, p0/m, z1.s, z29.s | |
| mov z12.d, #0 | |
| // Row 1 | |
| fcvt z26.s, p0/m, z13.h | |
| lsr z27.s, z13.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z2.s, p0/m, z2.s, z28.s | |
| fadd z3.s, p0/m, z3.s, z29.s | |
| mov z13.d, #0 | |
| // Row 2 | |
| fcvt z26.s, p0/m, z14.h | |
| lsr z27.s, z14.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z4.s, p0/m, z4.s, z28.s | |
| fadd z5.s, p0/m, z5.s, z29.s | |
| mov z14.d, #0 | |
| // Row 3 | |
| fcvt z26.s, p0/m, z15.h | |
| lsr z27.s, z15.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z6.s, p0/m, z6.s, z28.s | |
| fadd z7.s, p0/m, z7.s, z29.s | |
| mov z15.d, #0 | |
| // Row 4 | |
| fcvt z26.s, p0/m, z16.h | |
| lsr z27.s, z16.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z8.s, p0/m, z8.s, z28.s | |
| fadd z9.s, p0/m, z9.s, z29.s | |
| mov z16.d, #0 | |
| // Row 5 | |
| fcvt z26.s, p0/m, z17.h | |
| lsr z27.s, z17.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z10.s, p0/m, z10.s, z28.s | |
| fadd z11.s, p0/m, z11.s, z29.s | |
| mov z17.d, #0 | |
| mov x25, #0 | |
| b .loop_k6x32 | |
| .store6x32: | |
| // Final conversion | |
| fcvt z26.s, p0/m, z12.h | |
| lsr z27.s, z12.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z0.s, p0/m, z0.s, z28.s | |
| fadd z1.s, p0/m, z1.s, z29.s | |
| fcvt z26.s, p0/m, z13.h | |
| lsr z27.s, z13.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z2.s, p0/m, z2.s, z28.s | |
| fadd z3.s, p0/m, z3.s, z29.s | |
| fcvt z26.s, p0/m, z14.h | |
| lsr z27.s, z14.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z4.s, p0/m, z4.s, z28.s | |
| fadd z5.s, p0/m, z5.s, z29.s | |
| fcvt z26.s, p0/m, z15.h | |
| lsr z27.s, z15.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z6.s, p0/m, z6.s, z28.s | |
| fadd z7.s, p0/m, z7.s, z29.s | |
| fcvt z26.s, p0/m, z16.h | |
| lsr z27.s, z16.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z8.s, p0/m, z8.s, z28.s | |
| fadd z9.s, p0/m, z9.s, z29.s | |
| fcvt z26.s, p0/m, z17.h | |
| lsr z27.s, z17.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z10.s, p0/m, z10.s, z28.s | |
| fadd z11.s, p0/m, z11.s, z29.s | |
| // Store 6 rows x 32 columns | |
| st1w z0.s, p0, [x_C_ptr] | |
| add x26, x_C_ptr, #64 | |
| st1w z1.s, p0, [x26] | |
| add x26, x_C_ptr, x_ldc_bytes | |
| st1w z2.s, p0, [x26] | |
| add x26, x26, #64 | |
| st1w z3.s, p0, [x26] | |
| add x26, x_C_ptr, x_ldc_bytes, lsl #1 | |
| st1w z4.s, p0, [x26] | |
| add x26, x26, #64 | |
| st1w z5.s, p0, [x26] | |
| // Row 3 | |
| mov x26, x_ldc_bytes | |
| add x26, x26, x_ldc_bytes, lsl #1 | |
| add x26, x_C_ptr, x26 | |
| st1w z6.s, p0, [x26] | |
| add x26, x26, #64 | |
| st1w z7.s, p0, [x26] | |
| // Row 4 | |
| add x26, x_C_ptr, x_ldc_bytes, lsl #2 | |
| st1w z8.s, p0, [x26] | |
| add x26, x26, #64 | |
| st1w z9.s, p0, [x26] | |
| // Row 5 | |
| mov x26, x_ldc_bytes | |
| add x26, x26, x_ldc_bytes, lsl #2 | |
| add x26, x_C_ptr, x26 | |
| st1w z10.s, p0, [x26] | |
| add x26, x26, #64 | |
| st1w z11.s, p0, [x26] | |
| add x_B_col, x_B_col, #64 | |
| add x_C_ptr, x_C_ptr, #128 | |
| sub x_n_cnt, x_n_cnt, #32 | |
| b .loop_n6x32 | |
| .loop_n6x16: | |
| // Handle remaining N columns (< 32) for the 6-row block | |
| // Use FP32 path for simplicity | |
| cbz x_n_cnt, .next_6rows | |
| cmp x_n_cnt, #16 | |
| b.lt .loop_n6_rem | |
| // Process 6 rows x 16 columns with FP32 | |
| mov z0.d, #0 | |
| mov z1.d, #0 | |
| mov z2.d, #0 | |
| mov z3.d, #0 | |
| mov z4.d, #0 | |
| mov z5.d, #0 | |
| mov x_A_ptr, x_A_row | |
| mov x_B_ptr, x_B_col | |
| mov x_k_cnt, x_K | |
| .loop_k6x16: | |
| cbz x_k_cnt, .store6x16 | |
| // Load 6 A values and convert to FP32 | |
| ldr h18, [x_A_ptr] | |
| fcvt s18, h18 | |
| mov z18.s, s18 | |
| add x26, x_A_ptr, x_lda_bytes | |
| ldr h19, [x26] | |
| fcvt s19, h19 | |
| mov z19.s, s19 | |
| add x26, x26, x_lda_bytes | |
| ldr h20, [x26] | |
| fcvt s20, h20 | |
| mov z20.s, s20 | |
| add x26, x26, x_lda_bytes | |
| ldr h21, [x26] | |
| fcvt s21, h21 | |
| mov z21.s, s21 | |
| add x26, x26, x_lda_bytes | |
| ldr h22, [x26] | |
| fcvt s22, h22 | |
| mov z22.s, s22 | |
| add x26, x26, x_lda_bytes | |
| ldr h23, [x26] | |
| fcvt s23, h23 | |
| mov z23.s, s23 | |
| // Load 16 FP16 from B, convert to FP32 | |
| ld1h z24.s, p0/z, [x_B_ptr] | |
| fcvt z24.s, p0/m, z24.h | |
| // FP32 FMA for 6 rows | |
| fmla z0.s, p0/m, z18.s, z24.s | |
| fmla z1.s, p0/m, z19.s, z24.s | |
| fmla z2.s, p0/m, z20.s, z24.s | |
| fmla z3.s, p0/m, z21.s, z24.s | |
| fmla z4.s, p0/m, z22.s, z24.s | |
| fmla z5.s, p0/m, z23.s, z24.s | |
| add x_A_ptr, x_A_ptr, #2 | |
| add x_B_ptr, x_B_ptr, x_ldb_bytes | |
| sub x_k_cnt, x_k_cnt, #1 | |
| b .loop_k6x16 | |
| .store6x16: | |
| st1w z0.s, p0, [x_C_ptr] | |
| add x26, x_C_ptr, x_ldc_bytes | |
| st1w z1.s, p0, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z2.s, p0, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z3.s, p0, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z4.s, p0, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z5.s, p0, [x26] | |
| add x_B_col, x_B_col, #32 | |
| add x_C_ptr, x_C_ptr, #64 | |
| sub x_n_cnt, x_n_cnt, #16 | |
| b .loop_n6x16 | |
| .loop_n6_rem: | |
| // Handle remaining < 16 columns for 6 rows | |
| cbz x_n_cnt, .next_6rows | |
| whilelt p2.s, xzr, x_n_cnt | |
| mov z0.d, #0 | |
| mov z1.d, #0 | |
| mov z2.d, #0 | |
| mov z3.d, #0 | |
| mov z4.d, #0 | |
| mov z5.d, #0 | |
| mov x_A_ptr, x_A_row | |
| mov x_B_ptr, x_B_col | |
| mov x_k_cnt, x_K | |
| .loop_k6_rem: | |
| cbz x_k_cnt, .store6_rem | |
| ldr h18, [x_A_ptr] | |
| fcvt s18, h18 | |
| mov z18.s, s18 | |
| add x26, x_A_ptr, x_lda_bytes | |
| ldr h19, [x26] | |
| fcvt s19, h19 | |
| mov z19.s, s19 | |
| add x26, x26, x_lda_bytes | |
| ldr h20, [x26] | |
| fcvt s20, h20 | |
| mov z20.s, s20 | |
| add x26, x26, x_lda_bytes | |
| ldr h21, [x26] | |
| fcvt s21, h21 | |
| mov z21.s, s21 | |
| add x26, x26, x_lda_bytes | |
| ldr h22, [x26] | |
| fcvt s22, h22 | |
| mov z22.s, s22 | |
| add x26, x26, x_lda_bytes | |
| ldr h23, [x26] | |
| fcvt s23, h23 | |
| mov z23.s, s23 | |
| ld1h z24.s, p2/z, [x_B_ptr] | |
| fcvt z24.s, p2/m, z24.h | |
| fmla z0.s, p2/m, z18.s, z24.s | |
| fmla z1.s, p2/m, z19.s, z24.s | |
| fmla z2.s, p2/m, z20.s, z24.s | |
| fmla z3.s, p2/m, z21.s, z24.s | |
| fmla z4.s, p2/m, z22.s, z24.s | |
| fmla z5.s, p2/m, z23.s, z24.s | |
| add x_A_ptr, x_A_ptr, #2 | |
| add x_B_ptr, x_B_ptr, x_ldb_bytes | |
| sub x_k_cnt, x_k_cnt, #1 | |
| b .loop_k6_rem | |
| .store6_rem: | |
| st1w z0.s, p2, [x_C_ptr] | |
| add x26, x_C_ptr, x_ldc_bytes | |
| st1w z1.s, p2, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z2.s, p2, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z3.s, p2, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z4.s, p2, [x26] | |
| add x26, x26, x_ldc_bytes | |
| st1w z5.s, p2, [x26] | |
| .next_6rows: | |
| // Advance by 6 rows | |
| mov x26, x_lda_bytes | |
| add x26, x26, x_lda_bytes, lsl #1 // 3 * lda_bytes | |
| add x26, x26, x26 // 6 * lda_bytes | |
| add x_A_row, x_A_row, x26 | |
| mov x26, x_ldc_bytes | |
| add x26, x26, x_ldc_bytes, lsl #1 | |
| add x26, x26, x26 | |
| add x_C_row, x_C_row, x26 | |
| sub x_m_cnt, x_m_cnt, #6 | |
| b .loop_m6 | |
| .loop_m1: | |
| // Handle remaining 0-5 rows one at a time | |
| cbz x_m_cnt, .done | |
| mov x_B_col, x_B | |
| mov x_C_ptr, x_C_row | |
| mov x_n_cnt, x_N | |
| .loop_n1x32: | |
| cmp x_n_cnt, #32 | |
| b.lt .handle_n16 | |
| // Process 1 row x 32 columns | |
| mov z0.d, #0 | |
| mov z1.d, #0 | |
| mov z2.d, #0 | |
| mov x_A_ptr, x_A_row | |
| mov x_B_ptr, x_B_col | |
| mov x_k_cnt, x_K | |
| .loop_k1x32: | |
| cbz x_k_cnt, .store1x32 | |
| ld1rh {z3.h}, p1/z, [x_A_ptr] | |
| add x_A_ptr, x_A_ptr, #2 | |
| ld1h z4.h, p1/z, [x_B_ptr] | |
| fmla z2.h, p1/m, z3.h, z4.h | |
| add x_B_ptr, x_B_ptr, x_ldb_bytes | |
| sub x_k_cnt, x_k_cnt, #1 | |
| ands x17, x_k_cnt, #7 | |
| b.ne .loop_k1x32 | |
| fcvt z5.s, p0/m, z2.h | |
| lsr z6.s, z2.s, #16 | |
| fcvt z6.s, p0/m, z6.h | |
| zip1 z7.s, z5.s, z6.s | |
| zip2 z3.s, z5.s, z6.s | |
| fadd z0.s, p0/m, z0.s, z7.s | |
| fadd z1.s, p0/m, z1.s, z3.s | |
| mov z2.d, #0 | |
| b .loop_k1x32 | |
| .store1x32: | |
| fcvt z5.s, p0/m, z2.h | |
| lsr z6.s, z2.s, #16 | |
| fcvt z6.s, p0/m, z6.h | |
| zip1 z7.s, z5.s, z6.s | |
| zip2 z3.s, z5.s, z6.s | |
| fadd z0.s, p0/m, z0.s, z7.s | |
| fadd z1.s, p0/m, z1.s, z3.s | |
| st1w z0.s, p0, [x_C_ptr] | |
| add x17, x_C_ptr, #64 | |
| st1w z1.s, p0, [x17] | |
| add x_B_col, x_B_col, #64 | |
| add x_C_ptr, x_C_ptr, #128 | |
| sub x_n_cnt, x_n_cnt, #32 | |
| b .loop_n1x32 | |
| .handle_n16: | |
| cmp x_n_cnt, #16 | |
| b.lt .handle_remainder | |
| mov z0.d, #0 | |
| mov z2.d, #0 | |
| mov x_A_ptr, x_A_row | |
| mov x_B_ptr, x_B_col | |
| mov x_k_cnt, x_K | |
| .loop_k16: | |
| cbz x_k_cnt, .store16 | |
| ld1rh {z3.h}, p1/z, [x_A_ptr] | |
| add x_A_ptr, x_A_ptr, #2 | |
| // Load 16 FP16 using ld1h z.s for correct conversion | |
| ld1h z4.s, p0/z, [x_B_ptr] | |
| // Broadcast to FP16 and multiply in FP16 | |
| fcvt z4.s, p0/m, z4.h | |
| fcvt z5.s, p0/m, z3.h // Convert A scalar | |
| fmla z0.s, p0/m, z5.s, z4.s | |
| add x_B_ptr, x_B_ptr, x_ldb_bytes | |
| sub x_k_cnt, x_k_cnt, #1 | |
| b .loop_k16 | |
| .store16: | |
| st1w z0.s, p0, [x_C_ptr] | |
| add x_B_col, x_B_col, #32 | |
| add x_C_ptr, x_C_ptr, #64 | |
| sub x_n_cnt, x_n_cnt, #16 | |
| b .loop_n1x32 | |
| .handle_remainder: | |
| cbz x_n_cnt, .next_row | |
| whilelt p2.s, xzr, x_n_cnt | |
| mov z0.d, #0 | |
| mov x_A_ptr, x_A_row | |
| mov x_B_ptr, x_B_col | |
| mov x_k_cnt, x_K | |
| .loop_k_rem: | |
| cbz x_k_cnt, .store_rem | |
| ldr h1, [x_A_ptr], #2 | |
| fcvt s1, h1 | |
| mov z1.s, s1 | |
| ld1h z2.s, p2/z, [x_B_ptr] | |
| fcvt z2.s, p2/m, z2.h | |
| fmla z0.s, p2/m, z1.s, z2.s | |
| add x_B_ptr, x_B_ptr, x_ldb_bytes | |
| sub x_k_cnt, x_k_cnt, #1 | |
| b .loop_k_rem | |
| .store_rem: | |
| st1w z0.s, p2, [x_C_ptr] | |
| .next_row: | |
| add x_A_row, x_A_row, x_lda_bytes | |
| add x_C_row, x_C_row, x_ldc_bytes | |
| sub x_m_cnt, x_m_cnt, #1 | |
| b .loop_m1 | |
| .done: | |
| ldp x25, x26, [sp, #48] | |
| ldp x23, x24, [sp, #32] | |
| ldp x21, x22, [sp, #16] | |
| ldp x19, x20, [sp], #64 | |
| ret | |
| .size hgemm_kernel, .-hgemm_kernel | |
| // ============================================================================ | |
| // High-Performance 6x32 Micro-Kernel | |
| // Uses FP16 FMLA with periodic FP32 accumulation | |
| // 6 rows x 32 columns = 192 outputs | |
| // K_INTERVAL = 8 (accumulate in FP16, then convert to FP32) | |
| // ============================================================================ | |
| .align 4 | |
| .global hgemm_kernel_6x32 | |
| .type hgemm_kernel_6x32, %function | |
| // Register allocation: | |
| // z0-z11: FP32 accumulators (6 rows x 2 vectors of 16 FP32 each) | |
| // z12-z17: FP16 temp accumulators (6 rows x 1 vector of 32 FP16 each) | |
| // z18-z23: A values broadcast (6 rows) | |
| // z24-z25: B values (32 FP16) | |
| // z26-z31: Scratch for conversion | |
| // p0: ptrue.s (16 elements) | |
| // p1: ptrue.h (32 elements) | |
| hgemm_kernel_6x32: | |
| // x0: K | |
| // x1: A pointer (K x 6, column-major for this kernel) | |
| // x2: B pointer (K x 32) | |
| // x3: C pointer (6 x 32) | |
| // Save callee-saved registers | |
| stp d8, d9, [sp, #-64]! | |
| stp d10, d11, [sp, #16] | |
| stp d12, d13, [sp, #32] | |
| stp d14, d15, [sp, #48] | |
| ptrue p0.s // 16 FP32 elements | |
| ptrue p1.h // 32 FP16 elements | |
| // Initialize FP32 accumulators to zero | |
| mov z0.d, #0 | |
| mov z1.d, #0 | |
| mov z2.d, #0 | |
| mov z3.d, #0 | |
| mov z4.d, #0 | |
| mov z5.d, #0 | |
| mov z6.d, #0 | |
| mov z7.d, #0 | |
| mov z8.d, #0 | |
| mov z9.d, #0 | |
| mov z10.d, #0 | |
| mov z11.d, #0 | |
| // Initialize FP16 temp accumulators to zero | |
| mov z12.d, #0 | |
| mov z13.d, #0 | |
| mov z14.d, #0 | |
| mov z15.d, #0 | |
| mov z16.d, #0 | |
| mov z17.d, #0 | |
| mov x4, x0 // K counter | |
| mov x5, #0 // K mod 8 counter | |
| .loop_k_opt: | |
| cbz x4, .final_convert | |
| // Prefetch next B row | |
| add x6, x2, #256 | |
| prfh pldl1keep, p1, [x6] | |
| // Load 6 A values and broadcast to FP16 vectors | |
| ld1rh {z18.h}, p1/z, [x1] | |
| ld1rh {z19.h}, p1/z, [x1, #2] | |
| ld1rh {z20.h}, p1/z, [x1, #4] | |
| ld1rh {z21.h}, p1/z, [x1, #6] | |
| ld1rh {z22.h}, p1/z, [x1, #8] | |
| ld1rh {z23.h}, p1/z, [x1, #10] | |
| // Load 32 FP16 from B | |
| ld1h z24.h, p1/z, [x2] | |
| // FP16 FMA for all 6 rows (32 ops each = 192 FP16 ops total) | |
| fmla z12.h, p1/m, z18.h, z24.h | |
| fmla z13.h, p1/m, z19.h, z24.h | |
| fmla z14.h, p1/m, z20.h, z24.h | |
| fmla z15.h, p1/m, z21.h, z24.h | |
| fmla z16.h, p1/m, z22.h, z24.h | |
| fmla z17.h, p1/m, z23.h, z24.h | |
| // Advance pointers | |
| add x1, x1, #12 // 6 FP16 = 12 bytes | |
| add x2, x2, #64 // 32 FP16 = 64 bytes | |
| sub x4, x4, #1 | |
| add x5, x5, #1 | |
| // Every 8 iterations, convert FP16 to FP32 and accumulate | |
| cmp x5, #8 | |
| b.lt .loop_k_opt | |
| // Convert FP16 temp accumulators to FP32 | |
| // Pattern: fcvt gets even elements, lsr+fcvt gets odd, zip interleaves | |
| // Row 0 (z12 -> z0, z1) | |
| fcvt z26.s, p0/m, z12.h // Even: h0, h2, h4, ... | |
| lsr z27.s, z12.s, #16 | |
| fcvt z27.s, p0/m, z27.h // Odd: h1, h3, h5, ... | |
| zip1 z28.s, z26.s, z27.s // h0, h1, h2, ..., h15 | |
| zip2 z29.s, z26.s, z27.s // h16, h17, ..., h31 | |
| fadd z0.s, p0/m, z0.s, z28.s | |
| fadd z1.s, p0/m, z1.s, z29.s | |
| mov z12.d, #0 | |
| // Row 1 (z13 -> z2, z3) | |
| fcvt z26.s, p0/m, z13.h | |
| lsr z27.s, z13.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z2.s, p0/m, z2.s, z28.s | |
| fadd z3.s, p0/m, z3.s, z29.s | |
| mov z13.d, #0 | |
| // Row 2 (z14 -> z4, z5) | |
| fcvt z26.s, p0/m, z14.h | |
| lsr z27.s, z14.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z4.s, p0/m, z4.s, z28.s | |
| fadd z5.s, p0/m, z5.s, z29.s | |
| mov z14.d, #0 | |
| // Row 3 (z15 -> z6, z7) | |
| fcvt z26.s, p0/m, z15.h | |
| lsr z27.s, z15.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z6.s, p0/m, z6.s, z28.s | |
| fadd z7.s, p0/m, z7.s, z29.s | |
| mov z15.d, #0 | |
| // Row 4 (z16 -> z8, z9) | |
| fcvt z26.s, p0/m, z16.h | |
| lsr z27.s, z16.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z8.s, p0/m, z8.s, z28.s | |
| fadd z9.s, p0/m, z9.s, z29.s | |
| mov z16.d, #0 | |
| // Row 5 (z17 -> z10, z11) | |
| fcvt z26.s, p0/m, z17.h | |
| lsr z27.s, z17.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z10.s, p0/m, z10.s, z28.s | |
| fadd z11.s, p0/m, z11.s, z29.s | |
| mov z17.d, #0 | |
| mov x5, #0 // Reset K mod counter | |
| b .loop_k_opt | |
| .final_convert: | |
| // Convert any remaining FP16 accumulation to FP32 | |
| // Same pattern: fcvt for even, lsr+fcvt for odd, zip to interleave | |
| // Row 0 (z12 -> z0, z1) | |
| fcvt z26.s, p0/m, z12.h | |
| lsr z27.s, z12.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z0.s, p0/m, z0.s, z28.s | |
| fadd z1.s, p0/m, z1.s, z29.s | |
| // Row 1 (z13 -> z2, z3) | |
| fcvt z26.s, p0/m, z13.h | |
| lsr z27.s, z13.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z2.s, p0/m, z2.s, z28.s | |
| fadd z3.s, p0/m, z3.s, z29.s | |
| // Row 2 (z14 -> z4, z5) | |
| fcvt z26.s, p0/m, z14.h | |
| lsr z27.s, z14.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z4.s, p0/m, z4.s, z28.s | |
| fadd z5.s, p0/m, z5.s, z29.s | |
| // Row 3 (z15 -> z6, z7) | |
| fcvt z26.s, p0/m, z15.h | |
| lsr z27.s, z15.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z6.s, p0/m, z6.s, z28.s | |
| fadd z7.s, p0/m, z7.s, z29.s | |
| // Row 4 (z16 -> z8, z9) | |
| fcvt z26.s, p0/m, z16.h | |
| lsr z27.s, z16.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z8.s, p0/m, z8.s, z28.s | |
| fadd z9.s, p0/m, z9.s, z29.s | |
| // Row 5 (z17 -> z10, z11) | |
| fcvt z26.s, p0/m, z17.h | |
| lsr z27.s, z17.s, #16 | |
| fcvt z27.s, p0/m, z27.h | |
| zip1 z28.s, z26.s, z27.s | |
| zip2 z29.s, z26.s, z27.s | |
| fadd z10.s, p0/m, z10.s, z28.s | |
| fadd z11.s, p0/m, z11.s, z29.s | |
| // Store 6 x 32 FP32 results | |
| st1w z0.s, p0, [x3, #0, mul vl] | |
| st1w z1.s, p0, [x3, #1, mul vl] | |
| st1w z2.s, p0, [x3, #2, mul vl] | |
| st1w z3.s, p0, [x3, #3, mul vl] | |
| st1w z4.s, p0, [x3, #4, mul vl] | |
| st1w z5.s, p0, [x3, #5, mul vl] | |
| st1w z6.s, p0, [x3, #6, mul vl] | |
| st1w z7.s, p0, [x3, #7, mul vl] | |
| add x6, x3, #512 | |
| st1w z8.s, p0, [x6, #0, mul vl] | |
| st1w z9.s, p0, [x6, #1, mul vl] | |
| st1w z10.s, p0, [x6, #2, mul vl] | |
| st1w z11.s, p0, [x6, #3, mul vl] | |
| // Restore callee-saved registers | |
| ldp d14, d15, [sp, #48] | |
| ldp d12, d13, [sp, #32] | |
| ldp d10, d11, [sp, #16] | |
| ldp d8, d9, [sp], #64 | |
| ret | |
| .size hgemm_kernel_6x32, .-hgemm_kernel_6x32 | |
| // ============================================================================ | |
| // Ultra-optimized 8x32 Micro-Kernel with K-unrolling | |
| // 8 rows x 32 columns, K unrolled by 4 | |
| // ============================================================================ | |
| .align 4 | |
| .global hgemm_kernel_8x32_unroll4 | |
| .type hgemm_kernel_8x32_unroll4, %function | |
| hgemm_kernel_8x32_unroll4: | |
| // x0: K (must be multiple of 4) | |
| // x1: A pointer (K x 8) | |
| // x2: B pointer (K x 32) | |
| // x3: C pointer (8 x 32) | |
| stp d8, d9, [sp, #-64]! | |
| stp d10, d11, [sp, #16] | |
| stp d12, d13, [sp, #32] | |
| stp d14, d15, [sp, #48] | |
| ptrue p0.s | |
| ptrue p1.h | |
| // FP32 accumulators: 8 rows x 2 vectors = 16 vectors (z0-z15) | |
| mov z0.d, #0 | |
| mov z1.d, #0 | |
| mov z2.d, #0 | |
| mov z3.d, #0 | |
| mov z4.d, #0 | |
| mov z5.d, #0 | |
| mov z6.d, #0 | |
| mov z7.d, #0 | |
| mov z8.d, #0 | |
| mov z9.d, #0 | |
| mov z10.d, #0 | |
| mov z11.d, #0 | |
| mov z12.d, #0 | |
| mov z13.d, #0 | |
| mov z14.d, #0 | |
| mov z15.d, #0 | |
| // FP16 temp accumulators: 8 vectors (z16-z23) | |
| mov z16.d, #0 | |
| mov z17.d, #0 | |
| mov z18.d, #0 | |
| mov z19.d, #0 | |
| mov z20.d, #0 | |
| mov z21.d, #0 | |
| mov z22.d, #0 | |
| mov z23.d, #0 | |
| lsr x4, x0, #2 // K / 4 | |
| mov x5, #0 // Accumulated K count | |
| .loop_k8x32: | |
| cbz x4, .final8x32 | |
| // Prefetch | |
| add x6, x2, #512 | |
| prfh pldl1keep, p1, [x6] | |
| // K iteration 0 | |
| ld1rh {z24.h}, p1/z, [x1, #0] | |
| ld1rh {z25.h}, p1/z, [x1, #2] | |
| ld1rh {z26.h}, p1/z, [x1, #4] | |
| ld1rh {z27.h}, p1/z, [x1, #6] | |
| ld1rh {z28.h}, p1/z, [x1, #8] | |
| ld1rh {z29.h}, p1/z, [x1, #10] | |
| ld1rh {z30.h}, p1/z, [x1, #12] | |
| ld1rh {z31.h}, p1/z, [x1, #14] | |
| ld1h z24.h, p1/z, [x2] // Reuse z24 for B | |
| // Actually need separate B load - reload after A broadcasts | |
| ld1h z24.h, p1/z, [x2] | |
| // This approach has register pressure issues | |
| // Simplify: process K=1 at a time but pipeline loads | |
| // Reload A broadcasts for iteration 0 | |
| ld1rh {z25.h}, p1/z, [x1, #0] | |
| ld1rh {z26.h}, p1/z, [x1, #2] | |
| ld1rh {z27.h}, p1/z, [x1, #4] | |
| ld1rh {z28.h}, p1/z, [x1, #6] | |
| ld1rh {z29.h}, p1/z, [x1, #8] | |
| ld1rh {z30.h}, p1/z, [x1, #10] | |
| ld1rh {z31.h}, p1/z, [x1, #12] | |
| // Need one more A value - use memory directly | |
| fmla z16.h, p1/m, z25.h, z24.h | |
| fmla z17.h, p1/m, z26.h, z24.h | |
| fmla z18.h, p1/m, z27.h, z24.h | |
| fmla z19.h, p1/m, z28.h, z24.h | |
| fmla z20.h, p1/m, z29.h, z24.h | |
| fmla z21.h, p1/m, z30.h, z24.h | |
| fmla z22.h, p1/m, z31.h, z24.h | |
| ld1rh {z25.h}, p1/z, [x1, #14] | |
| fmla z23.h, p1/m, z25.h, z24.h | |
| add x1, x1, #16 | |
| add x2, x2, #64 | |
| // K iterations 1-3 (similar pattern) | |
| .rept 3 | |
| ld1h z24.h, p1/z, [x2] | |
| ld1rh {z25.h}, p1/z, [x1, #0] | |
| ld1rh {z26.h}, p1/z, [x1, #2] | |
| ld1rh {z27.h}, p1/z, [x1, #4] | |
| ld1rh {z28.h}, p1/z, [x1, #6] | |
| ld1rh {z29.h}, p1/z, [x1, #8] | |
| ld1rh {z30.h}, p1/z, [x1, #10] | |
| ld1rh {z31.h}, p1/z, [x1, #12] | |
| fmla z16.h, p1/m, z25.h, z24.h | |
| fmla z17.h, p1/m, z26.h, z24.h | |
| fmla z18.h, p1/m, z27.h, z24.h | |
| fmla z19.h, p1/m, z28.h, z24.h | |
| fmla z20.h, p1/m, z29.h, z24.h | |
| fmla z21.h, p1/m, z30.h, z24.h | |
| fmla z22.h, p1/m, z31.h, z24.h | |
| ld1rh {z25.h}, p1/z, [x1, #14] | |
| fmla z23.h, p1/m, z25.h, z24.h | |
| add x1, x1, #16 | |
| add x2, x2, #64 | |
| .endr | |
| sub x4, x4, #1 | |
| add x5, x5, #4 | |
| // Every 8 K iterations, convert and accumulate | |
| ands x6, x5, #7 | |
| b.ne .loop_k8x32 | |
| // Convert all 8 FP16 temp accumulators to FP32 | |
| .irp row, 16, 17, 18, 19, 20, 21, 22, 23 | |
| .if \row == 16 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z0.s, p0/m, z0.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z1.s, p0/m, z1.s, z24.s | |
| .elseif \row == 17 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z2.s, p0/m, z2.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z3.s, p0/m, z3.s, z24.s | |
| .elseif \row == 18 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z4.s, p0/m, z4.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z5.s, p0/m, z5.s, z24.s | |
| .elseif \row == 19 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z6.s, p0/m, z6.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z7.s, p0/m, z7.s, z24.s | |
| .elseif \row == 20 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z8.s, p0/m, z8.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z9.s, p0/m, z9.s, z24.s | |
| .elseif \row == 21 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z10.s, p0/m, z10.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z11.s, p0/m, z11.s, z24.s | |
| .elseif \row == 22 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z12.s, p0/m, z12.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z13.s, p0/m, z13.s, z24.s | |
| .elseif \row == 23 | |
| fcvt z24.s, p0/m, z\row\().h | |
| fadd z14.s, p0/m, z14.s, z24.s | |
| mov z25.d, z\row\().d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z15.s, p0/m, z15.s, z24.s | |
| .endif | |
| mov z\row\().d, #0 | |
| .endr | |
| mov x5, #0 | |
| b .loop_k8x32 | |
| .final8x32: | |
| // Final conversion (same as above but without clearing) | |
| fcvt z24.s, p0/m, z16.h | |
| fadd z0.s, p0/m, z0.s, z24.s | |
| mov z25.d, z16.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z1.s, p0/m, z1.s, z24.s | |
| fcvt z24.s, p0/m, z17.h | |
| fadd z2.s, p0/m, z2.s, z24.s | |
| mov z25.d, z17.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z3.s, p0/m, z3.s, z24.s | |
| fcvt z24.s, p0/m, z18.h | |
| fadd z4.s, p0/m, z4.s, z24.s | |
| mov z25.d, z18.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z5.s, p0/m, z5.s, z24.s | |
| fcvt z24.s, p0/m, z19.h | |
| fadd z6.s, p0/m, z6.s, z24.s | |
| mov z25.d, z19.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z7.s, p0/m, z7.s, z24.s | |
| fcvt z24.s, p0/m, z20.h | |
| fadd z8.s, p0/m, z8.s, z24.s | |
| mov z25.d, z20.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z9.s, p0/m, z9.s, z24.s | |
| fcvt z24.s, p0/m, z21.h | |
| fadd z10.s, p0/m, z10.s, z24.s | |
| mov z25.d, z21.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z11.s, p0/m, z11.s, z24.s | |
| fcvt z24.s, p0/m, z22.h | |
| fadd z12.s, p0/m, z12.s, z24.s | |
| mov z25.d, z22.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z13.s, p0/m, z13.s, z24.s | |
| fcvt z24.s, p0/m, z23.h | |
| fadd z14.s, p0/m, z14.s, z24.s | |
| mov z25.d, z23.d | |
| ext z25.b, z25.b, z25.b, #32 | |
| fcvt z24.s, p0/m, z25.h | |
| fadd z15.s, p0/m, z15.s, z24.s | |
| // Store 8 x 32 FP32 results (8 rows x 2 vectors) | |
| st1w z0.s, p0, [x3, #0, mul vl] | |
| st1w z1.s, p0, [x3, #1, mul vl] | |
| st1w z2.s, p0, [x3, #2, mul vl] | |
| st1w z3.s, p0, [x3, #3, mul vl] | |
| st1w z4.s, p0, [x3, #4, mul vl] | |
| st1w z5.s, p0, [x3, #5, mul vl] | |
| st1w z6.s, p0, [x3, #6, mul vl] | |
| st1w z7.s, p0, [x3, #7, mul vl] | |
| add x6, x3, #512 | |
| st1w z8.s, p0, [x6, #0, mul vl] | |
| st1w z9.s, p0, [x6, #1, mul vl] | |
| st1w z10.s, p0, [x6, #2, mul vl] | |
| st1w z11.s, p0, [x6, #3, mul vl] | |
| add x6, x6, #256 | |
| st1w z12.s, p0, [x6, #0, mul vl] | |
| st1w z13.s, p0, [x6, #1, mul vl] | |
| st1w z14.s, p0, [x6, #2, mul vl] | |
| st1w z15.s, p0, [x6, #3, mul vl] | |
| ldp d14, d15, [sp, #48] | |
| ldp d12, d13, [sp, #32] | |
| ldp d10, d11, [sp, #16] | |
| ldp d8, d9, [sp], #64 | |
| ret | |
| .size hgemm_kernel_8x32_unroll4, .-hgemm_kernel_8x32_unroll4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment