Created
January 16, 2025 15:55
-
-
Save makslevental/49da1d17597950933498288bb1dcd8c6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp | |
--- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp (revision f5e11cc519ea223de2c74c3c14a32f0beb327fae) | |
+++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp (date 1737041209117) | |
@@ -451,6 +451,19 @@ | |
llvm::SetVector<Operation *> &opToRewrite; | |
}; | |
+static SmallVector<Value> | |
+maybeCreateInitialFatPtr(SmallVector<Value> ptr, | |
+ ConversionPatternRewriter &rewriter, | |
+ FatPointers &fatPtrs) { | |
+ if (ptr.size() == 1) { | |
+ auto fatPtrOffset = | |
+ rewriter.create<arith::ConstantIntOp>(ptr[0].getLoc(), 0, 64); | |
+ ptr.push_back(fatPtrOffset); | |
+ fatPtrs[{ptr[0], fatPtrOffset}].canNarrow = true; | |
+ } | |
+ return {ptr[0], ptr[1]}; | |
+} | |
+ | |
/// splat integer offset, keep base | |
class ConvertSplatOp : public PointerCanonicalizationPattern<tt::SplatOp> { | |
public: | |
@@ -459,12 +472,16 @@ | |
LogicalResult | |
matchAndRewrite_(tt::SplatOp splatOp, OneToNOpAdaptor adaptor, | |
ConversionPatternRewriter &rewriter) const override { | |
- ValueRange remappedOperands = adaptor.getSrc(); | |
- if (remappedOperands.size() != 2) | |
- return rewriter.notifyMatchFailure( | |
- splatOp, "expected SplatOp src to have already been remapped"); | |
- Value fatPtrBase = remappedOperands[0]; | |
- Value fatPtrOffset = remappedOperands[1]; | |
+ ValueRange remappedSrc = adaptor.getSrc(); | |
+ Value fatPtrBase = remappedSrc[0]; | |
+ Value fatPtrOffset; | |
+ if (remappedSrc.size() == 2) { | |
+ fatPtrOffset = remappedSrc[1]; | |
+ } else { | |
+ fatPtrOffset = | |
+ rewriter.create<arith::ConstantIntOp>(fatPtrBase.getLoc(), 0, 64); | |
+ fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow = true; | |
+ } | |
if (!llvm::isa<tt::PointerType>(fatPtrBase.getType())) | |
return rewriter.notifyMatchFailure(splatOp, | |
"non tt.ptr base unimplemented"); | |
@@ -535,10 +552,14 @@ | |
LogicalResult | |
matchAndRewrite_(tt::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, | |
ConversionPatternRewriter &rewriter) const override { | |
- ValueRange remappedPtr = adaptor.getPtr(); | |
- if (remappedPtr.size() != 2) | |
- return rewriter.notifyMatchFailure( | |
- addPtrOp, "expected AddPtrOp Ptr to have already been remapped"); | |
+ SmallVector<Value> remappedPtr = adaptor.getPtr(); | |
+ if (remappedPtr.size() == 1) { | |
+ auto fatPtrOffset = | |
+ rewriter.create<arith::ConstantIntOp>(remappedPtr[0].getLoc(), 0, 64); | |
+ remappedPtr.push_back(fatPtrOffset); | |
+ fatPtrs[{remappedPtr[0], fatPtrOffset}].canNarrow = true; | |
+ } | |
+ | |
ValueRange nonRemappedOffset = adaptor.getOffset(); | |
if (nonRemappedOffset.size() != 1) | |
return rewriter.notifyMatchFailure( | |
@@ -901,14 +922,22 @@ | |
matchAndRewrite_(arith::SelectOp selectOp, OneToNOpAdaptor adaptor, | |
ConversionPatternRewriter &rewriter) const override { | |
ArrayRef<ValueRange> remappedOperands = adaptor.getOperands(); | |
- if (remappedOperands[1].size() != 2 || remappedOperands[2].size() != 2) | |
- return rewriter.notifyMatchFailure( | |
- selectOp, "expected adaptor to have had both true and false operands " | |
- "already remapped"); | |
// If both have been traversed, then we can rewrite select of pointers as a | |
// select of base and offset | |
- ValueRange fatPtrTrue = remappedOperands[1]; | |
- ValueRange fatPtrFalse = remappedOperands[2]; | |
+ SmallVector<Value> fatPtrTrue = remappedOperands[1]; | |
+ if (fatPtrTrue.size() == 1) { | |
+ auto fatPtrOffset = | |
+ rewriter.create<arith::ConstantIntOp>(fatPtrTrue[0].getLoc(), 0, 64); | |
+ fatPtrTrue.push_back(fatPtrOffset); | |
+ fatPtrs[{fatPtrTrue[0], fatPtrOffset}].canNarrow = true; | |
+ } | |
+ SmallVector<Value> fatPtrFalse = remappedOperands[2]; | |
+ if (fatPtrFalse.size() == 1) { | |
+ auto fatPtrOffset = | |
+ rewriter.create<arith::ConstantIntOp>(fatPtrFalse[0].getLoc(), 0, 64); | |
+ fatPtrFalse.push_back(fatPtrOffset); | |
+ fatPtrs[{fatPtrFalse[0], fatPtrOffset}].canNarrow = true; | |
+ } | |
// Simple case of a scalar select: update the base pointer | |
if (!isa<RankedTensorType>(selectOp.getType())) { | |
auto newSelectOp = rewriter.create<arith::SelectOp>( | |
@@ -1046,10 +1075,13 @@ | |
LogicalResult | |
matchAndRewrite_(tt::LoadOp loadOp, OneToNOpAdaptor adaptor, | |
ConversionPatternRewriter &rewriter) const override { | |
- ValueRange fatPtr = adaptor.getPtr(); | |
- if (fatPtr.size() != 2) | |
- return rewriter.notifyMatchFailure( | |
- loadOp, "expected LoadOp ptr to have already been remapped"); | |
+ SmallVector<Value> fatPtr = adaptor.getPtr(); | |
+ if (fatPtr.size() == 1) { | |
+ auto fatPtrOffset = | |
+ rewriter.create<arith::ConstantIntOp>(fatPtr[0].getLoc(), 0, 64); | |
+ fatPtr.push_back(fatPtrOffset); | |
+ fatPtrs[{fatPtr[0], fatPtrOffset}].canNarrow = true; | |
+ } | |
Value fatPtrBase = fatPtr[0]; | |
Value fatPtrOffset = fatPtr[1]; | |
Location curLoc = loadOp.getLoc(); | |
@@ -1104,44 +1136,6 @@ | |
} | |
}; | |
-/// tt.func gets rewritten differently from all of the other ops - the op itself | |
-/// is not rewritten but all tt.ptr args are rewritten (all uses) to be | |
-/// %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr. | |
-/// This unrealized_cast remains through out the first pass of the dialect | |
-/// conversion and is then materialized in the second pass | |
-/// (ConvertUnrealizedConversionCastOp). | |
-class ConvertFuncOp : public PointerCanonicalizationPattern<tt::FuncOp> { | |
-public: | |
- using PointerCanonicalizationPattern::PointerCanonicalizationPattern; | |
- | |
- LogicalResult | |
- matchAndRewrite_(tt::FuncOp funcOp, OneToNOpAdaptor adaptor, | |
- ConversionPatternRewriter &rewriter) const override { | |
- int64_t bitness = 64; | |
- rewriter.setInsertionPointToStart(&funcOp.getBody().front()); | |
- rewriter.modifyOpInPlace(funcOp, [&] { | |
- for (auto [idx, arg] : llvm::enumerate(funcOp.getArguments())) { | |
- // The pointer argument needs to be a scalar | |
- if (!isa<tt::PointerType>(arg.getType())) | |
- continue; | |
- if (auto pointerRangeAttr = | |
- funcOp.getArgAttrOfType<IntegerAttr>(idx, "tt.pointer_range")) | |
- bitness = pointerRangeAttr.getInt(); | |
- Value zeroOffset = | |
- rewriter.create<arith::ConstantIntOp>(funcOp.getLoc(), 0, bitness); | |
- auto dummyCast = rewriter.create<UnrealizedConversionCastOp>( | |
- arg.getLoc(), TypeRange{arg.getType()}, ValueRange{arg}); | |
- rewriter.replaceUsesOfBlockArgument(arg, dummyCast.getResult(0)); | |
- // TODO(max): why is this true? | |
- fatPtrs[{arg, zeroOffset}].canNarrow = true; | |
- rewriter.replaceOpWithMultiple(dummyCast, {{arg, zeroOffset}}); | |
- } | |
- }); | |
- | |
- return success(); | |
- } | |
-}; | |
- | |
/// No-op to make conversion framework happy. | |
class ConvertReturnOp : public PointerCanonicalizationPattern<tt::ReturnOp> { | |
public: | |
@@ -1305,7 +1299,6 @@ | |
tt::FuncOp func = funcOps[0]; | |
llvm::SetVector<Operation *> opsToRewrite; | |
- opsToRewrite.insert(func); | |
for (auto arg : func.getArguments()) { | |
if (llvm::isa<tt::PointerType>(arg.getType())) { | |
// NB: reusing the same SetVector invalidates the topo order implied by | |
@@ -1338,12 +1331,12 @@ | |
FatPointers fatPrs; | |
- patterns.add<ConvertFuncOp, ConvertBroadcastOp, ConvertSplatOp, | |
- ConvertAddPtrOp, ConvertLoadOp, ConvertStoreOp, ConvertSCFForOp, | |
- ConvertSCFYieldOp, ConvertSCFIfOp, ConvertSCFConditionOp, | |
- ConvertSCFWhileOp, ConvertCFCondBranch, ConvertCFBranch, | |
- ConvertArithSelectOp, ConvertReturnOp>(patterns.getContext(), | |
- opsToRewrite, fatPrs); | |
+ patterns | |
+ .add<ConvertBroadcastOp, ConvertSplatOp, ConvertAddPtrOp, ConvertLoadOp, | |
+ ConvertStoreOp, ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp, | |
+ ConvertSCFConditionOp, ConvertSCFWhileOp, ConvertCFCondBranch, | |
+ ConvertCFBranch, ConvertArithSelectOp, ConvertReturnOp>( | |
+ patterns.getContext(), opsToRewrite, fatPrs); | |
ConversionConfig config; | |
config.buildMaterializations = false; | |
if (failed(applyPartialConversion(func, target, std::move(patterns), config))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment