Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created March 10, 2025 17:04
Show Gist options
  • Save pashu123/568dc8aa664243a47635c5e0e7b38b38 to your computer and use it in GitHub Desktop.
Save pashu123/568dc8aa664243a47635c5e0e7b38b38 to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
module {
func.func @matmul_add_transpose(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = tensor.empty() : tensor<4x4xf32>
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%1 = tensor.empty() : tensor<4x4xf32>
%2 = scf.for %arg2 = %c0 to %c4 step %c2 iter_args(%arg3 = %1) -> (tensor<4x4xf32>) {
%4 = scf.for %arg4 = %c0 to %c4 step %c2 iter_args(%arg5 = %arg3) -> (tensor<4x4xf32>) {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0] [2, 4] [1, 1] : tensor<4x4xf32> to tensor<2x4xf32>
%extracted_slice_0 = tensor.extract_slice %arg1[0, %arg4] [4, 2] [1, 1] : tensor<4x4xf32> to tensor<4x2xf32>
%extracted_slice_1 = tensor.extract_slice %arg5[%arg2, %arg4] [2, 2] [1, 1] : tensor<4x4xf32> to tensor<2x2xf32>
%5 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<2x4xf32>, tensor<4x2xf32>) outs(%extracted_slice_1 : tensor<2x2xf32>) -> tensor<2x2xf32>
%inserted_slice = tensor.insert_slice %5 into %arg5[%arg2, %arg4] [2, 2] [1, 1] : tensor<2x2xf32> into tensor<4x4xf32>
scf.yield %inserted_slice : tensor<4x4xf32>
}
scf.yield %4 : tensor<4x4xf32>
}
%3 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %2 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%0 : tensor<4x4xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<4x4xf32>
return %3 : tensor<4x4xf32>
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_consumer %yield
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment