Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Created July 3, 2024 03:42
Show Gist options
  • Save qedawkins/11a10aeefef6168829f6bffdf949fe1e to your computer and use it in GitHub Desktop.
Save qedawkins/11a10aeefef6168829f6bffdf949fe1e to your computer and use it in GitHub Desktop.
#map = affine_map<(d0) -> (d0 * 32)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
"builtin.module"() ({
"builtin.module"() ({
"func.func"() <{function_type = (tensor<128xf32>, tensor<128x128xf16>, tensor<128x64xf32>, tensor<128x128xf32>) -> (), sym_name = "loop_sibling_fusion"}> ({
^bb0(%arg1: tensor<128xf32>, %arg2: tensor<128x128xf16>, %arg3: tensor<128x64xf32>, %arg4: tensor<128x128xf32>):
%3:2 = "scf.forall"(%arg1, %arg2) <{mapping = [#gpu.warp<linear_dim_0>], operandSegmentSizes = array<i32: 0, 0, 0, 2>, staticLowerBound = array<i64: 0>, staticStep = array<i64: 1>, staticUpperBound = array<i64: 4>}> ({
^bb0(%arg5: index, %arg6: tensor<128xf32>, %arg7: tensor<128x128xf16>):
%5 = "affine.apply"(%arg5) <{map = #map}> : (index) -> index
%6 = "tensor.extract_slice"(%arg4, %5) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 32, 1>, static_strides = array<i64: 1, 1>}> : (tensor<128x128xf32>, index) -> tensor<32xf32>
%7 = "affine.apply"(%arg5) <{map = #map}> : (index) -> index
%8 = "tensor.extract_slice"(%arg4, %7) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 32, 128>, static_strides = array<i64: 1, 1>}> : (tensor<128x128xf32>, index) -> tensor<32x128xf32>
%9 = "tensor.extract_slice"(%4, %7) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 32, 128>, static_strides = array<i64: 1, 1>}> : (tensor<128x128xf16>, index) -> tensor<32x128xf16>
%10 = "linalg.generic"(%8, %9) <{indexing_maps = [#map1, #map1], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg8: f32, %arg9: f16):
%11 = "arith.truncf"(%arg8) : (f32) -> f16
"linalg.yield"(%11) : (f16) -> ()
}) : (tensor<32x128xf32>, tensor<32x128xf16>) -> tensor<32x128xf16>
"scf.forall.in_parallel"() ({
"tensor.parallel_insert_slice"(%6, %arg6, %5) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 32>, static_strides = array<i64: 1>}> : (tensor<32xf32>, tensor<128xf32>, index) -> ()
"tensor.parallel_insert_slice"(%10, %arg7, %7) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 32, 128>, static_strides = array<i64: 1, 1>}> : (tensor<32x128xf16>, tensor<128x128xf16>, index) -> ()
}) : () -> ()
}) : (tensor<128xf32>, tensor<128x128xf16>) -> (tensor<128xf32>, tensor<128x128xf16>)
%4 = "tensor.empty"() : () -> tensor<128x128xf16>
"func.return"() : () -> ()
}) : () -> ()
}) : () -> ()
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
%0 = "transform.structured.match"(%arg0) <{ops = ["scf.forall"]}> : (!transform.any_op) -> !transform.any_op
%1:2 = "transform.split_handle"(%0) <{fail_on_payload_too_small = true, pass_through_empty_handle = true}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%2 = "transform.loop.fuse_sibling"(%1#0, %1#1) : (!transform.any_op, !transform.any_op) -> !transform.any_op
"transform.print"(%arg0) : (!transform.any_op) -> ()
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
}) : () -> ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment