Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created November 4, 2024 18:29
Show Gist options
  • Save bjacob/7c4d05c79d567b4f4ffc624902387ac2 to your computer and use it in GitHub Desktop.
Save bjacob/7c4d05c79d567b4f4ffc624902387ac2 to your computer and use it in GitHub Desktop.
commit f09ea33589816270efc12e9a5371da8b80cd6e2b
Author: Benoit Jacob <[email protected]>
Date: Thu Nov 14 02:11:18 2024 -0800
more fixes
Signed-off-by: Benoit Jacob <[email protected]>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 4406de2e9f..a89f7fced8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -987,6 +987,29 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
return d.kind == TileSwizzle::Dim::Kind::CrossThread;
});
+ // Adjustment needed on RDNA3 where the same data is read by 2 threads and the
+ // intrinsic thread-grid is correspondingly 2x smaller than subgroup size.
+ // We can't recover that from the unrolled tileSwizzle, where these intrinsic
+ // level dimensions are mixed with expansion to multiple subgroups, so we have
+ // to go back to the intrinsicSwizzle here.
+ TileSwizzle intrinsicSwizzle =
+ getIntrinsicSwizzle(getIntrinsic().getValue(), fragment);
+ SmallVector<int64_t> intrinsicThreadSizes =
+ sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) {
+ return d.kind == TileSwizzle::Dim::Kind::CrossThread;
+ });
+ int64_t subgroupThreadIdWrappingValue =
+ ShapedType::getNumElements(intrinsicThreadSizes);
+ if (subgroupThreadIdWrappingValue != getSubgroupSize()) {
+ // For now only support the special case that happens on RDNA3.
+ assert(getSubgroupSize() == 2 * subgroupThreadIdWrappingValue);
+ assert(llvm::isPowerOf2_64(subgroupThreadIdWrappingValue));
+ threadId =
+ builder.create<arith::AndIOp>(loc, threadId,
+ builder.create<arith::ConstantIndexOp>(
+ loc, ~subgroupThreadIdWrappingValue));
+ }
+
// Bound for threadId is the product of tileOffsetsBasis.
OpFoldResult threadIdBound =
builder.getIndexAttr(ShapedType::getNumElements(tileOffsetsBasis));
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index ab3c15744d..22a701efde 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1809,6 +1809,9 @@ iree_generated_e2e_runner_test(
"--iree-opt-data-tiling"
"--iree-global-opt-experimental-rocm-data-tiling"
"--iree-global-opt-enable-early-materialization=true"
+ RUNNER_ARGS
+ "--require_exact_results=false"
+ "--acceptable_fp_delta=1e-04"
LABELS
"noasan"
"nomsan"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment