Skip to content

Instantly share code, notes, and snippets.

@bjacob
Last active November 4, 2024 17:14
Show Gist options
  • Save bjacob/20dbbb202f67ab48c167dd8191ad2d87 to your computer and use it in GitHub Desktop.
Save bjacob/20dbbb202f67ab48c167dd8191ad2d87 to your computer and use it in GitHub Desktop.
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -987,6 +987,29 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
return d.kind == TileSwizzle::Dim::Kind::CrossThread;
});
+ // Adjustment needed on RDNA3 where the same data is read by 2 threads and the
+ // intrinsic thread-grid is correspondingly 2x smaller than subgroup size.
+ // We can't recover that from the unrolled tileSwizzle, where these intrinsic
+ // level dimensions are mixed with expansion to multiple subgroups, so we have
+ // to go back to the intrinsicSwizzle here.
+ TileSwizzle intrinsicSwizzle =
+ getIntrinsicSwizzle(getIntrinsic().getValue(), fragment);
+ SmallVector<int64_t> intrinsicThreadSizes =
+ sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) {
+ return d.kind == TileSwizzle::Dim::Kind::CrossThread;
+ });
+ int64_t subgroupThreadIdWrappingValue =
+ ShapedType::getNumElements(intrinsicThreadSizes);
+ if (subgroupThreadIdWrappingValue != getSubgroupSize()) {
+ // For now only support the special case that happens on RDNA3.
+ assert(getSubgroupSize() == 2 * subgroupThreadIdWrappingValue);
+ assert(llvm::isPowerOf2_64(subgroupThreadIdWrappingValue));
+ threadId =
+ builder.create<arith::AndIOp>(loc, threadId,
+ builder.create<arith::ConstantIndexOp>(
+ loc, ~subgroupThreadIdWrappingValue));
+ }
+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment