Created
June 20, 2025 03:10
-
-
Save bjacob/a132e70ea0f7ceba4554f352a45bcd1c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 5a9ea4399bdf0eafa4136b86e97831b7da6279b3 | |
Author: Benoit Jacob <[email protected]> | |
Date: Thu Jun 19 19:37:00 2025 -0700 | |
bufferization-fixes | |
Signed-off-by: Benoit Jacob <[email protected]> | |
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp | |
index 7ca1f240f4..714f9bf681 100644 | |
--- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp | |
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp | |
@@ -148,16 +148,8 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() { | |
// This type converter converts tensor types to memref types when no exact | |
// memref type can be inferred from the context. | |
- options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, | |
+ options.unknownTypeConverterFn = [](TensorType tensorType, Attribute memorySpace, | |
const BufferizationOptions &options) { | |
- auto tensorType = llvm::cast<TensorType>(value.getType()); | |
- | |
- // Special rule for ConstantOps: These always lower to some memref with a | |
- // static identity layout. | |
- if (value.getDefiningOp<arith::ConstantOp>()) | |
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, | |
- memorySpace); | |
- | |
// Default case: Fully dynamic layout map for best compatibility. | |
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, | |
memorySpace); | |
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp | |
index 886e1a6ac5..f4c6e101de 100644 | |
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp | |
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp | |
@@ -886,16 +886,8 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() { | |
// This type converter converts tensor types to memref types when no exact | |
// memref type can be inferred from the context. | |
- options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, | |
+ options.unknownTypeConverterFn = [](TensorType tensorType, Attribute memorySpace, | |
const BufferizationOptions &options) { | |
- auto tensorType = llvm::cast<TensorType>(value.getType()); | |
- | |
- // Special rule for ConstantOps: These always lower to some memref with a | |
- // static identity layout. | |
- if (value.getDefiningOp<arith::ConstantOp>()) | |
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, | |
- memorySpace); | |
- | |
// Default case: Fully dynamic layout map for best compatibility. | |
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, | |
memorySpace); | |
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp | |
index 3ee12df080..683b870daa 100644 | |
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp | |
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp | |
@@ -14,6 +14,7 @@ | |
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | |
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" | |
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | |
+#include "mlir/Dialect/MemRef/IR/MemRef.h" | |
#include "mlir/IR/BuiltinTypes.h" | |
#include "mlir/IR/Value.h" | |
@@ -108,7 +109,7 @@ struct BarrierRegionOpBufferizationInterface | |
SmallVector<Value> &invocationStack) const { | |
auto barrierOp = cast<IREE::GPU::BarrierRegionOp>(op); | |
- FailureOr<BaseMemRefType> memrefType = failure(); | |
+ FailureOr<mlir::bufferization::BufferLikeType> memrefType = failure(); | |
if (auto opResult = dyn_cast<OpResult>(value)) { | |
int64_t resultNum = opResult.getResultNumber(); | |
memrefType = bufferization::getBufferType( | |
@@ -121,7 +122,7 @@ struct BarrierRegionOpBufferizationInterface | |
} | |
if (failed(memrefType)) | |
return failure(); | |
- return memrefType; | |
+ return cast<BaseMemRefType>(*memrefType); | |
} | |
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | |
@@ -148,7 +149,7 @@ struct BarrierRegionOpBufferizationInterface | |
} | |
tensorizedOperands.push_back(rewriter | |
.create<bufferization::ToTensorOp>( | |
- replacement.getLoc(), replacement) | |
+ replacement.getLoc(), memref::getTensorTypeFromMemRefType(replacement.getType()), replacement) | |
.getResult()); | |
} | |
@@ -205,7 +206,7 @@ struct ValueBarrierOpBufferizationInterface | |
state, invocationStack); | |
if (failed(srcMemrefType)) | |
return failure(); | |
- return srcMemrefType; | |
+ return cast<BaseMemRefType>(*srcMemrefType); | |
} | |
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | |
@@ -341,8 +342,9 @@ struct BufferResourceCastOpBufferizationInterface | |
if (failed(srcMemrefType)) | |
return failure(); | |
- if (!hasStorageBufferMemSpace(srcMemrefType.value())) { | |
- return srcMemrefType; | |
+ auto baseMemrefType = cast<BaseMemRefType>(srcMemrefType.value()); | |
+ if (!hasStorageBufferMemSpace(baseMemrefType)) { | |
+ return baseMemrefType; | |
} | |
auto rankedSrcType = cast<MemRefType>(srcMemrefType.value()); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment