-
-
Save banach-space/91a868f992a5747dd58b17a85f584ec8 to your computer and use it in GitHub Desktop.
// The following matmul maps very nicely onto SME with SVL=128bits. For f32, there are 4 4x4 tiles, | |
// that can be assembled as a 8x8 matrix. | |
// %mat_A_tr - transpose of A | |
func.func @matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
linalg.matmul ins(%mat_A_tr, %mat_B: memref<6x8xf32>, memref<6x8xf32>) | |
outs(%mat_C: memref<8x8xf32>) | |
return | |
} |
//------------------------------------------------- | |
// DESIRED LOWERING FOR SME - HAND WRITTEN | |
// linalg.matmul as vector.contract | |
// | |
// Overview: | |
// * element type - f32 | |
// * number of vector.contract accumulators - 4 | |
// * C is partitioned into 4 tiles | |
// * A and B are partitioned into 2 halves | |
// * A is transposed before entering the kernel | |
//------------------------------------------------- | |
// %mat_A_tr - transpose of A | |
#map = affine_map<(d0, d1, d2) -> (d0, d2)> | |
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> | |
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> | |
module attributes { transform.with_named_sequence } { | |
func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
%c0 = arith.constant 0 : index | |
%c4 = arith.constant 4 : index | |
%cst = arith.constant 0.000000e+00 : f32 | |
// Tile 0 | |
%tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 1 | |
%tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 2 | |
%tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 3 | |
%tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile upper - A | |
%tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile lower - A | |
%tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile left - B | |
%tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile right - B | |
%tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
%tile_A_hi_tr = vector.transpose %tile_A_hi, [1, 0] : vector<6x4xf32> to vector<4x6xf32> | |
%tile_A_lo_tr = vector.transpose %tile_A_lo, [1, 0] : vector<6x4xf32> to vector<4x6xf32> | |
// Compute upper half of C | |
%tile0_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_left, %tile0_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
%tile1_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_right, %tile1_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
// Compute lower half of C | |
%tile2_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_left, %tile2_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
%tile3_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_right, %tile3_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
vector.transfer_write %tile0_matmul, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
vector.transfer_write %tile1_matmul, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
vector.transfer_write %tile2_matmul, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
vector.transfer_write %tile3_matmul, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
return | |
} |
//------------------------------------------------- | |
// DESIRED LOWERING FOR SME - HAND WRITTEN | |
// linalg.matmul as vector.outerproduct | |
// | |
// Overview: | |
// * element type - f32 | |
// * number of OP accumulators - 4 | |
// * C is partitioned into 4 tiles | |
// * A and B are partitioned into 2 halves | |
// * A is transposed before entering the kernel | |
//------------------------------------------------- | |
// LEGEND: | |
// Matrix C. Each tile is [SVL_s x SVL_s] (4x4xf32 in this example) | |
// ----------------------------- | |
// | | | | |
// | | | | |
// | tile0_C | tile1_C | | |
// | | | | |
// | | | | |
// ----------------------------- | |
// | | | | |
// | | | | |
// | tile2_C | tile3_C | | |
// | | | | |
// | | | | |
// ----------------------------- | |
// Matrix B. Each tile is [K x SVL_s] (6x4xf32 in this example) | |
// ---------------------------------- | |
// | | | | |
// | | | | |
// | tile_B_left | tile_B_right | | |
// | | | | |
// | | | | |
// ---------------------------------| | |
// Matrix A transpose. Each tile is [K x SVL_s] (6x4xf32 in this example) | |
// ---------------------------------- | |
// | | | | |
// | | | | |
// | tile_A_upper | tile_A_lower | | |
// | | | | |
// | | | | |
// ---------------------------------| | |
// Columns of A are rows of A^T. | |
// Row 3 of matrix B. Each half is SVL_s long (4xf32 in this example) | |
// ----------------------------- | |
// | B_row_3_0 | B_row_3_1 | | |
// ----------------------------- | |
// half 0 half 1 | |
// %mat_A_tr - transpose of A | |
module { | |
func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
%c0 = arith.constant 0 : index | |
%c4 = arith.constant 4 : index | |
%cst = arith.constant 0.000000e+00 : f32 | |
// Tile 0 | |
%tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 1 | |
%tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 2 | |
%tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 3 | |
%tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile upper - A | |
%tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile lower - A | |
%tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile left - B | |
%tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile right - B | |
%tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// =================================================================================================================== | |
// 1 COMPUTE UPPER HALF OF C | |
// =================================================================================================================== | |
// 1.1 Input tile "upper" from A and "left" from B - accumulate to ZA0.s (tile 0 of C) | |
%A_col_0_0 = vector.extract %tile_A_hi[0] : vector<6x4xf32> | |
%A_col_1_0 = vector.extract %tile_A_hi[1] : vector<6x4xf32> | |
%A_col_2_0 = vector.extract %tile_A_hi[2] : vector<6x4xf32> | |
%A_col_3_0 = vector.extract %tile_A_hi[3] : vector<6x4xf32> | |
%A_col_4_0 = vector.extract %tile_A_hi[4] : vector<6x4xf32> | |
%A_col_5_0 = vector.extract %tile_A_hi[5] : vector<6x4xf32> | |
%B_row_0_0 = vector.extract %tile_B_left[0] : vector<6x4xf32> | |
%B_row_1_0 = vector.extract %tile_B_left[1] : vector<6x4xf32> | |
%B_row_2_0 = vector.extract %tile_B_left[2] : vector<6x4xf32> | |
%B_row_3_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
%B_row_4_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
%B_row_5_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
%op_0 = vector.outerproduct %A_col_0_0, %B_row_0_0, %tile0_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_1 = vector.outerproduct %A_col_1_0, %B_row_1_0, %op_0 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_2 = vector.outerproduct %A_col_2_0, %B_row_2_0, %op_1 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_3 = vector.outerproduct %A_col_3_0, %B_row_3_0, %op_2 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_4 = vector.outerproduct %A_col_4_0, %B_row_4_0, %op_3 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_5 = vector.outerproduct %A_col_5_0, %B_row_5_0, %op_4 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_5, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
// 1.2 Input tile "upper" from A and "right" from B - accumulate to ZA1.s (tile 1 of C) | |
%B_row_0_1 = vector.extract %tile_B_right[0] : vector<6x4xf32> | |
%B_row_1_1 = vector.extract %tile_B_right[1] : vector<6x4xf32> | |
%B_row_2_1 = vector.extract %tile_B_right[2] : vector<6x4xf32> | |
%B_row_3_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
%B_row_4_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
%B_row_5_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
%op_6 = vector.outerproduct %A_col_0_0, %B_row_0_1, %tile1_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_7 = vector.outerproduct %A_col_1_0, %B_row_1_1, %op_6 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_8 = vector.outerproduct %A_col_2_0, %B_row_2_1, %op_7 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_9 = vector.outerproduct %A_col_3_0, %B_row_3_1, %op_8 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_10 = vector.outerproduct %A_col_4_0, %B_row_4_1, %op_9 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_11 = vector.outerproduct %A_col_5_0, %B_row_5_1, %op_10 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_11, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
// =================================================================================================================== | |
// 2 COMPUTE LOWER HALF OF C | |
// =================================================================================================================== | |
// 2.1 Input tile "lower" for A and tile left from B - accumulate to ZA2.s (tile 2 of C) | |
%A_col_0_1 = vector.extract %tile_A_lo[0] : vector<6x4xf32> | |
%A_col_1_1 = vector.extract %tile_A_lo[1] : vector<6x4xf32> | |
%A_col_2_1 = vector.extract %tile_A_lo[2] : vector<6x4xf32> | |
%A_col_3_1 = vector.extract %tile_A_lo[3] : vector<6x4xf32> | |
%A_col_4_1 = vector.extract %tile_A_lo[4] : vector<6x4xf32> | |
%A_col_5_1 = vector.extract %tile_A_lo[5] : vector<6x4xf32> | |
%op_12 = vector.outerproduct %A_col_0_1, %B_row_0_0, %tile2_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_13 = vector.outerproduct %A_col_1_1, %B_row_1_0, %op_12 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_14 = vector.outerproduct %A_col_2_1, %B_row_2_0, %op_13 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_15 = vector.outerproduct %A_col_3_1, %B_row_3_0, %op_14 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_16 = vector.outerproduct %A_col_3_1, %B_row_4_0, %op_15 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_17 = vector.outerproduct %A_col_3_1, %B_row_5_0, %op_16 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_17, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
// 2.2 Input tile "lower" from A and tile "right" from B - accumulate to ZA3.s (tile 3 of C) | |
%op_18 = vector.outerproduct %A_col_0_1, %B_row_0_1, %tile3_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_19 = vector.outerproduct %A_col_1_1, %B_row_1_1, %op_18 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_20 = vector.outerproduct %A_col_2_1, %B_row_2_1, %op_19 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_21 = vector.outerproduct %A_col_3_1, %B_row_3_1, %op_20 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_22 = vector.outerproduct %A_col_3_1, %B_row_4_1, %op_21 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_23 = vector.outerproduct %A_col_3_1, %B_row_5_1, %op_22 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_23, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
return | |
} | |
} |
// Lower `vector.contract` to `vector.outerproduct` | |
transform.sequence failures(propagate) { | |
^bb1(%arg1: !transform.any_op): | |
%0 = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op | |
%1 = transform.get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op | |
%2 = transform.merge_handles %1 { deduplicate } : !transform.any_op | |
transform.apply_patterns to %2 { | |
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" | |
} : !transform.any_op | |
} |
// Lower `linalg.matmul` to `vector.outerproduct` - use 4 accumulators. | |
transform.sequence failures(propagate) { | |
^bb0(%arg1: !transform.any_op): | |
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op | |
%tiled, %loops:2 = transform.structured.tile %0 [4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">) | |
%1 = get_closest_isolated_parent %tiled : (!transform.any_op) -> !transform.any_op | |
%2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op | |
%3 = transform.structured.match ops{["scf.for"]} in %2 : (!transform.any_op) -> !transform.op<"scf.for"> | |
transform.loop.unroll %3 { factor = 2 } : !transform.op<"scf.for"> | |
} |
Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with
Sure, feel free to DM me on Discord or Discourse!
Thank you so much, do you have a discord username that I can add?
Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)
Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you
Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with
Sure, feel free to DM me on Discord or Discourse!
Thank you so much, do you have a discord username that I can add?
Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)
Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you
Sorry, I wasn't able to find you. Lets just use good old fashioned e-mail: [email protected] :)
Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)