Skip to content

Instantly share code, notes, and snippets.

@syoyo
Created December 14, 2025 18:28
Show Gist options
  • Select an option

  • Save syoyo/0e9605b6c7384f4c68f3ea45b1f84ee8 to your computer and use it in GitHub Desktop.

Select an option

Save syoyo/0e9605b6c7384f4c68f3ea45b1f84ee8 to your computer and use it in GitHub Desktop.
a64fx fp16 gemm with fp32 accumulation experiment
// A64FX SVE 1.0 High-Performance FP16 GEMM with FP32 Accumulation
// Target: ~160% of FP32 peak (~80% of FP16 peak)
//
// Strategy:
// - Use FP16 FMLA directly for maximum throughput (32 ops/instruction)
// - Accumulate in FP16 for K_INTERVAL iterations
// - Convert to FP32 and add to FP32 accumulators periodically
// - Software pipeline: overlap loads, converts, and FMAs
//
// A64FX FP16 peak: ~264 GFLOPS/core (2x FP32)
// Target: ~211 GFLOPS (80% of peak)
.arch armv8.2-a+sve
.text
// ============================================================================
// Main GEMM kernel - calls optimized micro-kernel
// ============================================================================
.align 4
.global hgemm_kernel
.type hgemm_kernel, %function
x_M .req x0
x_N .req x1
x_K .req x2
x_A .req x3
x_lda .req x4
x_B .req x5
x_ldb .req x6
x_C .req x7
x_ldc .req x19
x_m_cnt .req x9
x_n_cnt .req x10
x_k_cnt .req x11
x_A_row .req x12
x_B_col .req x13
x_C_row .req x14
x_A_ptr .req x15
x_B_ptr .req x16
x_lda_bytes .req x20
x_ldb_bytes .req x21
x_ldc_bytes .req x22
x_C_ptr .req x23
hgemm_kernel:
stp x19, x20, [sp, #-64]!
stp x21, x22, [sp, #16]
stp x23, x24, [sp, #32]
stp x25, x26, [sp, #48]
ldr x_ldc, [sp, #64]
lsl x_lda_bytes, x_lda, #1
lsl x_ldb_bytes, x_ldb, #1
lsl x_ldc_bytes, x_ldc, #2
mov x_A_row, x_A
mov x_C_row, x_C
ptrue p0.s // 16 FP32 elements
ptrue p1.h // 32 FP16 elements
// Process 6 rows at a time for better register utilization
mov x_m_cnt, x_M
.loop_m6:
cmp x_m_cnt, #6
b.lt .loop_m1 // Handle remaining rows individually
mov x_B_col, x_B
mov x_C_ptr, x_C_row
mov x_n_cnt, x_N
.loop_n6x32:
cmp x_n_cnt, #32
b.lt .loop_n6x16
// Process 6 rows x 32 columns with FP16 FMLA
// FP32 accumulators: z0-z11 (6 rows x 2 vectors)
// FP16 temp: z12-z17 (6 rows x 1 vector)
mov z0.d, #0
mov z1.d, #0
mov z2.d, #0
mov z3.d, #0
mov z4.d, #0
mov z5.d, #0
mov z6.d, #0
mov z7.d, #0
mov z8.d, #0
mov z9.d, #0
mov z10.d, #0
mov z11.d, #0
mov z12.d, #0
mov z13.d, #0
mov z14.d, #0
mov z15.d, #0
mov z16.d, #0
mov z17.d, #0
mov x_A_ptr, x_A_row
mov x_B_ptr, x_B_col
mov x_k_cnt, x_K
mov x25, #0 // K mod 8 counter
.loop_k6x32:
cmp x_k_cnt, #2
b.lt .loop_k6x32_tail
// Prefetch next B rows
add x17, x_B_ptr, x_ldb_bytes, lsl #1
prfh pldl1keep, p1, [x17]
// K iteration 0: Load 6 A values and broadcast
ld1rh {z18.h}, p1/z, [x_A_ptr]
add x26, x_A_ptr, x_lda_bytes
ld1rh {z19.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z20.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z21.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z22.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z23.h}, p1/z, [x26]
// Load 32 FP16 from B (K=0)
ld1h z24.h, p1/z, [x_B_ptr]
// K iteration 1: Load A values for next iteration (interleaved)
add x17, x_A_ptr, #2
ld1rh {z30.h}, p1/z, [x17]
add x26, x17, x_lda_bytes
ld1rh {z31.h}, p1/z, [x26]
// FP16 FMA for 6 rows (K=0)
fmla z12.h, p1/m, z18.h, z24.h
fmla z13.h, p1/m, z19.h, z24.h
fmla z14.h, p1/m, z20.h, z24.h
fmla z15.h, p1/m, z21.h, z24.h
fmla z16.h, p1/m, z22.h, z24.h
fmla z17.h, p1/m, z23.h, z24.h
// Load B for K=1
add x17, x_B_ptr, x_ldb_bytes
ld1h z25.h, p1/z, [x17]
// Continue loading A for K=1
add x26, x26, x_lda_bytes
ld1rh {z18.h}, p1/z, [x26] // Reuse z18 for row 2 of K=1
add x26, x26, x_lda_bytes
ld1rh {z19.h}, p1/z, [x26] // Reuse z19 for row 3 of K=1
add x26, x26, x_lda_bytes
ld1rh {z20.h}, p1/z, [x26] // Reuse z20 for row 4 of K=1
add x26, x26, x_lda_bytes
ld1rh {z21.h}, p1/z, [x26] // Reuse z21 for row 5 of K=1
// FP16 FMA for 6 rows (K=1)
fmla z12.h, p1/m, z30.h, z25.h
fmla z13.h, p1/m, z31.h, z25.h
fmla z14.h, p1/m, z18.h, z25.h
fmla z15.h, p1/m, z19.h, z25.h
fmla z16.h, p1/m, z20.h, z25.h
fmla z17.h, p1/m, z21.h, z25.h
add x_A_ptr, x_A_ptr, #4 // 2 FP16 = 4 bytes
add x_B_ptr, x_B_ptr, x_ldb_bytes, lsl #1
sub x_k_cnt, x_k_cnt, #2
add x25, x25, #2
// Every 8 K iterations, convert to FP32
cmp x25, #8
b.lt .loop_k6x32
b .convert6x32
.loop_k6x32_tail:
cbz x_k_cnt, .store6x32
// Single K iteration for tail
ld1rh {z18.h}, p1/z, [x_A_ptr]
add x26, x_A_ptr, x_lda_bytes
ld1rh {z19.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z20.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z21.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z22.h}, p1/z, [x26]
add x26, x26, x_lda_bytes
ld1rh {z23.h}, p1/z, [x26]
ld1h z24.h, p1/z, [x_B_ptr]
fmla z12.h, p1/m, z18.h, z24.h
fmla z13.h, p1/m, z19.h, z24.h
fmla z14.h, p1/m, z20.h, z24.h
fmla z15.h, p1/m, z21.h, z24.h
fmla z16.h, p1/m, z22.h, z24.h
fmla z17.h, p1/m, z23.h, z24.h
add x_A_ptr, x_A_ptr, #2
add x_B_ptr, x_B_ptr, x_ldb_bytes
sub x_k_cnt, x_k_cnt, #1
add x25, x25, #1
cmp x25, #8
b.lt .loop_k6x32_tail
.convert6x32:
// Convert all 6 FP16 temp accumulators
// Row 0
fcvt z26.s, p0/m, z12.h
lsr z27.s, z12.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z0.s, p0/m, z0.s, z28.s
fadd z1.s, p0/m, z1.s, z29.s
mov z12.d, #0
// Row 1
fcvt z26.s, p0/m, z13.h
lsr z27.s, z13.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z2.s, p0/m, z2.s, z28.s
fadd z3.s, p0/m, z3.s, z29.s
mov z13.d, #0
// Row 2
fcvt z26.s, p0/m, z14.h
lsr z27.s, z14.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z4.s, p0/m, z4.s, z28.s
fadd z5.s, p0/m, z5.s, z29.s
mov z14.d, #0
// Row 3
fcvt z26.s, p0/m, z15.h
lsr z27.s, z15.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z6.s, p0/m, z6.s, z28.s
fadd z7.s, p0/m, z7.s, z29.s
mov z15.d, #0
// Row 4
fcvt z26.s, p0/m, z16.h
lsr z27.s, z16.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z8.s, p0/m, z8.s, z28.s
fadd z9.s, p0/m, z9.s, z29.s
mov z16.d, #0
// Row 5
fcvt z26.s, p0/m, z17.h
lsr z27.s, z17.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z10.s, p0/m, z10.s, z28.s
fadd z11.s, p0/m, z11.s, z29.s
mov z17.d, #0
mov x25, #0
b .loop_k6x32
.store6x32:
// Final conversion
fcvt z26.s, p0/m, z12.h
lsr z27.s, z12.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z0.s, p0/m, z0.s, z28.s
fadd z1.s, p0/m, z1.s, z29.s
fcvt z26.s, p0/m, z13.h
lsr z27.s, z13.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z2.s, p0/m, z2.s, z28.s
fadd z3.s, p0/m, z3.s, z29.s
fcvt z26.s, p0/m, z14.h
lsr z27.s, z14.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z4.s, p0/m, z4.s, z28.s
fadd z5.s, p0/m, z5.s, z29.s
fcvt z26.s, p0/m, z15.h
lsr z27.s, z15.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z6.s, p0/m, z6.s, z28.s
fadd z7.s, p0/m, z7.s, z29.s
fcvt z26.s, p0/m, z16.h
lsr z27.s, z16.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z8.s, p0/m, z8.s, z28.s
fadd z9.s, p0/m, z9.s, z29.s
fcvt z26.s, p0/m, z17.h
lsr z27.s, z17.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z10.s, p0/m, z10.s, z28.s
fadd z11.s, p0/m, z11.s, z29.s
// Store 6 rows x 32 columns
st1w z0.s, p0, [x_C_ptr]
add x26, x_C_ptr, #64
st1w z1.s, p0, [x26]
add x26, x_C_ptr, x_ldc_bytes
st1w z2.s, p0, [x26]
add x26, x26, #64
st1w z3.s, p0, [x26]
add x26, x_C_ptr, x_ldc_bytes, lsl #1
st1w z4.s, p0, [x26]
add x26, x26, #64
st1w z5.s, p0, [x26]
// Row 3
mov x26, x_ldc_bytes
add x26, x26, x_ldc_bytes, lsl #1
add x26, x_C_ptr, x26
st1w z6.s, p0, [x26]
add x26, x26, #64
st1w z7.s, p0, [x26]
// Row 4
add x26, x_C_ptr, x_ldc_bytes, lsl #2
st1w z8.s, p0, [x26]
add x26, x26, #64
st1w z9.s, p0, [x26]
// Row 5
mov x26, x_ldc_bytes
add x26, x26, x_ldc_bytes, lsl #2
add x26, x_C_ptr, x26
st1w z10.s, p0, [x26]
add x26, x26, #64
st1w z11.s, p0, [x26]
add x_B_col, x_B_col, #64
add x_C_ptr, x_C_ptr, #128
sub x_n_cnt, x_n_cnt, #32
b .loop_n6x32
.loop_n6x16:
// Handle remaining N columns (< 32) for the 6-row block
// Use FP32 path for simplicity
cbz x_n_cnt, .next_6rows
cmp x_n_cnt, #16
b.lt .loop_n6_rem
// Process 6 rows x 16 columns with FP32
mov z0.d, #0
mov z1.d, #0
mov z2.d, #0
mov z3.d, #0
mov z4.d, #0
mov z5.d, #0
mov x_A_ptr, x_A_row
mov x_B_ptr, x_B_col
mov x_k_cnt, x_K
.loop_k6x16:
cbz x_k_cnt, .store6x16
// Load 6 A values and convert to FP32
ldr h18, [x_A_ptr]
fcvt s18, h18
mov z18.s, s18
add x26, x_A_ptr, x_lda_bytes
ldr h19, [x26]
fcvt s19, h19
mov z19.s, s19
add x26, x26, x_lda_bytes
ldr h20, [x26]
fcvt s20, h20
mov z20.s, s20
add x26, x26, x_lda_bytes
ldr h21, [x26]
fcvt s21, h21
mov z21.s, s21
add x26, x26, x_lda_bytes
ldr h22, [x26]
fcvt s22, h22
mov z22.s, s22
add x26, x26, x_lda_bytes
ldr h23, [x26]
fcvt s23, h23
mov z23.s, s23
// Load 16 FP16 from B, convert to FP32
ld1h z24.s, p0/z, [x_B_ptr]
fcvt z24.s, p0/m, z24.h
// FP32 FMA for 6 rows
fmla z0.s, p0/m, z18.s, z24.s
fmla z1.s, p0/m, z19.s, z24.s
fmla z2.s, p0/m, z20.s, z24.s
fmla z3.s, p0/m, z21.s, z24.s
fmla z4.s, p0/m, z22.s, z24.s
fmla z5.s, p0/m, z23.s, z24.s
add x_A_ptr, x_A_ptr, #2
add x_B_ptr, x_B_ptr, x_ldb_bytes
sub x_k_cnt, x_k_cnt, #1
b .loop_k6x16
.store6x16:
st1w z0.s, p0, [x_C_ptr]
add x26, x_C_ptr, x_ldc_bytes
st1w z1.s, p0, [x26]
add x26, x26, x_ldc_bytes
st1w z2.s, p0, [x26]
add x26, x26, x_ldc_bytes
st1w z3.s, p0, [x26]
add x26, x26, x_ldc_bytes
st1w z4.s, p0, [x26]
add x26, x26, x_ldc_bytes
st1w z5.s, p0, [x26]
add x_B_col, x_B_col, #32
add x_C_ptr, x_C_ptr, #64
sub x_n_cnt, x_n_cnt, #16
b .loop_n6x16
.loop_n6_rem:
// Handle remaining < 16 columns for 6 rows
cbz x_n_cnt, .next_6rows
whilelt p2.s, xzr, x_n_cnt
mov z0.d, #0
mov z1.d, #0
mov z2.d, #0
mov z3.d, #0
mov z4.d, #0
mov z5.d, #0
mov x_A_ptr, x_A_row
mov x_B_ptr, x_B_col
mov x_k_cnt, x_K
.loop_k6_rem:
cbz x_k_cnt, .store6_rem
ldr h18, [x_A_ptr]
fcvt s18, h18
mov z18.s, s18
add x26, x_A_ptr, x_lda_bytes
ldr h19, [x26]
fcvt s19, h19
mov z19.s, s19
add x26, x26, x_lda_bytes
ldr h20, [x26]
fcvt s20, h20
mov z20.s, s20
add x26, x26, x_lda_bytes
ldr h21, [x26]
fcvt s21, h21
mov z21.s, s21
add x26, x26, x_lda_bytes
ldr h22, [x26]
fcvt s22, h22
mov z22.s, s22
add x26, x26, x_lda_bytes
ldr h23, [x26]
fcvt s23, h23
mov z23.s, s23
ld1h z24.s, p2/z, [x_B_ptr]
fcvt z24.s, p2/m, z24.h
fmla z0.s, p2/m, z18.s, z24.s
fmla z1.s, p2/m, z19.s, z24.s
fmla z2.s, p2/m, z20.s, z24.s
fmla z3.s, p2/m, z21.s, z24.s
fmla z4.s, p2/m, z22.s, z24.s
fmla z5.s, p2/m, z23.s, z24.s
add x_A_ptr, x_A_ptr, #2
add x_B_ptr, x_B_ptr, x_ldb_bytes
sub x_k_cnt, x_k_cnt, #1
b .loop_k6_rem
.store6_rem:
st1w z0.s, p2, [x_C_ptr]
add x26, x_C_ptr, x_ldc_bytes
st1w z1.s, p2, [x26]
add x26, x26, x_ldc_bytes
st1w z2.s, p2, [x26]
add x26, x26, x_ldc_bytes
st1w z3.s, p2, [x26]
add x26, x26, x_ldc_bytes
st1w z4.s, p2, [x26]
add x26, x26, x_ldc_bytes
st1w z5.s, p2, [x26]
.next_6rows:
// Advance by 6 rows
mov x26, x_lda_bytes
add x26, x26, x_lda_bytes, lsl #1 // 3 * lda_bytes
add x26, x26, x26 // 6 * lda_bytes
add x_A_row, x_A_row, x26
mov x26, x_ldc_bytes
add x26, x26, x_ldc_bytes, lsl #1
add x26, x26, x26
add x_C_row, x_C_row, x26
sub x_m_cnt, x_m_cnt, #6
b .loop_m6
.loop_m1:
// Handle remaining 0-5 rows one at a time
cbz x_m_cnt, .done
mov x_B_col, x_B
mov x_C_ptr, x_C_row
mov x_n_cnt, x_N
.loop_n1x32:
cmp x_n_cnt, #32
b.lt .handle_n16
// Process 1 row x 32 columns
mov z0.d, #0
mov z1.d, #0
mov z2.d, #0
mov x_A_ptr, x_A_row
mov x_B_ptr, x_B_col
mov x_k_cnt, x_K
.loop_k1x32:
cbz x_k_cnt, .store1x32
ld1rh {z3.h}, p1/z, [x_A_ptr]
add x_A_ptr, x_A_ptr, #2
ld1h z4.h, p1/z, [x_B_ptr]
fmla z2.h, p1/m, z3.h, z4.h
add x_B_ptr, x_B_ptr, x_ldb_bytes
sub x_k_cnt, x_k_cnt, #1
ands x17, x_k_cnt, #7
b.ne .loop_k1x32
fcvt z5.s, p0/m, z2.h
lsr z6.s, z2.s, #16
fcvt z6.s, p0/m, z6.h
zip1 z7.s, z5.s, z6.s
zip2 z3.s, z5.s, z6.s
fadd z0.s, p0/m, z0.s, z7.s
fadd z1.s, p0/m, z1.s, z3.s
mov z2.d, #0
b .loop_k1x32
.store1x32:
fcvt z5.s, p0/m, z2.h
lsr z6.s, z2.s, #16
fcvt z6.s, p0/m, z6.h
zip1 z7.s, z5.s, z6.s
zip2 z3.s, z5.s, z6.s
fadd z0.s, p0/m, z0.s, z7.s
fadd z1.s, p0/m, z1.s, z3.s
st1w z0.s, p0, [x_C_ptr]
add x17, x_C_ptr, #64
st1w z1.s, p0, [x17]
add x_B_col, x_B_col, #64
add x_C_ptr, x_C_ptr, #128
sub x_n_cnt, x_n_cnt, #32
b .loop_n1x32
.handle_n16:
cmp x_n_cnt, #16
b.lt .handle_remainder
mov z0.d, #0
mov z2.d, #0
mov x_A_ptr, x_A_row
mov x_B_ptr, x_B_col
mov x_k_cnt, x_K
.loop_k16:
cbz x_k_cnt, .store16
ld1rh {z3.h}, p1/z, [x_A_ptr]
add x_A_ptr, x_A_ptr, #2
// Load 16 FP16 using ld1h z.s for correct conversion
ld1h z4.s, p0/z, [x_B_ptr]
// Broadcast to FP16 and multiply in FP16
fcvt z4.s, p0/m, z4.h
fcvt z5.s, p0/m, z3.h // Convert A scalar
fmla z0.s, p0/m, z5.s, z4.s
add x_B_ptr, x_B_ptr, x_ldb_bytes
sub x_k_cnt, x_k_cnt, #1
b .loop_k16
.store16:
st1w z0.s, p0, [x_C_ptr]
add x_B_col, x_B_col, #32
add x_C_ptr, x_C_ptr, #64
sub x_n_cnt, x_n_cnt, #16
b .loop_n1x32
.handle_remainder:
cbz x_n_cnt, .next_row
whilelt p2.s, xzr, x_n_cnt
mov z0.d, #0
mov x_A_ptr, x_A_row
mov x_B_ptr, x_B_col
mov x_k_cnt, x_K
.loop_k_rem:
cbz x_k_cnt, .store_rem
ldr h1, [x_A_ptr], #2
fcvt s1, h1
mov z1.s, s1
ld1h z2.s, p2/z, [x_B_ptr]
fcvt z2.s, p2/m, z2.h
fmla z0.s, p2/m, z1.s, z2.s
add x_B_ptr, x_B_ptr, x_ldb_bytes
sub x_k_cnt, x_k_cnt, #1
b .loop_k_rem
.store_rem:
st1w z0.s, p2, [x_C_ptr]
.next_row:
add x_A_row, x_A_row, x_lda_bytes
add x_C_row, x_C_row, x_ldc_bytes
sub x_m_cnt, x_m_cnt, #1
b .loop_m1
.done:
ldp x25, x26, [sp, #48]
ldp x23, x24, [sp, #32]
ldp x21, x22, [sp, #16]
ldp x19, x20, [sp], #64
ret
.size hgemm_kernel, .-hgemm_kernel
// ============================================================================
// High-Performance 6x32 Micro-Kernel
// Uses FP16 FMLA with periodic FP32 accumulation
// 6 rows x 32 columns = 192 outputs
// K_INTERVAL = 8 (accumulate in FP16, then convert to FP32)
// ============================================================================
.align 4
.global hgemm_kernel_6x32
.type hgemm_kernel_6x32, %function
// Register allocation:
// z0-z11: FP32 accumulators (6 rows x 2 vectors of 16 FP32 each)
// z12-z17: FP16 temp accumulators (6 rows x 1 vector of 32 FP16 each)
// z18-z23: A values broadcast (6 rows)
// z24-z25: B values (32 FP16)
// z26-z31: Scratch for conversion
// p0: ptrue.s (16 elements)
// p1: ptrue.h (32 elements)
hgemm_kernel_6x32:
// x0: K
// x1: A pointer (K x 6, column-major for this kernel)
// x2: B pointer (K x 32)
// x3: C pointer (6 x 32)
// Save callee-saved registers
stp d8, d9, [sp, #-64]!
stp d10, d11, [sp, #16]
stp d12, d13, [sp, #32]
stp d14, d15, [sp, #48]
ptrue p0.s // 16 FP32 elements
ptrue p1.h // 32 FP16 elements
// Initialize FP32 accumulators to zero
mov z0.d, #0
mov z1.d, #0
mov z2.d, #0
mov z3.d, #0
mov z4.d, #0
mov z5.d, #0
mov z6.d, #0
mov z7.d, #0
mov z8.d, #0
mov z9.d, #0
mov z10.d, #0
mov z11.d, #0
// Initialize FP16 temp accumulators to zero
mov z12.d, #0
mov z13.d, #0
mov z14.d, #0
mov z15.d, #0
mov z16.d, #0
mov z17.d, #0
mov x4, x0 // K counter
mov x5, #0 // K mod 8 counter
.loop_k_opt:
cbz x4, .final_convert
// Prefetch next B row
add x6, x2, #256
prfh pldl1keep, p1, [x6]
// Load 6 A values and broadcast to FP16 vectors
ld1rh {z18.h}, p1/z, [x1]
ld1rh {z19.h}, p1/z, [x1, #2]
ld1rh {z20.h}, p1/z, [x1, #4]
ld1rh {z21.h}, p1/z, [x1, #6]
ld1rh {z22.h}, p1/z, [x1, #8]
ld1rh {z23.h}, p1/z, [x1, #10]
// Load 32 FP16 from B
ld1h z24.h, p1/z, [x2]
// FP16 FMA for all 6 rows (32 ops each = 192 FP16 ops total)
fmla z12.h, p1/m, z18.h, z24.h
fmla z13.h, p1/m, z19.h, z24.h
fmla z14.h, p1/m, z20.h, z24.h
fmla z15.h, p1/m, z21.h, z24.h
fmla z16.h, p1/m, z22.h, z24.h
fmla z17.h, p1/m, z23.h, z24.h
// Advance pointers
add x1, x1, #12 // 6 FP16 = 12 bytes
add x2, x2, #64 // 32 FP16 = 64 bytes
sub x4, x4, #1
add x5, x5, #1
// Every 8 iterations, convert FP16 to FP32 and accumulate
cmp x5, #8
b.lt .loop_k_opt
// Convert FP16 temp accumulators to FP32
// Pattern: fcvt gets even elements, lsr+fcvt gets odd, zip interleaves
// Row 0 (z12 -> z0, z1)
fcvt z26.s, p0/m, z12.h // Even: h0, h2, h4, ...
lsr z27.s, z12.s, #16
fcvt z27.s, p0/m, z27.h // Odd: h1, h3, h5, ...
zip1 z28.s, z26.s, z27.s // h0, h1, h2, ..., h15
zip2 z29.s, z26.s, z27.s // h16, h17, ..., h31
fadd z0.s, p0/m, z0.s, z28.s
fadd z1.s, p0/m, z1.s, z29.s
mov z12.d, #0
// Row 1 (z13 -> z2, z3)
fcvt z26.s, p0/m, z13.h
lsr z27.s, z13.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z2.s, p0/m, z2.s, z28.s
fadd z3.s, p0/m, z3.s, z29.s
mov z13.d, #0
// Row 2 (z14 -> z4, z5)
fcvt z26.s, p0/m, z14.h
lsr z27.s, z14.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z4.s, p0/m, z4.s, z28.s
fadd z5.s, p0/m, z5.s, z29.s
mov z14.d, #0
// Row 3 (z15 -> z6, z7)
fcvt z26.s, p0/m, z15.h
lsr z27.s, z15.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z6.s, p0/m, z6.s, z28.s
fadd z7.s, p0/m, z7.s, z29.s
mov z15.d, #0
// Row 4 (z16 -> z8, z9)
fcvt z26.s, p0/m, z16.h
lsr z27.s, z16.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z8.s, p0/m, z8.s, z28.s
fadd z9.s, p0/m, z9.s, z29.s
mov z16.d, #0
// Row 5 (z17 -> z10, z11)
fcvt z26.s, p0/m, z17.h
lsr z27.s, z17.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z10.s, p0/m, z10.s, z28.s
fadd z11.s, p0/m, z11.s, z29.s
mov z17.d, #0
mov x5, #0 // Reset K mod counter
b .loop_k_opt
.final_convert:
// Convert any remaining FP16 accumulation to FP32
// Same pattern: fcvt for even, lsr+fcvt for odd, zip to interleave
// Row 0 (z12 -> z0, z1)
fcvt z26.s, p0/m, z12.h
lsr z27.s, z12.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z0.s, p0/m, z0.s, z28.s
fadd z1.s, p0/m, z1.s, z29.s
// Row 1 (z13 -> z2, z3)
fcvt z26.s, p0/m, z13.h
lsr z27.s, z13.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z2.s, p0/m, z2.s, z28.s
fadd z3.s, p0/m, z3.s, z29.s
// Row 2 (z14 -> z4, z5)
fcvt z26.s, p0/m, z14.h
lsr z27.s, z14.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z4.s, p0/m, z4.s, z28.s
fadd z5.s, p0/m, z5.s, z29.s
// Row 3 (z15 -> z6, z7)
fcvt z26.s, p0/m, z15.h
lsr z27.s, z15.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z6.s, p0/m, z6.s, z28.s
fadd z7.s, p0/m, z7.s, z29.s
// Row 4 (z16 -> z8, z9)
fcvt z26.s, p0/m, z16.h
lsr z27.s, z16.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z8.s, p0/m, z8.s, z28.s
fadd z9.s, p0/m, z9.s, z29.s
// Row 5 (z17 -> z10, z11)
fcvt z26.s, p0/m, z17.h
lsr z27.s, z17.s, #16
fcvt z27.s, p0/m, z27.h
zip1 z28.s, z26.s, z27.s
zip2 z29.s, z26.s, z27.s
fadd z10.s, p0/m, z10.s, z28.s
fadd z11.s, p0/m, z11.s, z29.s
// Store 6 x 32 FP32 results
st1w z0.s, p0, [x3, #0, mul vl]
st1w z1.s, p0, [x3, #1, mul vl]
st1w z2.s, p0, [x3, #2, mul vl]
st1w z3.s, p0, [x3, #3, mul vl]
st1w z4.s, p0, [x3, #4, mul vl]
st1w z5.s, p0, [x3, #5, mul vl]
st1w z6.s, p0, [x3, #6, mul vl]
st1w z7.s, p0, [x3, #7, mul vl]
add x6, x3, #512
st1w z8.s, p0, [x6, #0, mul vl]
st1w z9.s, p0, [x6, #1, mul vl]
st1w z10.s, p0, [x6, #2, mul vl]
st1w z11.s, p0, [x6, #3, mul vl]
// Restore callee-saved registers
ldp d14, d15, [sp, #48]
ldp d12, d13, [sp, #32]
ldp d10, d11, [sp, #16]
ldp d8, d9, [sp], #64
ret
.size hgemm_kernel_6x32, .-hgemm_kernel_6x32
// ============================================================================
// Ultra-optimized 8x32 Micro-Kernel with K-unrolling
// 8 rows x 32 columns, K unrolled by 4
// ============================================================================
.align 4
.global hgemm_kernel_8x32_unroll4
.type hgemm_kernel_8x32_unroll4, %function
hgemm_kernel_8x32_unroll4:
// x0: K (must be multiple of 4)
// x1: A pointer (K x 8)
// x2: B pointer (K x 32)
// x3: C pointer (8 x 32)
stp d8, d9, [sp, #-64]!
stp d10, d11, [sp, #16]
stp d12, d13, [sp, #32]
stp d14, d15, [sp, #48]
ptrue p0.s
ptrue p1.h
// FP32 accumulators: 8 rows x 2 vectors = 16 vectors (z0-z15)
mov z0.d, #0
mov z1.d, #0
mov z2.d, #0
mov z3.d, #0
mov z4.d, #0
mov z5.d, #0
mov z6.d, #0
mov z7.d, #0
mov z8.d, #0
mov z9.d, #0
mov z10.d, #0
mov z11.d, #0
mov z12.d, #0
mov z13.d, #0
mov z14.d, #0
mov z15.d, #0
// FP16 temp accumulators: 8 vectors (z16-z23)
mov z16.d, #0
mov z17.d, #0
mov z18.d, #0
mov z19.d, #0
mov z20.d, #0
mov z21.d, #0
mov z22.d, #0
mov z23.d, #0
lsr x4, x0, #2 // K / 4
mov x5, #0 // Accumulated K count
.loop_k8x32:
cbz x4, .final8x32
// Prefetch
add x6, x2, #512
prfh pldl1keep, p1, [x6]
// K iteration 0
ld1rh {z24.h}, p1/z, [x1, #0]
ld1rh {z25.h}, p1/z, [x1, #2]
ld1rh {z26.h}, p1/z, [x1, #4]
ld1rh {z27.h}, p1/z, [x1, #6]
ld1rh {z28.h}, p1/z, [x1, #8]
ld1rh {z29.h}, p1/z, [x1, #10]
ld1rh {z30.h}, p1/z, [x1, #12]
ld1rh {z31.h}, p1/z, [x1, #14]
ld1h z24.h, p1/z, [x2] // Reuse z24 for B
// Actually need separate B load - reload after A broadcasts
ld1h z24.h, p1/z, [x2]
// This approach has register pressure issues
// Simplify: process K=1 at a time but pipeline loads
// Reload A broadcasts for iteration 0
ld1rh {z25.h}, p1/z, [x1, #0]
ld1rh {z26.h}, p1/z, [x1, #2]
ld1rh {z27.h}, p1/z, [x1, #4]
ld1rh {z28.h}, p1/z, [x1, #6]
ld1rh {z29.h}, p1/z, [x1, #8]
ld1rh {z30.h}, p1/z, [x1, #10]
ld1rh {z31.h}, p1/z, [x1, #12]
// Need one more A value - use memory directly
fmla z16.h, p1/m, z25.h, z24.h
fmla z17.h, p1/m, z26.h, z24.h
fmla z18.h, p1/m, z27.h, z24.h
fmla z19.h, p1/m, z28.h, z24.h
fmla z20.h, p1/m, z29.h, z24.h
fmla z21.h, p1/m, z30.h, z24.h
fmla z22.h, p1/m, z31.h, z24.h
ld1rh {z25.h}, p1/z, [x1, #14]
fmla z23.h, p1/m, z25.h, z24.h
add x1, x1, #16
add x2, x2, #64
// K iterations 1-3 (similar pattern)
.rept 3
ld1h z24.h, p1/z, [x2]
ld1rh {z25.h}, p1/z, [x1, #0]
ld1rh {z26.h}, p1/z, [x1, #2]
ld1rh {z27.h}, p1/z, [x1, #4]
ld1rh {z28.h}, p1/z, [x1, #6]
ld1rh {z29.h}, p1/z, [x1, #8]
ld1rh {z30.h}, p1/z, [x1, #10]
ld1rh {z31.h}, p1/z, [x1, #12]
fmla z16.h, p1/m, z25.h, z24.h
fmla z17.h, p1/m, z26.h, z24.h
fmla z18.h, p1/m, z27.h, z24.h
fmla z19.h, p1/m, z28.h, z24.h
fmla z20.h, p1/m, z29.h, z24.h
fmla z21.h, p1/m, z30.h, z24.h
fmla z22.h, p1/m, z31.h, z24.h
ld1rh {z25.h}, p1/z, [x1, #14]
fmla z23.h, p1/m, z25.h, z24.h
add x1, x1, #16
add x2, x2, #64
.endr
sub x4, x4, #1
add x5, x5, #4
// Every 8 K iterations, convert and accumulate
ands x6, x5, #7
b.ne .loop_k8x32
// Convert all 8 FP16 temp accumulators to FP32
.irp row, 16, 17, 18, 19, 20, 21, 22, 23
.if \row == 16
fcvt z24.s, p0/m, z\row\().h
fadd z0.s, p0/m, z0.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z1.s, p0/m, z1.s, z24.s
.elseif \row == 17
fcvt z24.s, p0/m, z\row\().h
fadd z2.s, p0/m, z2.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z3.s, p0/m, z3.s, z24.s
.elseif \row == 18
fcvt z24.s, p0/m, z\row\().h
fadd z4.s, p0/m, z4.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z5.s, p0/m, z5.s, z24.s
.elseif \row == 19
fcvt z24.s, p0/m, z\row\().h
fadd z6.s, p0/m, z6.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z7.s, p0/m, z7.s, z24.s
.elseif \row == 20
fcvt z24.s, p0/m, z\row\().h
fadd z8.s, p0/m, z8.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z9.s, p0/m, z9.s, z24.s
.elseif \row == 21
fcvt z24.s, p0/m, z\row\().h
fadd z10.s, p0/m, z10.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z11.s, p0/m, z11.s, z24.s
.elseif \row == 22
fcvt z24.s, p0/m, z\row\().h
fadd z12.s, p0/m, z12.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z13.s, p0/m, z13.s, z24.s
.elseif \row == 23
fcvt z24.s, p0/m, z\row\().h
fadd z14.s, p0/m, z14.s, z24.s
mov z25.d, z\row\().d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z15.s, p0/m, z15.s, z24.s
.endif
mov z\row\().d, #0
.endr
mov x5, #0
b .loop_k8x32
.final8x32:
// Final conversion (same as above but without clearing)
fcvt z24.s, p0/m, z16.h
fadd z0.s, p0/m, z0.s, z24.s
mov z25.d, z16.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z1.s, p0/m, z1.s, z24.s
fcvt z24.s, p0/m, z17.h
fadd z2.s, p0/m, z2.s, z24.s
mov z25.d, z17.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z3.s, p0/m, z3.s, z24.s
fcvt z24.s, p0/m, z18.h
fadd z4.s, p0/m, z4.s, z24.s
mov z25.d, z18.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z5.s, p0/m, z5.s, z24.s
fcvt z24.s, p0/m, z19.h
fadd z6.s, p0/m, z6.s, z24.s
mov z25.d, z19.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z7.s, p0/m, z7.s, z24.s
fcvt z24.s, p0/m, z20.h
fadd z8.s, p0/m, z8.s, z24.s
mov z25.d, z20.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z9.s, p0/m, z9.s, z24.s
fcvt z24.s, p0/m, z21.h
fadd z10.s, p0/m, z10.s, z24.s
mov z25.d, z21.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z11.s, p0/m, z11.s, z24.s
fcvt z24.s, p0/m, z22.h
fadd z12.s, p0/m, z12.s, z24.s
mov z25.d, z22.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z13.s, p0/m, z13.s, z24.s
fcvt z24.s, p0/m, z23.h
fadd z14.s, p0/m, z14.s, z24.s
mov z25.d, z23.d
ext z25.b, z25.b, z25.b, #32
fcvt z24.s, p0/m, z25.h
fadd z15.s, p0/m, z15.s, z24.s
// Store 8 x 32 FP32 results (8 rows x 2 vectors)
st1w z0.s, p0, [x3, #0, mul vl]
st1w z1.s, p0, [x3, #1, mul vl]
st1w z2.s, p0, [x3, #2, mul vl]
st1w z3.s, p0, [x3, #3, mul vl]
st1w z4.s, p0, [x3, #4, mul vl]
st1w z5.s, p0, [x3, #5, mul vl]
st1w z6.s, p0, [x3, #6, mul vl]
st1w z7.s, p0, [x3, #7, mul vl]
add x6, x3, #512
st1w z8.s, p0, [x6, #0, mul vl]
st1w z9.s, p0, [x6, #1, mul vl]
st1w z10.s, p0, [x6, #2, mul vl]
st1w z11.s, p0, [x6, #3, mul vl]
add x6, x6, #256
st1w z12.s, p0, [x6, #0, mul vl]
st1w z13.s, p0, [x6, #1, mul vl]
st1w z14.s, p0, [x6, #2, mul vl]
st1w z15.s, p0, [x6, #3, mul vl]
ldp d14, d15, [sp, #48]
ldp d12, d13, [sp, #32]
ldp d10, d11, [sp, #16]
ldp d8, d9, [sp], #64
ret
.size hgemm_kernel_8x32_unroll4, .-hgemm_kernel_8x32_unroll4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment