Created
June 6, 2025 16:44
-
-
Save bjacob/84947a1794dfa87720205db55bc9aa59 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
index 9b8292d052..ddc1fbc04d 100644 | |
--- a/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir | |
+++ b/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir | |
@@ -925,6 +925,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b | |
%c64 = arith.constant 64 : index | |
%c128 = arith.constant 128 : index | |
%c256 = arith.constant 256 : index | |
+ %c16384 = arith.constant 16384 : index | |
+ %c32768 = arith.constant 32768 : index | |
+ | |
%cst = arith.constant 0.0 : f8E4M3FNUZ | |
%lhs_shared_base = memref.alloc() : !mflat_shared_f8 | |
%rhs_shared_base = memref.alloc() : !flat_shared_f8 | |
@@ -933,6 +936,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b | |
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty_f8 | |
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_f8 | |
+ %lhs_flat = tensor.collapse_shape %lhs [[0, 1, 2]] : !mexp_in_ty_f8 into tensor<?xf8E4M3FNUZ> | |
+ %rhs_flat = tensor.collapse_shape %rhs [[0, 1]] : !in_ty_f8 into tensor<?xf8E4M3FNUZ> | |
+ | |
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !mflat_shared_f8 | |
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8 | |
@@ -1001,7 +1007,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b | |
rocdl.sched.barrier 0 | |
// Global loads of rhs. | |
- %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 128] [1, 1] : !in_ty_f8 to !block_in_f8 | |
+ %i_times_rhs_flat_block_size = arith.muli %i, %c32768 : index | |
+ %rhs_flat_block = tensor.extract_slice %rhs_flat [%i_times_rhs_flat_block_size] [32768] [1] : tensor<?xf8E4M3FNUZ> to tensor<32768xf8E4M3FNUZ> | |
+ %rhs_block = tensor.expand_shape %rhs_flat_block [[0, 1]] output_shape [256, 128] : tensor<32768xf8E4M3FNUZ> into !block_in_f8 | |
%rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ> | |
%rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ> | |
%rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ> | |
@@ -1021,7 +1029,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b | |
rocdl.sched.barrier 0 | |
// Global loads of lhs. | |
- %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 128, 128] [1, 1, 1] : !mexp_in_ty_f8 to !mexp_block_in_f8 | |
+ %i_times_lhs_flat_block_size = arith.muli %i, %c16384 : index | |
+ %lhs_flat_block = tensor.extract_slice %lhs_flat [%i_times_lhs_flat_block_size] [16384] [1] : tensor<?xf8E4M3FNUZ> to tensor<16384xf8E4M3FNUZ> | |
+ %lhs_block = tensor.expand_shape %lhs_flat_block [[0, 1, 2]] output_shape [1, 128, 128] : tensor<16384xf8E4M3FNUZ> into !mexp_block_in_f8 | |
%lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0_lhs, %gko] [1, 1, 16] [1, 1, 1] : !mexp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ> | |
%lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ> | |
%lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1_lhs, %gko] [1, 1, 16] [1, 1, 1] : !mexp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment