Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created February 3, 2025 12:19
Show Gist options
  • Save pashu123/8823f629a1784f11600bbd332fabc1fe to your computer and use it in GitHub Desktop.
Save pashu123/8823f629a1784f11600bbd332fabc1fe to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
func.func @dont_fuse_when_same_trunc_op_dispatch_0_generic_2x4096x640x640_i8xi8xi8xi32xi32(%arg0: tensor<2x4096x640xi8>, %arg1: tensor<2x640x640xi8>, %arg2: tensor<2x640x640xi8>) -> tensor<2x4096x640xf16> {
%c0_i32 = arith.constant 0 : i32
%0 = tensor.empty() : tensor<2x4096x640xf16>
%1 = tensor.empty() : tensor<2x4096x640xi32>
%2:2 = scf.forall (%arg3, %arg4, %arg5) = (0, 0, 0) to (2, 4096, 640) step (1, 64, 64) shared_outs(%arg6 = %1, %arg7 = %1) -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) {
%extracted_slice = tensor.extract_slice %arg0[%arg3, %arg4, 0] [1, 64, 640] [1, 1, 1] : tensor<2x4096x640xi8> to tensor<1x64x640xi8>
%extracted_slice_0 = tensor.extract_slice %arg2[%arg3, %arg5, 0] [1, 64, 640] [1, 1, 1] : tensor<2x640x640xi8> to tensor<1x64x640xi8>
%extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4, %arg5] [1, 64, 64] [1, 1, 1] : tensor<2x4096x640xi32> to tensor<1x64x64xi32>
%4 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_1 : tensor<1x64x64xi32>) -> tensor<1x64x64xi32>
%extracted_slice_2 = tensor.extract_slice %arg7[%arg3, %arg4, %arg5] [1, 64, 64] [1, 1, 1] : tensor<2x4096x640xi32> to tensor<1x64x64xi32>
%5 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_2 : tensor<1x64x64xi32>) -> tensor<1x64x64xi32>
%6:2 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %extracted_slice_0, %extracted_slice_0 : tensor<1x64x640xi8>, tensor<1x64x640xi8>, tensor<1x64x640xi8>) outs(%4, %5 : tensor<1x64x64xi32>, tensor<1x64x64xi32>) {
^bb0(%in: i8, %in_3: i8, %in_4: i8, %out: i32, %out_5: i32):
%7 = arith.extsi %in : i8 to i32
%8 = arith.extsi %in_3 : i8 to i32
%9 = arith.muli %7, %8 : i32
%10 = arith.addi %out, %9 : i32
%11 = arith.extsi %in_4 : i8 to i32
%12 = arith.muli %7, %11 : i32
%13 = arith.addi %out_5, %12 : i32
linalg.yield %10, %13 : i32, i32
} -> (tensor<1x64x64xi32>, tensor<1x64x64xi32>)
scf.forall.in_parallel {
tensor.parallel_insert_slice %6#0 into %arg6[%arg3, %arg4, %arg5] [1, 64, 64] [1, 1, 1] : tensor<1x64x64xi32> into tensor<2x4096x640xi32>
tensor.parallel_insert_slice %6#1 into %arg7[%arg3, %arg4, %arg5] [1, 64, 64] [1, 1, 1] : tensor<1x64x64xi32> into tensor<2x4096x640xi32>
}
}
%3 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2#1, %2#0 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) outs(%0 : tensor<2x4096x640x
f16>) {
^bb0(%in: i32, %in_0: i32, %out: f16):
%4 = arith.sitofp %in : i32 to f32
%5 = arith.truncf %4 : f32 to f16
%6 = arith.sitofp %in_0 : i32 to f32
%7 = arith.truncf %6 : f32 to f16
%8 = arith.addf %5, %7 : f16
linalg.yield %8 : f16
} -> tensor<2x4096x640xf16>
return %3 : tensor<2x4096x640xf16>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment