Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created November 4, 2024 20:51
Show Gist options
  • Save bjacob/3bacc2dc3aa4a2c833e5e562edcac347 to your computer and use it in GitHub Desktop.
Save bjacob/3bacc2dc3aa4a2c833e5e562edcac347 to your computer and use it in GitHub Desktop.
commit 8bd7e0c21a5571928e3d918076dd598fc8d2f3b9
Author: Benoit Jacob <[email protected]>
Date: Thu Nov 14 05:51:31 2024 -0800
fix
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index a89f7fced8..9b661effb5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -963,7 +963,6 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
Value threadId, ArrayRef<int64_t> permutation,
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
- // Get the swizzle describing the internal layout of this fragment.
TileSwizzle swizzle = getSwizzle(*this, fragment);
LLVM_DEBUG({
@@ -972,59 +971,69 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
DBGS() << " swizzle: " << swizzle << "\n";
});
- // Populate tile sizes.
MLIRContext *ctx = builder.getContext();
SmallVector<OpFoldResult> tileSizes = getAsIndexOpFoldResult(
ctx, sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) {
return d.kind != TileSwizzle::Dim::Kind::CrossThread;
}));
- // Populate tile offsets by delinearizing threadId over the CrossThread dims.
- // Since the AffineDelinearizeIndexOp does not bound the input index, we
- // must bound the threadId by the product of the offset ranges.
- SmallVector<int64_t> tileOffsetsBasis =
+ // Most of the rest of this function is the computation of the offsets.
+ // The basic idea is to delinearize the threadId over the basis of
+ // cross-thread dimensions. The main subtlety is that the TileSwizzles refer
+ // only to the layout dims, and do not reflect a possible additional
+ // thread-distribution-only dimension present on some architectures (RDNA3).
+ // When such an extra dim exists, multiple threads are reading the same data.
+ // So we need to distinguish layoutThreadSizes vs. distributionThreadSizes.
+ SmallVector<int64_t> layoutThreadSizes =
sliceSwizzledShape(swizzle, [](TileSwizzle::Dim d) {
return d.kind == TileSwizzle::Dim::Kind::CrossThread;
});
-
- // Adjustment needed on RDNA3 where the same data is read by 2 threads and the
- // intrinsic thread-grid is correspondingly 2x smaller than subgroup size.
- // We can't recover that from the unrolled tileSwizzle, where these intrinsic
- // level dimensions are mixed with expansion to multiple subgroups, so we have
- // to go back to the intrinsicSwizzle here.
+ // In layoutThreadSizes, intrinsic level dimensions are mixed with expansion
+ // to multiple subgroups, so in order to tell if there are additional
+ // distribution-only thread dimensions, we need to get back to the intrinsic.
TileSwizzle intrinsicSwizzle =
getIntrinsicSwizzle(getIntrinsic().getValue(), fragment);
- SmallVector<int64_t> intrinsicThreadSizes =
+ SmallVector<int64_t> intrinsicLayoutThreadSizes =
sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) {
return d.kind == TileSwizzle::Dim::Kind::CrossThread;
});
- int64_t subgroupThreadIdWrappingValue =
- ShapedType::getNumElements(intrinsicThreadSizes);
- if (subgroupThreadIdWrappingValue != getSubgroupSize()) {
- // For now only support the special case that happens on RDNA3.
- assert(getSubgroupSize() == 2 * subgroupThreadIdWrappingValue);
- assert(llvm::isPowerOf2_64(subgroupThreadIdWrappingValue));
- threadId =
- builder.create<arith::AndIOp>(loc, threadId,
- builder.create<arith::ConstantIndexOp>(
- loc, ~subgroupThreadIdWrappingValue));
- }
-
- // Bound for threadId is the product of tileOffsetsBasis.
+ int64_t intrinsicLayoutThreadBound =
+ ShapedType::getNumElements(intrinsicLayoutThreadSizes);
+ SmallVector<int64_t> distributionThreadSizes = layoutThreadSizes;
+ int distributionOnlyDimIdx =
+ distributionThreadSizes.size() - intrinsicLayoutThreadSizes.size();
+ // Now we are able to tell if there is an extra distribution-only dimension.
+ bool hasDistributionOnlyDim = intrinsicLayoutThreadBound < getSubgroupSize();
+ if (hasDistributionOnlyDim) {
+ // Insert the extra distribution-only dimension. This will need to be paired
+ // below with erasing the corresponding dim out of the delinearized indices.
+ distributionThreadSizes.insert(
+ distributionThreadSizes.begin() + distributionOnlyDimIdx,
+ getSubgroupSize() / intrinsicLayoutThreadBound);
+ }
+
+ // AffineDelinearizeIndexOp requires an in-bounds input index, so we bound it.
OpFoldResult threadIdBound =
- builder.getIndexAttr(ShapedType::getNumElements(tileOffsetsBasis));
+ builder.getIndexAttr(ShapedType::getNumElements(distributionThreadSizes));
AffineExpr d0 = builder.getAffineDimExpr(0), d1 = builder.getAffineDimExpr(1);
OpFoldResult boundedThreadId = affine::makeComposedFoldedAffineApply(
builder, loc, {d0 % d1}, {threadId, threadIdBound});
+ // Obtain the offsets from delinearization along the distributionThreadSizes.
SmallVector<OpFoldResult> tileOffsets =
builder
.create<affine::AffineDelinearizeIndexOp>(
loc,
getValueOrCreateConstantIndexOp(builder, loc, boundedThreadId),
- getAsIndexOpFoldResult(ctx, tileOffsetsBasis))
+ getAsIndexOpFoldResult(ctx, distributionThreadSizes))
->getResults();
+ if (hasDistributionOnlyDim) {
+ // Erase the delinearized index that corresponds to the extra distribution
+ // dimension that we had inserted above.
+ tileOffsets.erase(tileOffsets.begin() + distributionOnlyDimIdx);
+ }
+
// Strides are trivial: each slice is contiguous along the *expanded* dims
// even if it may not be contiguous in the flattened layout.
SmallVector<OpFoldResult> tileStrides(tileSizes.size(),
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index a87ac3a5ed..dd387f3114 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -196,9 +196,7 @@ def get_test_shapes(shapes_id: ShapesId):
# disabled to improve the trade-off between test coverage and build
# latency.
if shapes_id == ShapesId.DEFAULT:
- return [
- TestShape(m=16, k=16, n=16, accumulate=True),
- ]
+ return get_test_shapes(ShapesId.SMALL) + get_test_shapes(ShapesId.LARGE)
if shapes_id == ShapesId.SMALL:
return [
# square matrices. Start by the simplest case of 1x1x1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment