Skip to content

Instantly share code, notes, and snippets.

@banach-space
Last active June 3, 2024 07:11
Show Gist options
  • Save banach-space/91a868f992a5747dd58b17a85f584ec8 to your computer and use it in GitHub Desktop.
Save banach-space/91a868f992a5747dd58b17a85f584ec8 to your computer and use it in GitHub Desktop.
C += A*B
// The following matmul maps very nicely onto SME with SVL=128bits. For f32, there are 4 4x4 tiles,
// that can be assembled as a 8x8 matrix.
// %mat_A_tr - transpose of A
func.func @matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) {
linalg.matmul ins(%mat_A_tr, %mat_B: memref<6x8xf32>, memref<6x8xf32>)
outs(%mat_C: memref<8x8xf32>)
return
}
//-------------------------------------------------
// DESIRED LOWERING FOR SME - HAND WRITTEN
// linalg.matmul as vector.contract
//
// Overview:
// * element type - f32
// * number of vector.contract accumulators - 4
// * C is partitioned into 4 tiles
// * A and B are partitioned into 2 halves
// * A is transposed before entering the kernel
//-------------------------------------------------
// %mat_A_tr - transpose of A
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module attributes { transform.with_named_sequence } {
func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
// Tile 0
%tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile 1
%tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile 2
%tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile 3
%tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile upper - A
%tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
// Tile lower - A
%tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
// Tile left - B
%tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
// Tile right - B
%tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
%tile_A_hi_tr = vector.transpose %tile_A_hi, [1, 0] : vector<6x4xf32> to vector<4x6xf32>
%tile_A_lo_tr = vector.transpose %tile_A_lo, [1, 0] : vector<6x4xf32> to vector<4x6xf32>
// Compute upper half of C
%tile0_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_left, %tile0_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
%tile1_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_right, %tile1_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
// Compute lower half of C
%tile2_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_left, %tile2_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
%tile3_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_right, %tile3_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
vector.transfer_write %tile0_matmul, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
vector.transfer_write %tile1_matmul, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
vector.transfer_write %tile2_matmul, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
vector.transfer_write %tile3_matmul, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
return
}
//-------------------------------------------------
// DESIRED LOWERING FOR SME - HAND WRITTEN
// linalg.matmul as vector.outerproduct
//
// Overview:
// * element type - f32
// * number of OP accumulators - 4
// * C is partitioned into 4 tiles
// * A and B are partitioned into 2 halves
// * A is transposed before entering the kernel
//-------------------------------------------------
// LEGEND:
// Matrix C. Each tile is [SVL_s x SVL_s] (4x4xf32 in this example)
// -----------------------------
// | | |
// | | |
// | tile0_C | tile1_C |
// | | |
// | | |
// -----------------------------
// | | |
// | | |
// | tile2_C | tile3_C |
// | | |
// | | |
// -----------------------------
// Matrix B. Each tile is [K x SVL_s] (6x4xf32 in this example)
// ----------------------------------
// | | |
// | | |
// | tile_B_left | tile_B_right |
// | | |
// | | |
// ---------------------------------|
// Matrix A transpose. Each tile is [K x SVL_s] (6x4xf32 in this example)
// ----------------------------------
// | | |
// | | |
// | tile_A_upper | tile_A_lower |
// | | |
// | | |
// ---------------------------------|
// Columns of A are rows of A^T.
// Row 3 of matrix B. Each half is SVL_s long (4xf32 in this example)
// -----------------------------
// | B_row_3_0 | B_row_3_1 |
// -----------------------------
// half 0 half 1
// %mat_A_tr - transpose of A
module {
func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
// Tile 0
%tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile 1
%tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile 2
%tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile 3
%tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32>
// Tile upper - A
%tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
// Tile lower - A
%tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
// Tile left - B
%tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
// Tile right - B
%tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32>
// ===================================================================================================================
// 1 COMPUTE UPPER HALF OF C
// ===================================================================================================================
// 1.1 Input tile "upper" from A and "left" from B - accumulate to ZA0.s (tile 0 of C)
%A_col_0_0 = vector.extract %tile_A_hi[0] : vector<6x4xf32>
%A_col_1_0 = vector.extract %tile_A_hi[1] : vector<6x4xf32>
%A_col_2_0 = vector.extract %tile_A_hi[2] : vector<6x4xf32>
%A_col_3_0 = vector.extract %tile_A_hi[3] : vector<6x4xf32>
%A_col_4_0 = vector.extract %tile_A_hi[4] : vector<6x4xf32>
%A_col_5_0 = vector.extract %tile_A_hi[5] : vector<6x4xf32>
%B_row_0_0 = vector.extract %tile_B_left[0] : vector<6x4xf32>
%B_row_1_0 = vector.extract %tile_B_left[1] : vector<6x4xf32>
%B_row_2_0 = vector.extract %tile_B_left[2] : vector<6x4xf32>
%B_row_3_0 = vector.extract %tile_B_left[3] : vector<6x4xf32>
%B_row_4_0 = vector.extract %tile_B_left[3] : vector<6x4xf32>
%B_row_5_0 = vector.extract %tile_B_left[3] : vector<6x4xf32>
%op_0 = vector.outerproduct %A_col_0_0, %B_row_0_0, %tile0_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_1 = vector.outerproduct %A_col_1_0, %B_row_1_0, %op_0 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_2 = vector.outerproduct %A_col_2_0, %B_row_2_0, %op_1 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_3 = vector.outerproduct %A_col_3_0, %B_row_3_0, %op_2 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_4 = vector.outerproduct %A_col_4_0, %B_row_4_0, %op_3 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_5 = vector.outerproduct %A_col_5_0, %B_row_5_0, %op_4 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
vector.transfer_write %op_5, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
// 1.2 Input tile "upper" from A and "right" from B - accumulate to ZA1.s (tile 1 of C)
%B_row_0_1 = vector.extract %tile_B_right[0] : vector<6x4xf32>
%B_row_1_1 = vector.extract %tile_B_right[1] : vector<6x4xf32>
%B_row_2_1 = vector.extract %tile_B_right[2] : vector<6x4xf32>
%B_row_3_1 = vector.extract %tile_B_right[3] : vector<6x4xf32>
%B_row_4_1 = vector.extract %tile_B_right[3] : vector<6x4xf32>
%B_row_5_1 = vector.extract %tile_B_right[3] : vector<6x4xf32>
%op_6 = vector.outerproduct %A_col_0_0, %B_row_0_1, %tile1_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_7 = vector.outerproduct %A_col_1_0, %B_row_1_1, %op_6 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_8 = vector.outerproduct %A_col_2_0, %B_row_2_1, %op_7 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_9 = vector.outerproduct %A_col_3_0, %B_row_3_1, %op_8 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_10 = vector.outerproduct %A_col_4_0, %B_row_4_1, %op_9 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_11 = vector.outerproduct %A_col_5_0, %B_row_5_1, %op_10 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
vector.transfer_write %op_11, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
// ===================================================================================================================
// 2 COMPUTE LOWER HALF OF C
// ===================================================================================================================
// 2.1 Input tile "lower" for A and tile left from B - accumulate to ZA2.s (tile 2 of C)
%A_col_0_1 = vector.extract %tile_A_lo[0] : vector<6x4xf32>
%A_col_1_1 = vector.extract %tile_A_lo[1] : vector<6x4xf32>
%A_col_2_1 = vector.extract %tile_A_lo[2] : vector<6x4xf32>
%A_col_3_1 = vector.extract %tile_A_lo[3] : vector<6x4xf32>
%A_col_4_1 = vector.extract %tile_A_lo[4] : vector<6x4xf32>
%A_col_5_1 = vector.extract %tile_A_lo[5] : vector<6x4xf32>
%op_12 = vector.outerproduct %A_col_0_1, %B_row_0_0, %tile2_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_13 = vector.outerproduct %A_col_1_1, %B_row_1_0, %op_12 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_14 = vector.outerproduct %A_col_2_1, %B_row_2_0, %op_13 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_15 = vector.outerproduct %A_col_3_1, %B_row_3_0, %op_14 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_16 = vector.outerproduct %A_col_3_1, %B_row_4_0, %op_15 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_17 = vector.outerproduct %A_col_3_1, %B_row_5_0, %op_16 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
vector.transfer_write %op_17, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
// 2.2 Input tile "lower" from A and tile "right" from B - accumulate to ZA3.s (tile 3 of C)
%op_18 = vector.outerproduct %A_col_0_1, %B_row_0_1, %tile3_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_19 = vector.outerproduct %A_col_1_1, %B_row_1_1, %op_18 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_20 = vector.outerproduct %A_col_2_1, %B_row_2_1, %op_19 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_21 = vector.outerproduct %A_col_3_1, %B_row_3_1, %op_20 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_22 = vector.outerproduct %A_col_3_1, %B_row_4_1, %op_21 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%op_23 = vector.outerproduct %A_col_3_1, %B_row_5_1, %op_22 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
vector.transfer_write %op_23, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32>
return
}
}
// Lower `vector.contract` to `vector.outerproduct`
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op
%2 = transform.merge_handles %1 { deduplicate } : !transform.any_op
transform.apply_patterns to %2 {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
} : !transform.any_op
}
// Lower `linalg.matmul` to `vector.outerproduct` - use 4 accumulators.
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tiled, %loops:2 = transform.structured.tile %0 [4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">)
%1 = get_closest_isolated_parent %tiled : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op
%3 = transform.structured.match ops{["scf.for"]} in %2 : (!transform.any_op) -> !transform.op<"scf.for">
transform.loop.unroll %3 { factor = 2 } : !transform.op<"scf.for">
}
@danikhan632
Copy link

How to lower linalg.matmul for SME

This Gist demonstrates potential lowering of linalg.matmul onto Arm's SME.

linalg_matmul_as_vector_op.mlir

This is the key file in this Gist. It's a hand-written lowering of linalg.matmul onto vector.outerproduct that we expect to map nicely onto SME. There is plenty of comments that should help to map that back onto the original matmul (from "matmul.mlir").

linalg_matmul_as_vector_contract.mlir

Hand-written lowering of linalg.matmul onto vector.contract. This can already be lowered to something similar to "linalg_matmul_as_vector_op.mlir".

matmul.mlir

The original linalg.matmul that we are attempting to lower in this Gist.

transform_seq_matmul.mlir

Transform dialect sequence that can be used to lower matmul.mlir to something similar as the hand-written code in "linalg_matmul_as_vector_contract.mlir".

transform_seq_contract.mlir

Transform dialect sequence that can be used to lower the output from "linalg_matmul_as_vector_contract.mlir" into something similar as the hand-written code from "linalg_matmul_as_vector_op.mlir".

Steps to reproduce

This how you can use the Transform dialect to generate something that should map nicely onto SME:

mlir-opt matmul.mlir --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=transform_seq_matmul.mlir})" | mlir-opt --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=transform_seq_contract.mlir})"

Set-up

Tested with: 0eb0fecbc544

_ No description provided. _

Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with

@banach-space
Copy link
Author

Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with

Sure, feel free to DM me on Discord or Discourse!

@danikhan632
Copy link

Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with

Sure, feel free to DM me on Discord or Discourse!

Thank you so much, do you have a discord username that I can add?

@banach-space
Copy link
Author

Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with

Sure, feel free to DM me on Discord or Discourse!

Thank you so much, do you have a discord username that I can add?

Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)

@danikhan632
Copy link

Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with

Sure, feel free to DM me on Discord or Discourse!

Thank you so much, do you have a discord username that I can add?

Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)

Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you

@banach-space
Copy link
Author

Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with

Sure, feel free to DM me on Discord or Discourse!

Thank you so much, do you have a discord username that I can add?

Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)

Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you

Sorry, I wasn't able to find you. Lets just use good old fashioned e-mail: [email protected] :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment