Created
          January 28, 2021 16:57 
        
      - 
      
- 
        Save bjacob/2ddd75ec3998a771a3852c3ca4233fd5 to your computer and use it in GitHub Desktop. 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | // meh::MatmulOp takes two Tensors as arguments lhs() and rhs(), and returns a Tensor, dst(). | |
| // The below is code that was originally written against memrefs, and doesn't work anymore now what meh::MatmulOp deals with tensors. | |
| // The question is how to fix it to work with tensors. | |
| class MehMatmulToSCFPattern : public OpRewritePattern<meh::MatmulOp> { | |
| public: | |
| using OpRewritePattern<meh::MatmulOp>::OpRewritePattern; | |
| LogicalResult matchAndRewrite(meh::MatmulOp op, | |
| PatternRewriter &rewriter) const override { | |
| auto lhsVal = op.lhs(); | |
| auto rhsVal = op.rhs(); | |
| auto dstVal = op.dst(); | |
| auto dstValType = dstVal.getType().cast<RankedTensorType>(); | |
| auto dstShape = dstValType.getShape(); // ArrayRef<int64_t> | |
| auto M = dstShape[0]; | |
| auto N = dstShape[1]; | |
| auto rhsValType = dstVal.getType().cast<RankedTensorType>(); | |
| auto rhsShape = rhsValType.getShape(); | |
| edsc::ScopedContext scope(rewriter, op.getLoc()); | |
| auto K = rhsShape[0]; | |
| Value zero = edsc::intrinsics::std_constant_index(0); | |
| Value step = edsc::intrinsics::std_constant_index(1); | |
| Value boundM = edsc::intrinsics::std_constant_index(M); | |
| Value boundN = edsc::intrinsics::std_constant_index(N); | |
| Value boundK = edsc::intrinsics::std_constant_index(K); | |
| edsc::loopNestBuilder(zero, boundM, step, [&](Value m) { | |
| edsc::loopNestBuilder(zero, boundN, step, [&](Value n) { | |
| edsc::loopNestBuilder(zero, boundK, step, [&](Value k) { | |
| // XXX Assertion failure, type mismatch in Cast in the next line. | |
| Value lhs_val = | |
| edsc::intrinsics::std_load(lhsVal, ArrayRef<Value>{m, k}); | |
| Value rhs_val = | |
| edsc::intrinsics::std_load(rhsVal, ArrayRef<Value>{k, n}); | |
| Value dst_val = | |
| edsc::intrinsics::std_load(dstVal, ArrayRef<Value>{m, n}); | |
| Value mul_f = edsc::intrinsics::std_mulf(lhs_val, rhs_val); | |
| Value res = edsc::intrinsics::std_addf(dst_val, mul_f); | |
| edsc::intrinsics::std_store(res, dstVal, ArrayRef<Value>{m, n}); | |
| }); | |
| }); | |
| }); | |
| rewriter.eraseOp(op); | |
| return success(); | |
| } | |
| }; | |
| struct ConvertMehMatmulToSCFPass | |
| : public PassWrapper<ConvertMehMatmulToSCFPass, FunctionPass> { | |
| void getDependentDialects(DialectRegistry ®istry) const override { | |
| registry.insert<meh::MehDialect, scf::SCFDialect>(); | |
| } | |
| void runOnFunction() override; | |
| }; | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment