Created
November 5, 2024 20:39
-
-
Save bjacob/7d7995176d3c1638fe9a28d6aadbc5ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
index 56f5a927cc..3959baec1b 100644 | |
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | |
@@ -693,33 +693,22 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, | |
case MMAIntrinsic::VMFMA_F32_16x16x32_F16: | |
case MMAIntrinsic::VMFMA_F32_32x32x16_F16: { | |
// Generate mfma's for K with unrolled kernels. | |
- const int64_t unrollKFactor = 2; | |
- auto [m, n, k] = getMNKShape(); | |
- // Compute actual/native intrinsic's K size. | |
- int64_t nativeKSize = k / unrollKFactor; | |
- | |
- auto [aType, bType, cType] = getABCVectorTypes(); | |
- if (aType.getShape()[0] != bType.getShape()[0]) { | |
- // Currently only support case where lhs and rhs | |
- // has same vectorWidth. | |
- return failure(); | |
- } | |
- int64_t vectorWidth = aType.getShape()[0] / unrollKFactor; | |
- for (int i = 0; i < unrollKFactor; i++) { | |
- int64_t offset = vectorWidth * i; | |
- Value sliced_lhs = builder.create<vector::ExtractStridedSliceOp>( | |
- loc, lhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth}, | |
- ArrayRef<int64_t>{1}); | |
- Value sliced_rhs = builder.create<vector::ExtractStridedSliceOp>( | |
- loc, rhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth}, | |
- ArrayRef<int64_t>{1}); | |
- acc = builder | |
- .create<amdgpu::MFMAOp>(loc, resultType, m, n, nativeKSize, | |
- getBlockSize(), sliced_lhs, sliced_rhs, | |
- acc) | |
- .getResult(); | |
- } | |
- return acc; | |
+ auto realIntrinsic = | |
+ getIntrinsic().getValue() == MMAIntrinsic::VMFMA_F32_16x16x32_F16 | |
+ ? MMAIntrinsic::MFMA_F32_16x16x16_F16 | |
+ : MMAIntrinsic::MFMA_F32_32x32x8_F16; | |
+ auto unrolled = DataTiledMMAAttr::get( | |
+ builder.getContext(), | |
+ MMAIntrinsicAttr::get(builder.getContext(), realIntrinsic), | |
+ /*other unroll factors...=*/1, 1, 1, 1, /*unroll_k=*/2); | |
+ auto [aType, bType, cType] = unrolled.getABCVectorTypes(); | |
+ Value xlhs = builder.create<vector::ShapeCastOp>(loc, aType, lhs); | |
+ Value xrhs = builder.create<vector::ShapeCastOp>(loc, bType, rhs); | |
+ Value xacc = builder.create<vector::ShapeCastOp>(loc, cType, acc); | |
+ Value xmma = | |
+ *unrolled.buildMmaOperation(builder, loc, cType, xlhs, xrhs, xacc); | |
+ return builder.create<vector::ShapeCastOp>(loc, resultType, xmma) | |
+ .getResult(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment