Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created November 5, 2024 20:39
Show Gist options
  • Save bjacob/7d7995176d3c1638fe9a28d6aadbc5ab to your computer and use it in GitHub Desktop.
Save bjacob/7d7995176d3c1638fe9a28d6aadbc5ab to your computer and use it in GitHub Desktop.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 56f5a927cc..3959baec1b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -693,33 +693,22 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16: {
// Generate mfma's for K with unrolled kernels.
- const int64_t unrollKFactor = 2;
- auto [m, n, k] = getMNKShape();
- // Compute actual/native intrinsic's K size.
- int64_t nativeKSize = k / unrollKFactor;
-
- auto [aType, bType, cType] = getABCVectorTypes();
- if (aType.getShape()[0] != bType.getShape()[0]) {
- // Currently only support case where lhs and rhs
- // has same vectorWidth.
- return failure();
- }
- int64_t vectorWidth = aType.getShape()[0] / unrollKFactor;
- for (int i = 0; i < unrollKFactor; i++) {
- int64_t offset = vectorWidth * i;
- Value sliced_lhs = builder.create<vector::ExtractStridedSliceOp>(
- loc, lhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
- ArrayRef<int64_t>{1});
- Value sliced_rhs = builder.create<vector::ExtractStridedSliceOp>(
- loc, rhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
- ArrayRef<int64_t>{1});
- acc = builder
- .create<amdgpu::MFMAOp>(loc, resultType, m, n, nativeKSize,
- getBlockSize(), sliced_lhs, sliced_rhs,
- acc)
- .getResult();
- }
- return acc;
+ auto realIntrinsic =
+ getIntrinsic().getValue() == MMAIntrinsic::VMFMA_F32_16x16x32_F16
+ ? MMAIntrinsic::MFMA_F32_16x16x16_F16
+ : MMAIntrinsic::MFMA_F32_32x32x8_F16;
+ auto unrolled = DataTiledMMAAttr::get(
+ builder.getContext(),
+ MMAIntrinsicAttr::get(builder.getContext(), realIntrinsic),
+ /*other unroll factors...=*/1, 1, 1, 1, /*unroll_k=*/2);
+ auto [aType, bType, cType] = unrolled.getABCVectorTypes();
+ Value xlhs = builder.create<vector::ShapeCastOp>(loc, aType, lhs);
+ Value xrhs = builder.create<vector::ShapeCastOp>(loc, bType, rhs);
+ Value xacc = builder.create<vector::ShapeCastOp>(loc, cType, acc);
+ Value xmma =
+ *unrolled.buildMmaOperation(builder, loc, cType, xlhs, xrhs, xacc);
+ return builder.create<vector::ShapeCastOp>(loc, resultType, xmma)
+ .getResult();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment