Skip to content

Instantly share code, notes, and snippets.

@makslevental
Created January 16, 2025 15:55
Show Gist options
  • Save makslevental/49da1d17597950933498288bb1dcd8c6 to your computer and use it in GitHub Desktop.
Save makslevental/49da1d17597950933498288bb1dcd8c6 to your computer and use it in GitHub Desktop.
Index: third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp
--- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp (revision f5e11cc519ea223de2c74c3c14a32f0beb327fae)
+++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp (date 1737041209117)
@@ -451,6 +451,19 @@
llvm::SetVector<Operation *> &opToRewrite;
};
+static SmallVector<Value>
+maybeCreateInitialFatPtr(SmallVector<Value> ptr,
+ ConversionPatternRewriter &rewriter,
+ FatPointers &fatPtrs) {
+ if (ptr.size() == 1) {
+ auto fatPtrOffset =
+ rewriter.create<arith::ConstantIntOp>(ptr[0].getLoc(), 0, 64);
+ ptr.push_back(fatPtrOffset);
+ fatPtrs[{ptr[0], fatPtrOffset}].canNarrow = true;
+ }
+ return {ptr[0], ptr[1]};
+}
+
/// splat integer offset, keep base
class ConvertSplatOp : public PointerCanonicalizationPattern<tt::SplatOp> {
public:
@@ -459,12 +472,16 @@
LogicalResult
matchAndRewrite_(tt::SplatOp splatOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- ValueRange remappedOperands = adaptor.getSrc();
- if (remappedOperands.size() != 2)
- return rewriter.notifyMatchFailure(
- splatOp, "expected SplatOp src to have already been remapped");
- Value fatPtrBase = remappedOperands[0];
- Value fatPtrOffset = remappedOperands[1];
+ ValueRange remappedSrc = adaptor.getSrc();
+ Value fatPtrBase = remappedSrc[0];
+ Value fatPtrOffset;
+ if (remappedSrc.size() == 2) {
+ fatPtrOffset = remappedSrc[1];
+ } else {
+ fatPtrOffset =
+ rewriter.create<arith::ConstantIntOp>(fatPtrBase.getLoc(), 0, 64);
+ fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow = true;
+ }
if (!llvm::isa<tt::PointerType>(fatPtrBase.getType()))
return rewriter.notifyMatchFailure(splatOp,
"non tt.ptr base unimplemented");
@@ -535,10 +552,14 @@
LogicalResult
matchAndRewrite_(tt::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- ValueRange remappedPtr = adaptor.getPtr();
- if (remappedPtr.size() != 2)
- return rewriter.notifyMatchFailure(
- addPtrOp, "expected AddPtrOp Ptr to have already been remapped");
+ SmallVector<Value> remappedPtr = adaptor.getPtr();
+ if (remappedPtr.size() == 1) {
+ auto fatPtrOffset =
+ rewriter.create<arith::ConstantIntOp>(remappedPtr[0].getLoc(), 0, 64);
+ remappedPtr.push_back(fatPtrOffset);
+ fatPtrs[{remappedPtr[0], fatPtrOffset}].canNarrow = true;
+ }
+
ValueRange nonRemappedOffset = adaptor.getOffset();
if (nonRemappedOffset.size() != 1)
return rewriter.notifyMatchFailure(
@@ -901,14 +922,22 @@
matchAndRewrite_(arith::SelectOp selectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ArrayRef<ValueRange> remappedOperands = adaptor.getOperands();
- if (remappedOperands[1].size() != 2 || remappedOperands[2].size() != 2)
- return rewriter.notifyMatchFailure(
- selectOp, "expected adaptor to have had both true and false operands "
- "already remapped");
// If both have been traversed, then we can rewrite select of pointers as a
// select of base and offset
- ValueRange fatPtrTrue = remappedOperands[1];
- ValueRange fatPtrFalse = remappedOperands[2];
+ SmallVector<Value> fatPtrTrue = remappedOperands[1];
+ if (fatPtrTrue.size() == 1) {
+ auto fatPtrOffset =
+ rewriter.create<arith::ConstantIntOp>(fatPtrTrue[0].getLoc(), 0, 64);
+ fatPtrTrue.push_back(fatPtrOffset);
+ fatPtrs[{fatPtrTrue[0], fatPtrOffset}].canNarrow = true;
+ }
+ SmallVector<Value> fatPtrFalse = remappedOperands[2];
+ if (fatPtrFalse.size() == 1) {
+ auto fatPtrOffset =
+ rewriter.create<arith::ConstantIntOp>(fatPtrFalse[0].getLoc(), 0, 64);
+ fatPtrFalse.push_back(fatPtrOffset);
+ fatPtrs[{fatPtrFalse[0], fatPtrOffset}].canNarrow = true;
+ }
// Simple case of a scalar select: update the base pointer
if (!isa<RankedTensorType>(selectOp.getType())) {
auto newSelectOp = rewriter.create<arith::SelectOp>(
@@ -1046,10 +1075,13 @@
LogicalResult
matchAndRewrite_(tt::LoadOp loadOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- ValueRange fatPtr = adaptor.getPtr();
- if (fatPtr.size() != 2)
- return rewriter.notifyMatchFailure(
- loadOp, "expected LoadOp ptr to have already been remapped");
+ SmallVector<Value> fatPtr = adaptor.getPtr();
+ if (fatPtr.size() == 1) {
+ auto fatPtrOffset =
+ rewriter.create<arith::ConstantIntOp>(fatPtr[0].getLoc(), 0, 64);
+ fatPtr.push_back(fatPtrOffset);
+ fatPtrs[{fatPtr[0], fatPtrOffset}].canNarrow = true;
+ }
Value fatPtrBase = fatPtr[0];
Value fatPtrOffset = fatPtr[1];
Location curLoc = loadOp.getLoc();
@@ -1104,44 +1136,6 @@
}
};
-/// tt.func gets rewritten differently from all of the other ops - the op itself
-/// is not rewritten but all tt.ptr args are rewritten (all uses) to be
-/// %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr.
-/// This unrealized_cast remains through out the first pass of the dialect
-/// conversion and is then materialized in the second pass
-/// (ConvertUnrealizedConversionCastOp).
-class ConvertFuncOp : public PointerCanonicalizationPattern<tt::FuncOp> {
-public:
- using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
-
- LogicalResult
- matchAndRewrite_(tt::FuncOp funcOp, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- int64_t bitness = 64;
- rewriter.setInsertionPointToStart(&funcOp.getBody().front());
- rewriter.modifyOpInPlace(funcOp, [&] {
- for (auto [idx, arg] : llvm::enumerate(funcOp.getArguments())) {
- // The pointer argument needs to be a scalar
- if (!isa<tt::PointerType>(arg.getType()))
- continue;
- if (auto pointerRangeAttr =
- funcOp.getArgAttrOfType<IntegerAttr>(idx, "tt.pointer_range"))
- bitness = pointerRangeAttr.getInt();
- Value zeroOffset =
- rewriter.create<arith::ConstantIntOp>(funcOp.getLoc(), 0, bitness);
- auto dummyCast = rewriter.create<UnrealizedConversionCastOp>(
- arg.getLoc(), TypeRange{arg.getType()}, ValueRange{arg});
- rewriter.replaceUsesOfBlockArgument(arg, dummyCast.getResult(0));
- // TODO(max): why is this true?
- fatPtrs[{arg, zeroOffset}].canNarrow = true;
- rewriter.replaceOpWithMultiple(dummyCast, {{arg, zeroOffset}});
- }
- });
-
- return success();
- }
-};
-
/// No-op to make conversion framework happy.
class ConvertReturnOp : public PointerCanonicalizationPattern<tt::ReturnOp> {
public:
@@ -1305,7 +1299,6 @@
tt::FuncOp func = funcOps[0];
llvm::SetVector<Operation *> opsToRewrite;
- opsToRewrite.insert(func);
for (auto arg : func.getArguments()) {
if (llvm::isa<tt::PointerType>(arg.getType())) {
// NB: reusing the same SetVector invalidates the topo order implied by
@@ -1338,12 +1331,12 @@
FatPointers fatPrs;
- patterns.add<ConvertFuncOp, ConvertBroadcastOp, ConvertSplatOp,
- ConvertAddPtrOp, ConvertLoadOp, ConvertStoreOp, ConvertSCFForOp,
- ConvertSCFYieldOp, ConvertSCFIfOp, ConvertSCFConditionOp,
- ConvertSCFWhileOp, ConvertCFCondBranch, ConvertCFBranch,
- ConvertArithSelectOp, ConvertReturnOp>(patterns.getContext(),
- opsToRewrite, fatPrs);
+ patterns
+ .add<ConvertBroadcastOp, ConvertSplatOp, ConvertAddPtrOp, ConvertLoadOp,
+ ConvertStoreOp, ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp,
+ ConvertSCFConditionOp, ConvertSCFWhileOp, ConvertCFCondBranch,
+ ConvertCFBranch, ConvertArithSelectOp, ConvertReturnOp>(
+ patterns.getContext(), opsToRewrite, fatPrs);
ConversionConfig config;
config.buildMaterializations = false;
if (failed(applyPartialConversion(func, target, std::move(patterns), config)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment