Last active
June 3, 2024 07:11
-
-
Save banach-space/91a868f992a5747dd58b17a85f584ec8 to your computer and use it in GitHub Desktop.
C += A*B
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// The following matmul maps very nicely onto SME with SVL=128bits. For f32, there are 4 4x4 tiles, | |
// that can be assembled as a 8x8 matrix. | |
// %mat_A_tr - transpose of A | |
func.func @matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
linalg.matmul ins(%mat_A_tr, %mat_B: memref<6x8xf32>, memref<6x8xf32>) | |
outs(%mat_C: memref<8x8xf32>) | |
return | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//------------------------------------------------- | |
// DESIRED LOWERING FOR SME - HAND WRITTEN | |
// linalg.matmul as vector.contract | |
// | |
// Overview: | |
// * element type - f32 | |
// * number of vector.contract accumulators - 4 | |
// * C is partitioned into 4 tiles | |
// * A and B are partitioned into 2 halves | |
// * A is transposed before entering the kernel | |
//------------------------------------------------- | |
// %mat_A_tr - transpose of A | |
#map = affine_map<(d0, d1, d2) -> (d0, d2)> | |
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> | |
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> | |
module attributes { transform.with_named_sequence } { | |
func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
%c0 = arith.constant 0 : index | |
%c4 = arith.constant 4 : index | |
%cst = arith.constant 0.000000e+00 : f32 | |
// Tile 0 | |
%tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 1 | |
%tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 2 | |
%tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 3 | |
%tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile upper - A | |
%tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile lower - A | |
%tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile left - B | |
%tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile right - B | |
%tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
%tile_A_hi_tr = vector.transpose %tile_A_hi, [1, 0] : vector<6x4xf32> to vector<4x6xf32> | |
%tile_A_lo_tr = vector.transpose %tile_A_lo, [1, 0] : vector<6x4xf32> to vector<4x6xf32> | |
// Compute upper half of C | |
%tile0_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_left, %tile0_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
%tile1_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_right, %tile1_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
// Compute lower half of C | |
%tile2_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_left, %tile2_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
%tile3_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_right, %tile3_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
vector.transfer_write %tile0_matmul, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
vector.transfer_write %tile1_matmul, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
vector.transfer_write %tile2_matmul, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
vector.transfer_write %tile3_matmul, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
return | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//------------------------------------------------- | |
// DESIRED LOWERING FOR SME - HAND WRITTEN | |
// linalg.matmul as vector.outerproduct | |
// | |
// Overview: | |
// * element type - f32 | |
// * number of OP accumulators - 4 | |
// * C is partitioned into 4 tiles | |
// * A and B are partitioned into 2 halves | |
// * A is transposed before entering the kernel | |
//------------------------------------------------- | |
// LEGEND: | |
// Matrix C. Each tile is [SVL_s x SVL_s] (4x4xf32 in this example) | |
// ----------------------------- | |
// | | | | |
// | | | | |
// | tile0_C | tile1_C | | |
// | | | | |
// | | | | |
// ----------------------------- | |
// | | | | |
// | | | | |
// | tile2_C | tile3_C | | |
// | | | | |
// | | | | |
// ----------------------------- | |
// Matrix B. Each tile is [K x SVL_s] (6x4xf32 in this example) | |
// ---------------------------------- | |
// | | | | |
// | | | | |
// | tile_B_left | tile_B_right | | |
// | | | | |
// | | | | |
// ---------------------------------| | |
// Matrix A transpose. Each tile is [K x SVL_s] (6x4xf32 in this example) | |
// ---------------------------------- | |
// | | | | |
// | | | | |
// | tile_A_upper | tile_A_lower | | |
// | | | | |
// | | | | |
// ---------------------------------| | |
// Columns of A are rows of A^T. | |
// Row 3 of matrix B. Each half is SVL_s long (4xf32 in this example) | |
// ----------------------------- | |
// | B_row_3_0 | B_row_3_1 | | |
// ----------------------------- | |
// half 0 half 1 | |
// %mat_A_tr - transpose of A | |
module { | |
func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
%c0 = arith.constant 0 : index | |
%c4 = arith.constant 4 : index | |
%cst = arith.constant 0.000000e+00 : f32 | |
// Tile 0 | |
%tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 1 | |
%tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 2 | |
%tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile 3 | |
%tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
// Tile upper - A | |
%tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile lower - A | |
%tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile left - B | |
%tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// Tile right - B | |
%tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
// =================================================================================================================== | |
// 1 COMPUTE UPPER HALF OF C | |
// =================================================================================================================== | |
// 1.1 Input tile "upper" from A and "left" from B - accumulate to ZA0.s (tile 0 of C) | |
%A_col_0_0 = vector.extract %tile_A_hi[0] : vector<6x4xf32> | |
%A_col_1_0 = vector.extract %tile_A_hi[1] : vector<6x4xf32> | |
%A_col_2_0 = vector.extract %tile_A_hi[2] : vector<6x4xf32> | |
%A_col_3_0 = vector.extract %tile_A_hi[3] : vector<6x4xf32> | |
%A_col_4_0 = vector.extract %tile_A_hi[4] : vector<6x4xf32> | |
%A_col_5_0 = vector.extract %tile_A_hi[5] : vector<6x4xf32> | |
%B_row_0_0 = vector.extract %tile_B_left[0] : vector<6x4xf32> | |
%B_row_1_0 = vector.extract %tile_B_left[1] : vector<6x4xf32> | |
%B_row_2_0 = vector.extract %tile_B_left[2] : vector<6x4xf32> | |
%B_row_3_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
%B_row_4_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
%B_row_5_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
%op_0 = vector.outerproduct %A_col_0_0, %B_row_0_0, %tile0_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_1 = vector.outerproduct %A_col_1_0, %B_row_1_0, %op_0 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_2 = vector.outerproduct %A_col_2_0, %B_row_2_0, %op_1 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_3 = vector.outerproduct %A_col_3_0, %B_row_3_0, %op_2 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_4 = vector.outerproduct %A_col_4_0, %B_row_4_0, %op_3 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_5 = vector.outerproduct %A_col_5_0, %B_row_5_0, %op_4 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_5, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
// 1.2 Input tile "upper" from A and "right" from B - accumulate to ZA1.s (tile 1 of C) | |
%B_row_0_1 = vector.extract %tile_B_right[0] : vector<6x4xf32> | |
%B_row_1_1 = vector.extract %tile_B_right[1] : vector<6x4xf32> | |
%B_row_2_1 = vector.extract %tile_B_right[2] : vector<6x4xf32> | |
%B_row_3_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
%B_row_4_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
%B_row_5_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
%op_6 = vector.outerproduct %A_col_0_0, %B_row_0_1, %tile1_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_7 = vector.outerproduct %A_col_1_0, %B_row_1_1, %op_6 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_8 = vector.outerproduct %A_col_2_0, %B_row_2_1, %op_7 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_9 = vector.outerproduct %A_col_3_0, %B_row_3_1, %op_8 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_10 = vector.outerproduct %A_col_4_0, %B_row_4_1, %op_9 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_11 = vector.outerproduct %A_col_5_0, %B_row_5_1, %op_10 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_11, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
// =================================================================================================================== | |
// 2 COMPUTE LOWER HALF OF C | |
// =================================================================================================================== | |
// 2.1 Input tile "lower" for A and tile left from B - accumulate to ZA2.s (tile 2 of C) | |
%A_col_0_1 = vector.extract %tile_A_lo[0] : vector<6x4xf32> | |
%A_col_1_1 = vector.extract %tile_A_lo[1] : vector<6x4xf32> | |
%A_col_2_1 = vector.extract %tile_A_lo[2] : vector<6x4xf32> | |
%A_col_3_1 = vector.extract %tile_A_lo[3] : vector<6x4xf32> | |
%A_col_4_1 = vector.extract %tile_A_lo[4] : vector<6x4xf32> | |
%A_col_5_1 = vector.extract %tile_A_lo[5] : vector<6x4xf32> | |
%op_12 = vector.outerproduct %A_col_0_1, %B_row_0_0, %tile2_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_13 = vector.outerproduct %A_col_1_1, %B_row_1_0, %op_12 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_14 = vector.outerproduct %A_col_2_1, %B_row_2_0, %op_13 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_15 = vector.outerproduct %A_col_3_1, %B_row_3_0, %op_14 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_16 = vector.outerproduct %A_col_3_1, %B_row_4_0, %op_15 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_17 = vector.outerproduct %A_col_3_1, %B_row_5_0, %op_16 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_17, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
// 2.2 Input tile "lower" from A and tile "right" from B - accumulate to ZA3.s (tile 3 of C) | |
%op_18 = vector.outerproduct %A_col_0_1, %B_row_0_1, %tile3_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_19 = vector.outerproduct %A_col_1_1, %B_row_1_1, %op_18 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_20 = vector.outerproduct %A_col_2_1, %B_row_2_1, %op_19 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_21 = vector.outerproduct %A_col_3_1, %B_row_3_1, %op_20 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_22 = vector.outerproduct %A_col_3_1, %B_row_4_1, %op_21 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
%op_23 = vector.outerproduct %A_col_3_1, %B_row_5_1, %op_22 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
vector.transfer_write %op_23, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
return | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Lower `vector.contract` to `vector.outerproduct` | |
transform.sequence failures(propagate) { | |
^bb1(%arg1: !transform.any_op): | |
%0 = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op | |
%1 = transform.get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op | |
%2 = transform.merge_handles %1 { deduplicate } : !transform.any_op | |
transform.apply_patterns to %2 { | |
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" | |
} : !transform.any_op | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Lower `linalg.matmul` to `vector.outerproduct` - use 4 accumulators. | |
transform.sequence failures(propagate) { | |
^bb0(%arg1: !transform.any_op): | |
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op | |
%tiled, %loops:2 = transform.structured.tile %0 [4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">) | |
%1 = get_closest_isolated_parent %tiled : (!transform.any_op) -> !transform.any_op | |
%2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op | |
%3 = transform.structured.match ops{["scf.for"]} in %2 : (!transform.any_op) -> !transform.op<"scf.for"> | |
transform.loop.unroll %3 { factor = 2 } : !transform.op<"scf.for"> | |
} |
Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with
Sure, feel free to DM me on Discord or Discourse!
Thank you so much, do you have a discord username that I can add?
Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)
Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you
Sorry, I wasn't able to find you. Lets just use good old fashioned e-mail: [email protected] :)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you