Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created June 6, 2025 16:44
Show Gist options
  • Save bjacob/84947a1794dfa87720205db55bc9aa59 to your computer and use it in GitHub Desktop.
Save bjacob/84947a1794dfa87720205db55bc9aa59 to your computer and use it in GitHub Desktop.
index 9b8292d052..ddc1fbc04d 100644
--- a/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
@@ -925,6 +925,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
+ %c16384 = arith.constant 16384 : index
+ %c32768 = arith.constant 32768 : index
+
%cst = arith.constant 0.0 : f8E4M3FNUZ
%lhs_shared_base = memref.alloc() : !mflat_shared_f8
%rhs_shared_base = memref.alloc() : !flat_shared_f8
@@ -933,6 +936,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
%lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty_f8
%rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_f8
+ %lhs_flat = tensor.collapse_shape %lhs [[0, 1, 2]] : !mexp_in_ty_f8 into tensor<?xf8E4M3FNUZ>
+ %rhs_flat = tensor.collapse_shape %rhs [[0, 1]] : !in_ty_f8 into tensor<?xf8E4M3FNUZ>
+
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !mflat_shared_f8
%rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8
@@ -1001,7 +1007,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
rocdl.sched.barrier 0
// Global loads of rhs.
- %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 128] [1, 1] : !in_ty_f8 to !block_in_f8
+ %i_times_rhs_flat_block_size = arith.muli %i, %c32768 : index
+ %rhs_flat_block = tensor.extract_slice %rhs_flat [%i_times_rhs_flat_block_size] [32768] [1] : tensor<?xf8E4M3FNUZ> to tensor<32768xf8E4M3FNUZ>
+ %rhs_block = tensor.expand_shape %rhs_flat_block [[0, 1]] output_shape [256, 128] : tensor<32768xf8E4M3FNUZ> into !block_in_f8
%rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
%rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
%rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
@@ -1021,7 +1029,9 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
rocdl.sched.barrier 0
// Global loads of lhs.
- %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 128, 128] [1, 1, 1] : !mexp_in_ty_f8 to !mexp_block_in_f8
+ %i_times_lhs_flat_block_size = arith.muli %i, %c16384 : index
+ %lhs_flat_block = tensor.extract_slice %lhs_flat [%i_times_lhs_flat_block_size] [16384] [1] : tensor<?xf8E4M3FNUZ> to tensor<16384xf8E4M3FNUZ>
+ %lhs_block = tensor.expand_shape %lhs_flat_block [[0, 1, 2]] output_shape [1, 128, 128] : tensor<16384xf8E4M3FNUZ> into !mexp_block_in_f8
%lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0_lhs, %gko] [1, 1, 16] [1, 1, 1] : !mexp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
%lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
%lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1_lhs, %gko] [1, 1, 16] [1, 1, 1] : !mexp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment