Created
November 4, 2024 20:51
-
-
Save bjacob/3bacc2dc3aa4a2c833e5e562edcac347 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 8bd7e0c21a5571928e3d918076dd598fc8d2f3b9 | |
Author: Benoit Jacob <[email protected]> | |
Date: Thu Nov 14 05:51:31 2024 -0800 | |
fix | |
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
index a89f7fced8..9b661effb5 100644 | |
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
@@ -963,7 +963,6 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( | |
Value threadId, ArrayRef<int64_t> permutation, | |
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes, | |
SmallVector<OpFoldResult> &strides) const { | |
- // Get the swizzle describing the internal layout of this fragment. | |
TileSwizzle swizzle = getSwizzle(*this, fragment); | |
LLVM_DEBUG({ | |
@@ -972,59 +971,69 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( | |
DBGS() << " swizzle: " << swizzle << "\n"; | |
}); | |
- // Populate tile sizes. | |
MLIRContext *ctx = builder.getContext(); | |
SmallVector<OpFoldResult> tileSizes = getAsIndexOpFoldResult( | |
ctx, sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) { | |
return d.kind != TileSwizzle::Dim::Kind::CrossThread; | |
})); | |
- // Populate tile offsets by delinearizing threadId over the CrossThread dims. | |
- // Since the AffineDelinearizeIndexOp does not bound the input index, we | |
- // must bound the threadId by the product of the offset ranges. | |
- SmallVector<int64_t> tileOffsetsBasis = | |
+ // Most of the rest of this function is the computation of the offsets. | |
+ // The basic idea is to delinearize the threadId over the basis of | |
+ // cross-thread dimensions. The main subtlety is that the TileSwizzles refer | |
+ // only to the layout dims, and do not reflect a possible additional | |
+ // thread-distribution-only dimension present on some architectures (RDNA3). | |
+ // When such an extra dim exists, multiple threads are reading the same data. | |
+ // So we need to distinguish layoutThreadSizes vs. distributionThreadSizes. | |
+ SmallVector<int64_t> layoutThreadSizes = | |
sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) { | |
return d.kind == TileSwizzle::Dim::Kind::CrossThread; | |
}); | |
- | |
- // Adjustment needed on RDNA3 where the same data is read by 2 threads and the | |
- // intrinsic thread-grid is correspondingly 2x smaller than subgroup size. | |
- // We can't recover that from the unrolled tileSwizzle, where these intrinsic | |
- // level dimensions are mixed with expansion to multiple subgroups, so we have | |
- // to go back to the intrinsicSwizzle here. | |
+ // In layoutThreadSizes, intrinsic level dimensions are mixed with expansion | |
+ // to multiple subgroups, so in order to tell if there are additional | |
+ // distribution-only thread dimensions, we need to get back to the intrinsic. | |
TileSwizzle intrinsicSwizzle = | |
getIntrinsicSwizzle(getIntrinsic().getValue(), fragment); | |
- SmallVector<int64_t> intrinsicThreadSizes = | |
+ SmallVector<int64_t> intrinsicLayoutThreadSizes = | |
sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) { | |
return d.kind == TileSwizzle::Dim::Kind::CrossThread; | |
}); | |
- int64_t subgroupThreadIdWrappingValue = | |
- ShapedType::getNumElements(intrinsicThreadSizes); | |
- if (subgroupThreadIdWrappingValue != getSubgroupSize()) { | |
- // For now only support the special case that happens on RDNA3. | |
- assert(getSubgroupSize() == 2 * subgroupThreadIdWrappingValue); | |
- assert(llvm::isPowerOf2_64(subgroupThreadIdWrappingValue)); | |
- threadId = | |
- builder.create<arith::AndIOp>(loc, threadId, | |
- builder.create<arith::ConstantIndexOp>( | |
- loc, ~subgroupThreadIdWrappingValue)); | |
- } | |
- | |
- // Bound for threadId is the product of tileOffsetsBasis. | |
+ int64_t intrinsicLayoutThreadBound = | |
+ ShapedType::getNumElements(intrinsicLayoutThreadSizes); | |
+ SmallVector<int64_t> distributionThreadSizes = layoutThreadSizes; | |
+ int distributionOnlyDimIdx = | |
+ distributionThreadSizes.size() - intrinsicLayoutThreadSizes.size(); | |
+ // Now we are able to tell if there is an extra distribution-only dimension. | |
+ bool hasDistributionOnlyDim = intrinsicLayoutThreadBound < getSubgroupSize(); | |
+ if (hasDistributionOnlyDim) { | |
+ // Insert the extra distribution-only dimension. This will need to be paired | |
+ // below with erasing the corresponding dim out of the delinearized indices. | |
+ distributionThreadSizes.insert( | |
+ distributionThreadSizes.begin() + distributionOnlyDimIdx, | |
+ getSubgroupSize() / intrinsicLayoutThreadBound); | |
+ } | |
+ | |
+ // AffineDelinearizeIndexOp requires an in-bounds input index, so we bound it. | |
OpFoldResult threadIdBound = | |
- builder.getIndexAttr(ShapedType::getNumElements(tileOffsetsBasis)); | |
+ builder.getIndexAttr(ShapedType::getNumElements(distributionThreadSizes)); | |
AffineExpr d0 = builder.getAffineDimExpr(0), d1 = builder.getAffineDimExpr(1); | |
OpFoldResult boundedThreadId = affine::makeComposedFoldedAffineApply( | |
builder, loc, {d0 % d1}, {threadId, threadIdBound}); | |
+ // Obtain the offsets from delinearization along the distributionThreadSizes. | |
SmallVector<OpFoldResult> tileOffsets = | |
builder | |
.create<affine::AffineDelinearizeIndexOp>( | |
loc, | |
getValueOrCreateConstantIndexOp(builder, loc, boundedThreadId), | |
- getAsIndexOpFoldResult(ctx, tileOffsetsBasis)) | |
+ getAsIndexOpFoldResult(ctx, distributionThreadSizes)) | |
->getResults(); | |
+ if (hasDistributionOnlyDim) { | |
+ // Erase the delinearized index that corresponds to the extra distribution | |
+ // dimension that we had inserted above. | |
+ tileOffsets.erase(tileOffsets.begin() + distributionOnlyDimIdx); | |
+ } | |
+ | |
// Strides are trivial: each slice is contiguous along the *expanded* dims | |
// even if it may not be contiguous in the flattened layout. | |
SmallVector<OpFoldResult> tileStrides(tileSizes.size(), | |
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py | |
index a87ac3a5ed..dd387f3114 100644 | |
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py | |
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py | |
@@ -196,9 +196,7 @@ def get_test_shapes(shapes_id: ShapesId): | |
# disabled to improve the trade-off between test coverage and build | |
# latency. | |
if shapes_id == ShapesId.DEFAULT: | |
- return [ | |
- TestShape(m=16, k=16, n=16, accumulate=True), | |
- ] | |
+ return get_test_shapes(ShapesId.SMALL) + get_test_shapes(ShapesId.LARGE) | |
if shapes_id == ShapesId.SMALL: | |
return [ | |
# square matrices. Start by the simplest case of 1x1x1. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment