Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created January 28, 2021 16:57
Show Gist options
  • Save bjacob/2ddd75ec3998a771a3852c3ca4233fd5 to your computer and use it in GitHub Desktop.
Save bjacob/2ddd75ec3998a771a3852c3ca4233fd5 to your computer and use it in GitHub Desktop.
// meh::MatmulOp takes two Tensors as arguments lhs() and rhs(), and returns a Tensor, dst().
// The below is code that was originally written against memrefs, and doesn't work anymore now what meh::MatmulOp deals with tensors.
// The question is how to fix it to work with tensors.
class MehMatmulToSCFPattern : public OpRewritePattern<meh::MatmulOp> {
public:
using OpRewritePattern<meh::MatmulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(meh::MatmulOp op,
PatternRewriter &rewriter) const override {
auto lhsVal = op.lhs();
auto rhsVal = op.rhs();
auto dstVal = op.dst();
auto dstValType = dstVal.getType().cast<RankedTensorType>();
auto dstShape = dstValType.getShape(); // ArrayRef<int64_t>
auto M = dstShape[0];
auto N = dstShape[1];
auto rhsValType = dstVal.getType().cast<RankedTensorType>();
auto rhsShape = rhsValType.getShape();
edsc::ScopedContext scope(rewriter, op.getLoc());
auto K = rhsShape[0];
Value zero = edsc::intrinsics::std_constant_index(0);
Value step = edsc::intrinsics::std_constant_index(1);
Value boundM = edsc::intrinsics::std_constant_index(M);
Value boundN = edsc::intrinsics::std_constant_index(N);
Value boundK = edsc::intrinsics::std_constant_index(K);
edsc::loopNestBuilder(zero, boundM, step, [&](Value m) {
edsc::loopNestBuilder(zero, boundN, step, [&](Value n) {
edsc::loopNestBuilder(zero, boundK, step, [&](Value k) {
// XXX Assertion failure, type mismatch in Cast in the next line.
Value lhs_val =
edsc::intrinsics::std_load(lhsVal, ArrayRef<Value>{m, k});
Value rhs_val =
edsc::intrinsics::std_load(rhsVal, ArrayRef<Value>{k, n});
Value dst_val =
edsc::intrinsics::std_load(dstVal, ArrayRef<Value>{m, n});
Value mul_f = edsc::intrinsics::std_mulf(lhs_val, rhs_val);
Value res = edsc::intrinsics::std_addf(dst_val, mul_f);
edsc::intrinsics::std_store(res, dstVal, ArrayRef<Value>{m, n});
});
});
});
rewriter.eraseOp(op);
return success();
}
};
struct ConvertMehMatmulToSCFPass
: public PassWrapper<ConvertMehMatmulToSCFPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<meh::MehDialect, scf::SCFDialect>();
}
void runOnFunction() override;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment