Skip to content

Instantly share code, notes, and snippets.

@jroelofs
Created April 15, 2020 16:11
Show Gist options
  • Save jroelofs/b6294975d3fe5dc0e7b61910cea883d7 to your computer and use it in GitHub Desktop.
Save jroelofs/b6294975d3fe5dc0e7b61910cea883d7 to your computer and use it in GitHub Desktop.
Matrix-Matrix multiply chain re-association for MILR's Toy examaple
commit 8ad8f04b2c4abec358d69d45870e80f50e3ccc22
Author: Jon Roelofs <[email protected]>
Date: Tue Apr 14 16:23:13 2020 -0600
WIP: reassociate matrix-matrix multiply chains
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index fafc3876db2..ee9a1205845 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -1,92 +1,229 @@
//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a set of simple combiners for optimizing operations in
// the Toy dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "toy/Dialect.h"
+#include "llvm/Support/Debug.h"
#include <numeric>
using namespace mlir;
using namespace toy;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold simple cast operations that return the same type as the input.
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return mlir::impl::foldCastOp(*this);
}
/// Fold constants.
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
/// Fold struct constants.
OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
return value();
}
/// Fold simple struct access operations that access into a constant.
OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
if (!structAttr)
return nullptr;
size_t elementIndex = index().getZExtValue();
return structAttr[elementIndex];
}
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return success();
}
};
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context);
}
+static ArrayRef<int64_t> getShape(Value v) {
+ return v.getType().cast<RankedTensorType>().getShape();
+}
+
+struct ReassociateMatMul : public mlir::OpRewritePattern<MulOp> {
+ /// Operations to ignore when this pattern is asked to analyze operations it
+ /// has already re-associated. Kind of a hack, since it has to be mutable to
+ /// work around the const-ness of matchAndRewrite.
+ mutable llvm::SmallDenseSet<Value, 16> ignore;
+
+ ReassociateMatMul(mlir::MLIRContext *context)
+ : OpRewritePattern<MulOp>(context, /*benefit=*/1) {}
+
+ mlir::LogicalResult
+ matchAndRewrite(MulOp op, mlir::PatternRewriter &rewriter) const override {
+ SmallVector<Value, 4> chain;
+
+ // pre-order walk of the entire tree of multiplies.
+ std::function<void(Value)> walk = [&chain, &walk, this](Value v) -> void {
+ // Don't bother with bits of the tree that already have some re-use, since
+ // that significantly complicates this analysis (and it's an open research
+ // problem to solve that).
+ if (!v.hasOneUse() || !v.getDefiningOp() || ignore.count(v)) {
+ chain.push_back(v);
+ return;
+ }
+
+ if (auto mul = dyn_cast<MulOp>(v.getDefiningOp())) {
+ walk(mul.getOperand(0));
+ walk(mul.getOperand(1));
+ return;
+ }
+
+ chain.push_back(v);
+ };
+
+ walk(op);
+
+ if (chain.size() < 3)
+ return rewriter.notifyMatchFailure(
+ op, "Too few MulOp's in chain for reassociation.");
+
+ DEBUG_WITH_TYPE("reassociate-matmul", {
+ llvm::outs() << "Chain: [\n ";
+ mlir::interleave(chain, llvm::outs(), "\n ");
+ llvm::outs() << "\n]\n";
+ });
+
+ // Only handle 2d matrix multiplies for now.
+ if (!llvm::all_of(chain, [](Value v) {
+ auto ty = v.getType().dyn_cast<RankedTensorType>();
+ if (!ty)
+ return false;
+ return ty.getShape().size() == 2;
+ }))
+ return rewriter.notifyMatchFailure(
+ op, "One or more operations in the chain weren't traditional "
+ "matrix-matrix multiplications.");
+
+ // Collect all the dimensions such that the matrix at chain[i] has dimension
+ // dims[i] x dims[i+1] for i in 0..N.
+ SmallVector<int64_t, 4> dims;
+ dims.reserve(chain.size() + 1);
+
+ dims.push_back(getShape(chain[0])[0]);
+ for (const Value &v : chain)
+ dims.push_back(getShape(v)[1]);
+
+ const int N = dims.size() - 1;
+
+ // mults[i * N + j] = Minimum number of multiplications needed to compute
+ // the matrix: chain[i]*chain[i+i]...chain[j] = chain[i..j]
+ SmallVector<int, 4 * 4> mults(N * N, 0);
+ auto M = [&mults, N](int Y, int X) -> int & { return mults[Y * N + X]; };
+
+ // Index of the split that achieved minimal cost.
+ SmallVector<int, 4 * 4> seq(N * N, 0);
+ auto S = [&seq, N](int Y, int X) -> int & { return seq[Y * N + X]; };
+
+ // Find the optimal re-association with dynamic programming.
+ for (int len = 1; len < N; ++len) {
+ for (int i = 0; i < N - len; ++i) {
+ int j = i + len;
+ M(i, j) = std::numeric_limits<int>::max();
+
+ // Find the best sub-sequence split point k to insert parens:
+ // (chain[i]*chain[i+1]*...chain[k]) * (chain[k+1]*...chain[j])
+ for (int k = i; k < j; ++k) {
+ int64_t cost =
+ M(i, k) + M(k + 1, j) + dims[i] * dims[k + 1] * dims[j + 1];
+ if (cost < M(i, j)) {
+ M(i, j) = cost;
+ S(i, j) = k;
+ }
+ }
+ }
+ }
+
+#ifndef NDEBUG
+ std::function<std::string(int, int)> toString = [&](int i,
+ int j) -> std::string {
+ if (i == j) {
+ // TODO: would be nicer to just print the SSA name of them, but there
+ // doesn't seem to be a convenient API that does that. *shrug*
+ return ("C" + Twine(i)).str();
+ }
+ return "(" + toString(i, S(i, j)) + "*" + toString(S(i, j) + 1, j) + ")";
+ };
+
+ DEBUG_WITH_TYPE("reassociate-matmul",
+ llvm::dbgs() << "optimal: " << toString(0, N - 1) << "\n");
+#endif
+
+ std::function<Value(int, int)> emitOptimalChain =
+ [&, this](int i, int j) -> Value {
+ if (i == j)
+ return chain[i];
+
+ Value lhs = emitOptimalChain(i, S(i, j));
+ Value rhs = emitOptimalChain(S(i, j) + 1, j);
+ MulOp mul = rewriter.create<MulOp>(op.getLoc(), lhs, rhs);
+ mul.inferShapes();
+
+ // Remind ourselves for later that this chain has already been
+ // re-associated.
+ ignore.insert(mul.getResult());
+
+ return mul.getResult();
+ };
+
+ Value res = emitOptimalChain(0, N - 1);
+ rewriter.replaceOp(op, {dyn_cast<MulOp>(res.getDefiningOp())});
+ return success();
+ }
+};
+
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
- FoldConstantReshapeOptPattern>(context);
-}
+ FoldConstantReshapeOptPattern, ReassociateMatMul>(context);
+}
\ No newline at end of file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment