Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created June 20, 2025 03:10
Show Gist options
  • Save bjacob/a132e70ea0f7ceba4554f352a45bcd1c to your computer and use it in GitHub Desktop.
Save bjacob/a132e70ea0f7ceba4554f352a45bcd1c to your computer and use it in GitHub Desktop.
commit 5a9ea4399bdf0eafa4136b86e97831b7da6279b3
Author: Benoit Jacob <[email protected]>
Date: Thu Jun 19 19:37:00 2025 -0700
bufferization-fixes
Signed-off-by: Benoit Jacob <[email protected]>
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index 7ca1f240f4..714f9bf681 100644
--- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -148,16 +148,8 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() {
// This type converter converts tensor types to memref types when no exact
// memref type can be inferred from the context.
- options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+ options.unknownTypeConverterFn = [](TensorType tensorType, Attribute memorySpace,
const BufferizationOptions &options) {
- auto tensorType = llvm::cast<TensorType>(value.getType());
-
- // Special rule for ConstantOps: These always lower to some memref with a
- // static identity layout.
- if (value.getDefiningOp<arith::ConstantOp>())
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
-
// Default case: Fully dynamic layout map for best compatibility.
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 886e1a6ac5..f4c6e101de 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -886,16 +886,8 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() {
// This type converter converts tensor types to memref types when no exact
// memref type can be inferred from the context.
- options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+ options.unknownTypeConverterFn = [](TensorType tensorType, Attribute memorySpace,
const BufferizationOptions &options) {
- auto tensorType = llvm::cast<TensorType>(value.getType());
-
- // Special rule for ConstantOps: These always lower to some memref with a
- // static identity layout.
- if (value.getDefiningOp<arith::ConstantOp>())
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
-
// Default case: Fully dynamic layout map for best compatibility.
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
index 3ee12df080..683b870daa 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
@@ -108,7 +109,7 @@ struct BarrierRegionOpBufferizationInterface
SmallVector<Value> &invocationStack) const {
auto barrierOp = cast<IREE::GPU::BarrierRegionOp>(op);
- FailureOr<BaseMemRefType> memrefType = failure();
+ FailureOr<mlir::bufferization::BufferLikeType> memrefType = failure();
if (auto opResult = dyn_cast<OpResult>(value)) {
int64_t resultNum = opResult.getResultNumber();
memrefType = bufferization::getBufferType(
@@ -121,7 +122,7 @@ struct BarrierRegionOpBufferizationInterface
}
if (failed(memrefType))
return failure();
- return memrefType;
+ return cast<BaseMemRefType>(*memrefType);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -148,7 +149,7 @@ struct BarrierRegionOpBufferizationInterface
}
tensorizedOperands.push_back(rewriter
.create<bufferization::ToTensorOp>(
- replacement.getLoc(), replacement)
+ replacement.getLoc(), memref::getTensorTypeFromMemRefType(replacement.getType()), replacement)
.getResult());
}
@@ -205,7 +206,7 @@ struct ValueBarrierOpBufferizationInterface
state, invocationStack);
if (failed(srcMemrefType))
return failure();
- return srcMemrefType;
+ return cast<BaseMemRefType>(*srcMemrefType);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -341,8 +342,9 @@ struct BufferResourceCastOpBufferizationInterface
if (failed(srcMemrefType))
return failure();
- if (!hasStorageBufferMemSpace(srcMemrefType.value())) {
- return srcMemrefType;
+ auto baseMemrefType = cast<BaseMemRefType>(srcMemrefType.value());
+ if (!hasStorageBufferMemSpace(baseMemrefType)) {
+ return baseMemrefType;
}
auto rankedSrcType = cast<MemRefType>(srcMemrefType.value());
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment