Last active
November 4, 2024 17:14
-
-
Save bjacob/20dbbb202f67ab48c167dd8191ad2d87 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
@@ -987,6 +987,29 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( | |
return d.kind == TileSwizzle::Dim::Kind::CrossThread; | |
}); | |
+ // Adjustment needed on RDNA3 where the same data is read by 2 threads and the | |
+ // intrinsic thread-grid is correspondingly 2x smaller than subgroup size. | |
+ // We can't recover that from the unrolled tileSwizzle, where these intrinsic | |
+ // level dimensions are mixed with expansion to multiple subgroups, so we have | |
+ // to go back to the intrinsicSwizzle here. | |
+ TileSwizzle intrinsicSwizzle = | |
+ getIntrinsicSwizzle(getIntrinsic().getValue(), fragment); | |
+ SmallVector<int64_t> intrinsicThreadSizes = | |
+ sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) { | |
+ return d.kind == TileSwizzle::Dim::Kind::CrossThread; | |
+ }); | |
+ int64_t subgroupThreadIdWrappingValue = | |
+ ShapedType::getNumElements(intrinsicThreadSizes); | |
+ if (subgroupThreadIdWrappingValue != getSubgroupSize()) { | |
+ // For now only support the special case that happens on RDNA3. | |
+ assert(getSubgroupSize() == 2 * subgroupThreadIdWrappingValue); | |
+ assert(llvm::isPowerOf2_64(subgroupThreadIdWrappingValue)); | |
+ threadId = | |
+ builder.create<arith::AndIOp>(loc, threadId, | |
+ builder.create<arith::ConstantIndexOp>( | |
+ loc, ~subgroupThreadIdWrappingValue)); | |
+ } | |
+ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment