Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created May 15, 2025 20:26
Show Gist options
  • Save bjacob/eecc65139224fde529c93491cc956d55 to your computer and use it in GitHub Desktop.
Save bjacob/eecc65139224fde529c93491cc956d55 to your computer and use it in GitHub Desktop.
diff --git a/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir b/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
index 53a82acdd7..d3d72d62c6 100644
--- a/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir
@@ -33,6 +33,11 @@
!mshared_f8 = memref<128x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>
!mshared_exp_f8 = memref<8x16x4x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>
+!unified_flat_shared_f8 = memref<49152xf8E4M3FNUZ, #gpu.address_space<workgroup>>
+!flat_shared_f8_at_offset_16384 = memref<32768xf8E4M3FNUZ, strided<[1], offset: 16384>, #gpu.address_space<workgroup>>
+!shared_f8_at_offset_16384 = memref<256x128xf8E4M3FNUZ, strided<[128, 1], offset: 16384>, #gpu.address_space<workgroup>>
+!shared_exp_f8_at_offset_16384 = memref<16x16x4x32xf8E4M3FNUZ, strided<[2048, 128, 32, 1], offset: 16384>, #gpu.address_space<workgroup>>
+
#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (j, k)>,
@@ -920,19 +925,26 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
+ %c8192 = arith.constant 8192 : index
+
%cst = arith.constant 0.0 : f8E4M3FNUZ
- %lhs_shared_base = memref.alloc() : !mflat_shared_f8
- %rhs_shared_base = memref.alloc() : !flat_shared_f8
+ %unified_shared_alloc = memref.alloc() : !unified_flat_shared_f8
+
+ %lhs_shared_base = memref.subview %unified_shared_alloc[0][16384][1] : !unified_flat_shared_f8 to !mflat_shared_f8
+ %rhs_shared_base = memref.subview %unified_shared_alloc[16384][32768][1] : !unified_flat_shared_f8 to !flat_shared_f8_at_offset_16384
%dim = tensor.dim %rhs_base, %c1 : !in_ty_f8
- %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty_f8
- %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_f8
+ %dim_times_4 = arith.muli %dim, %c4 : index
+ %swizzle_stride = arith.minsi %dim_times_4, %c8192 : index
+
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%swizzle_stride) : !mexp_in_ty_f8
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%swizzle_stride) : !in_ty_f8
%lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !mflat_shared_f8
- %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8_at_offset_16384
%lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [128, 128] : !mflat_shared_f8 into !mshared_f8
- %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 128] : !flat_shared_f8 into !shared_f8
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 128] : !flat_shared_f8_at_offset_16384 into !shared_f8_at_offset_16384
%lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 128, 128] [1, 1, 1] : !mexp_in_ty_f8 to !mexp_block_in_f8
%rhs_init = tensor.extract_slice %rhs [0, 0] [256, 128] [1, 1] : !in_ty_f8 to !block_in_f8
@@ -949,11 +961,11 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
%vec = arith.muli %delin#1, %c16 : index
%rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
%rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
- vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8_at_offset_16384
} {mapping = [#gpu.thread<linear_dim_0>]}
%lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [8, 16, 4, 32] : !mshared_f8 into !mshared_exp_f8
- %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 32] : !shared_f8 into !shared_exp_f8
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 32] : !shared_f8_at_offset_16384 into !shared_exp_f8_at_offset_16384
%0 = tensor.empty() : tensor<1x8x16x16x16xf32>
%1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x8x16x16x16xf32> {
@@ -989,7 +1001,7 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
%3 = scf.for %i = %c128 to %dim step %c128 iter_args(%iter = %2) -> vector<4x4x1x4xf32> {
%lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
- %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8_at_offset_16384, vector<4x1x2x8xf8E4M3FNUZ>
%lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
%rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
@@ -1009,7 +1021,7 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
rocdl.sched.barrier 0
%lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
- %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8_at_offset_16384, vector<4x1x2x8xf8E4M3FNUZ>
%lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
%rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
@@ -1036,10 +1048,10 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
gpu.barrier
rocdl.sched.barrier 0
- vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
- vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
- vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
- vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8_at_offset_16384
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8_at_offset_16384
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8_at_offset_16384
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8_at_offset_16384
vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0_lhs, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !mshared_f8
vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1_lhs, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !mshared_f8
@@ -1066,7 +1078,7 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
// Epilogue
%lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
- %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8_at_offset_16384, vector<4x1x2x8xf8E4M3FNUZ>
%lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
%rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
@@ -1077,7 +1089,7 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
} : vector<4x2x1x8xf8E4M3FNUZ>, vector<4x2x1x8xf8E4M3FNUZ> into vector<4x4x1x4xf32>
%lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
- %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x2x8xf8E4M3FNUZ>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8_at_offset_16384, vector<4x1x2x8xf8E4M3FNUZ>
%lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
%rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x8xf8E4M3FNUZ> to vector<4x2x1x8xf8E4M3FNUZ>
@@ -1787,9 +1799,9 @@ transform.named_sequence
attributes { iree_codegen.tuning_spec_entrypoint } {
%res = transform.foreach_match in %variant_op
// Match pingpong variants.
- @match_mmt_f16_f16_f32_large_expanded -> @apply_expanded_pingpong_op_config,
- @match_mmt_f8_f8_f32_large_expanded -> @apply_expanded_f8_pingpong_op_config,
- @match_mmt_f16_f16_f32_large -> @apply_pingpong_op_config,
+ //@match_mmt_f16_f16_f32_large_expanded -> @apply_expanded_pingpong_op_config,
+ //@match_mmt_f8_f8_f32_large_expanded -> @apply_expanded_f8_pingpong_op_config,
+ //@match_mmt_f16_f16_f32_large -> @apply_pingpong_op_config,
// Medium pingpong variants are lower priority.
@match_mmt_f16_f16_f32_medium_expanded -> @apply_expanded_medium_pingpong_op_config,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment