Created
April 15, 2020 16:11
-
-
Save jroelofs/b6294975d3fe5dc0e7b61910cea883d7 to your computer and use it in GitHub Desktop.
Matrix-Matrix multiply chain re-association for MILR's Toy examaple
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 8ad8f04b2c4abec358d69d45870e80f50e3ccc22 | |
Author: Jon Roelofs <[email protected]> | |
Date: Tue Apr 14 16:23:13 2020 -0600 | |
WIP: reassociate matrix-matrix multiply chains | |
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp | |
index fafc3876db2..ee9a1205845 100644 | |
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp | |
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp | |
@@ -1,92 +1,229 @@ | |
//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// | |
// | |
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |
// See https://llvm.org/LICENSE.txt for license information. | |
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
// | |
//===----------------------------------------------------------------------===// | |
// | |
// This file implements a set of simple combiners for optimizing operations in | |
// the Toy dialect. | |
// | |
//===----------------------------------------------------------------------===// | |
#include "mlir/IR/Matchers.h" | |
#include "mlir/IR/PatternMatch.h" | |
#include "toy/Dialect.h" | |
+#include "llvm/Support/Debug.h" | |
#include <numeric> | |
using namespace mlir; | |
using namespace toy; | |
namespace { | |
/// Include the patterns defined in the Declarative Rewrite framework. | |
#include "ToyCombine.inc" | |
} // end anonymous namespace | |
/// Fold simple cast operations that return the same type as the input. | |
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { | |
return mlir::impl::foldCastOp(*this); | |
} | |
/// Fold constants. | |
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); } | |
/// Fold struct constants. | |
OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) { | |
return value(); | |
} | |
/// Fold simple struct access operations that access into a constant. | |
OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) { | |
auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>(); | |
if (!structAttr) | |
return nullptr; | |
size_t elementIndex = index().getZExtValue(); | |
return structAttr[elementIndex]; | |
} | |
/// This is an example of a c++ rewrite pattern for the TransposeOp. It | |
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) | |
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> { | |
/// We register this pattern to match every toy.transpose in the IR. | |
/// The "benefit" is used by the framework to order the patterns and process | |
/// them in order of profitability. | |
SimplifyRedundantTranspose(mlir::MLIRContext *context) | |
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {} | |
/// This method attempts to match a pattern and rewrite it. The rewriter | |
/// argument is the orchestrator of the sequence of rewrites. The pattern is | |
/// expected to interact with it to perform any changes to the IR from here. | |
mlir::LogicalResult | |
matchAndRewrite(TransposeOp op, | |
mlir::PatternRewriter &rewriter) const override { | |
// Look through the input of the current transpose. | |
mlir::Value transposeInput = op.getOperand(); | |
TransposeOp transposeInputOp = | |
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp()); | |
// Input defined by another transpose? If not, no match. | |
if (!transposeInputOp) | |
return failure(); | |
// Otherwise, we have a redundant transpose. Use the rewriter. | |
rewriter.replaceOp(op, {transposeInputOp.getOperand()}); | |
return success(); | |
} | |
}; | |
/// Register our patterns as "canonicalization" patterns on the TransposeOp so | |
/// that they can be picked up by the Canonicalization framework. | |
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, | |
MLIRContext *context) { | |
results.insert<SimplifyRedundantTranspose>(context); | |
} | |
+static ArrayRef<int64_t> getShape(Value v) { | |
+ return v.getType().cast<RankedTensorType>().getShape(); | |
+} | |
+ | |
+struct ReassociateMatMul : public mlir::OpRewritePattern<MulOp> { | |
+ /// Operations to ignore when this pattern is asked to analyze operations it | |
+ /// has already re-associated. Kind of a hack, since it has to be mutable to | |
+ /// work around the const-ness of matchAndRewrite. | |
+ mutable llvm::SmallDenseSet<Value, 16> ignore; | |
+ | |
+ ReassociateMatMul(mlir::MLIRContext *context) | |
+ : OpRewritePattern<MulOp>(context, /*benefit=*/1) {} | |
+ | |
+ mlir::LogicalResult | |
+ matchAndRewrite(MulOp op, mlir::PatternRewriter &rewriter) const override { | |
+ SmallVector<Value, 4> chain; | |
+ | |
+ // pre-order walk of the entire tree of multiplies. | |
+ std::function<void(Value)> walk = [&chain, &walk, this](Value v) -> void { | |
+ // Don't bother with bits of the tree that already have some re-use, since | |
+ // that significantly complicates this analysis (and it's an open research | |
+ // problem to solve that). | |
+ if (!v.hasOneUse() || !v.getDefiningOp() || ignore.count(v)) { | |
+ chain.push_back(v); | |
+ return; | |
+ } | |
+ | |
+ if (auto mul = dyn_cast<MulOp>(v.getDefiningOp())) { | |
+ walk(mul.getOperand(0)); | |
+ walk(mul.getOperand(1)); | |
+ return; | |
+ } | |
+ | |
+ chain.push_back(v); | |
+ }; | |
+ | |
+ walk(op); | |
+ | |
+ if (chain.size() < 3) | |
+ return rewriter.notifyMatchFailure( | |
+ op, "Too few MulOp's in chain for reassociation."); | |
+ | |
+ DEBUG_WITH_TYPE("reassociate-matmul", { | |
+ llvm::outs() << "Chain: [\n "; | |
+ mlir::interleave(chain, llvm::outs(), "\n "); | |
+ llvm::outs() << "\n]\n"; | |
+ }); | |
+ | |
+ // Only handle 2d matrix multiplies for now. | |
+ if (!llvm::all_of(chain, [](Value v) { | |
+ auto ty = v.getType().dyn_cast<RankedTensorType>(); | |
+ if (!ty) | |
+ return false; | |
+ return ty.getShape().size() == 2; | |
+ })) | |
+ return rewriter.notifyMatchFailure( | |
+ op, "One or more operations in the chain weren't traditional " | |
+ "matrix-matrix multiplications."); | |
+ | |
+ // Collect all the dimensions such that the matrix at chain[i] has dimension | |
+ // dims[i] x dims[i+1] for i in 0..N. | |
+ SmallVector<int64_t, 4> dims; | |
+ dims.reserve(chain.size() + 1); | |
+ | |
+ dims.push_back(getShape(chain[0])[0]); | |
+ for (const Value &v : chain) | |
+ dims.push_back(getShape(v)[1]); | |
+ | |
+ const int N = dims.size() - 1; | |
+ | |
+ // mults[i * N + j] = Minimum number of multiplications needed to compute | |
+ // the matrix: chain[i]*chain[i+i]...chain[j] = chain[i..j] | |
+ SmallVector<int, 4 * 4> mults(N * N, 0); | |
+ auto M = [&mults, N](int Y, int X) -> int & { return mults[Y * N + X]; }; | |
+ | |
+ // Index of the split that achieved minimal cost. | |
+ SmallVector<int, 4 * 4> seq(N * N, 0); | |
+ auto S = [&seq, N](int Y, int X) -> int & { return seq[Y * N + X]; }; | |
+ | |
+ // Find the optimal re-association with dynamic programming. | |
+ for (int len = 1; len < N; ++len) { | |
+ for (int i = 0; i < N - len; ++i) { | |
+ int j = i + len; | |
+ M(i, j) = std::numeric_limits<int>::max(); | |
+ | |
+ // Find the best sub-sequence split point k to insert parens: | |
+ // (chain[i]*chain[i+1]*...chain[k]) * (chain[k+1]*...chain[j]) | |
+ for (int k = i; k < j; ++k) { | |
+ int64_t cost = | |
+ M(i, k) + M(k + 1, j) + dims[i] * dims[k + 1] * dims[j + 1]; | |
+ if (cost < M(i, j)) { | |
+ M(i, j) = cost; | |
+ S(i, j) = k; | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+#ifndef NDEBUG | |
+ std::function<std::string(int, int)> toString = [&](int i, | |
+ int j) -> std::string { | |
+ if (i == j) { | |
+ // TODO: would be nicer to just print the SSA name of them, but there | |
+ // doesn't seem to be a convenient API that does that. *shrug* | |
+ return ("C" + Twine(i)).str(); | |
+ } | |
+ return "(" + toString(i, S(i, j)) + "*" + toString(S(i, j) + 1, j) + ")"; | |
+ }; | |
+ | |
+ DEBUG_WITH_TYPE("reassociate-matmul", | |
+ llvm::dbgs() << "optimal: " << toString(0, N - 1) << "\n"); | |
+#endif | |
+ | |
+ std::function<Value(int, int)> emitOptimalChain = | |
+ [&, this](int i, int j) -> Value { | |
+ if (i == j) | |
+ return chain[i]; | |
+ | |
+ Value lhs = emitOptimalChain(i, S(i, j)); | |
+ Value rhs = emitOptimalChain(S(i, j) + 1, j); | |
+ MulOp mul = rewriter.create<MulOp>(op.getLoc(), lhs, rhs); | |
+ mul.inferShapes(); | |
+ | |
+ // Remind ourselves for later that this chain has already been | |
+ // re-associated. | |
+ ignore.insert(mul.getResult()); | |
+ | |
+ return mul.getResult(); | |
+ }; | |
+ | |
+ Value res = emitOptimalChain(0, N - 1); | |
+ rewriter.replaceOp(op, {dyn_cast<MulOp>(res.getDefiningOp())}); | |
+ return success(); | |
+ } | |
+}; | |
+ | |
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so | |
/// that they can be picked up by the Canonicalization framework. | |
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, | |
MLIRContext *context) { | |
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern, | |
- FoldConstantReshapeOptPattern>(context); | |
-} | |
+ FoldConstantReshapeOptPattern, ReassociateMatMul>(context); | |
+} | |
\ No newline at end of file |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment