Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created January 30, 2021 05:03
Show Gist options
  • Save bjacob/a7a1ccd7fc1d02865c58dc184f6665a5 to your computer and use it in GitHub Desktop.
Save bjacob/a7a1ccd7fc1d02865c58dc184f6665a5 to your computer and use it in GitHub Desktop.
// I'm trying to generate a simple loop nest performing matrix multiplication,
// while staying at the Tensor level.
// Everything worked well while this code was working at the memref level.
// But now I don't know how to end this rewrite pass:
// I can't eraseOp because it says that the old op's result is still used.
// If I don't do anything at the end, my pass has no effect, presumably
// because I haven't actually replaced the old op by the new loop nest.
// So I need something like replaceOp at the end of this function, but I don't
// know how to pass to it the new loop nest that we just constructed.
// I don't understand the notion of "result" of a loop nest. I've tried
// getResult() on the loop nest but apparently it returns empty.
// Also I wasn't sure if my loop bodies are supposed to be returning something.
// See how the innermost loop bosy has a return statement and not the outer loops.
// Obviously I have NO idea what I'm doing :-)
class MehMatmulToSCFPattern : public OpRewritePattern<meh::MatmulOp> {
public:
using OpRewritePattern<meh::MatmulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(meh::MatmulOp op,
PatternRewriter &rewriter) const override {
auto lhsVal = op.lhs();
auto rhsVal = op.rhs();
auto dstVal = op.dst();
auto dstValType = dstVal.getType().cast<ShapedType>();
auto dstShape = dstValType.getShape(); // ArrayRef<int64_t>
auto M = dstShape[0];
auto N = dstShape[1];
auto rhsValType = dstVal.getType().cast<ShapedType>();
auto rhsShape = rhsValType.getShape();
edsc::ScopedContext scope(rewriter, op.getLoc());
auto K = rhsShape[0];
Value zero = edsc::intrinsics::std_constant_index(0);
Value one = edsc::intrinsics::std_constant_index(1);
Value boundM = edsc::intrinsics::std_constant_index(M);
Value boundN = edsc::intrinsics::std_constant_index(N);
Value boundK = edsc::intrinsics::std_constant_index(K);
edsc::loopNestBuilder(zero, boundM, one, [&](Value m) {
edsc::loopNestBuilder(zero, boundN, one, [&](Value n) {
edsc::loopNestBuilder(zero, boundK, one, [&](Value k) {
Value lhs_entry = rewriter.create<SubTensorOp>(
op.getLoc(), lhsVal,
ValueRange{m, k},
ValueRange{one, one},
ValueRange{one, one});
Value rhs_entry = rewriter.create<SubTensorOp>(
op.getLoc(), lhsVal,
ValueRange{k, n},
ValueRange{one, one},
ValueRange{one, one});
auto product = rewriter.create<linalg::MatmulOp>(
op.getLoc(), TypeRange{dstVal.getType()},
ValueRange{lhs_entry, rhs_entry}, lhs_entry);
return rewriter.create<SubTensorInsertOp>(op.getLoc(), product.getResults()[0], dstVal, ValueRange{m, n},
ValueRange{one, one},
ValueRange{one, one});
});
});
});
rewriter.replaceOp(op, /* How do I replace 'op' by the above loop nest? */);
return success();
}
};
struct ConvertMehMatmulToSCFPass
: public PassWrapper<ConvertMehMatmulToSCFPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<meh::MehDialect, scf::SCFDialect>();
}
void runOnFunction() override;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment