Created
          January 30, 2021 05:03 
        
      - 
      
- 
        Save bjacob/a7a1ccd7fc1d02865c58dc184f6665a5 to your computer and use it in GitHub Desktop. 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | // I'm trying to generate a simple loop nest performing matrix multiplication, | |
| // while staying at the Tensor level. | |
| // Everything worked well while this code was working at the memref level. | |
| // But now I don't know how to end this rewrite pass: | |
| // I can't eraseOp because it says that the old op's result is still used. | |
| // If I don't do anything at the end, my pass has no effect, presumably | |
| // because I haven't actually replaced the old op by the new loop nest. | |
| // So I need something like replaceOp at the end of this function, but I don't | |
| // know how to pass to it the new loop nest that we just constructed. | |
| // I don't understand the notion of "result" of a loop nest. I've tried | |
| // getResult() on the loop nest but apparently it returns empty. | |
| // Also I wasn't sure if my loop bodies are supposed to be returning something. | |
| // See how the innermost loop bosy has a return statement and not the outer loops. | |
| // Obviously I have NO idea what I'm doing :-) | |
| class MehMatmulToSCFPattern : public OpRewritePattern<meh::MatmulOp> { | |
| public: | |
| using OpRewritePattern<meh::MatmulOp>::OpRewritePattern; | |
| LogicalResult matchAndRewrite(meh::MatmulOp op, | |
| PatternRewriter &rewriter) const override { | |
| auto lhsVal = op.lhs(); | |
| auto rhsVal = op.rhs(); | |
| auto dstVal = op.dst(); | |
| auto dstValType = dstVal.getType().cast<ShapedType>(); | |
| auto dstShape = dstValType.getShape(); // ArrayRef<int64_t> | |
| auto M = dstShape[0]; | |
| auto N = dstShape[1]; | |
| auto rhsValType = dstVal.getType().cast<ShapedType>(); | |
| auto rhsShape = rhsValType.getShape(); | |
| edsc::ScopedContext scope(rewriter, op.getLoc()); | |
| auto K = rhsShape[0]; | |
| Value zero = edsc::intrinsics::std_constant_index(0); | |
| Value one = edsc::intrinsics::std_constant_index(1); | |
| Value boundM = edsc::intrinsics::std_constant_index(M); | |
| Value boundN = edsc::intrinsics::std_constant_index(N); | |
| Value boundK = edsc::intrinsics::std_constant_index(K); | |
| edsc::loopNestBuilder(zero, boundM, one, [&](Value m) { | |
| edsc::loopNestBuilder(zero, boundN, one, [&](Value n) { | |
| edsc::loopNestBuilder(zero, boundK, one, [&](Value k) { | |
| Value lhs_entry = rewriter.create<SubTensorOp>( | |
| op.getLoc(), lhsVal, | |
| ValueRange{m, k}, | |
| ValueRange{one, one}, | |
| ValueRange{one, one}); | |
| Value rhs_entry = rewriter.create<SubTensorOp>( | |
| op.getLoc(), lhsVal, | |
| ValueRange{k, n}, | |
| ValueRange{one, one}, | |
| ValueRange{one, one}); | |
| auto product = rewriter.create<linalg::MatmulOp>( | |
| op.getLoc(), TypeRange{dstVal.getType()}, | |
| ValueRange{lhs_entry, rhs_entry}, lhs_entry); | |
| return rewriter.create<SubTensorInsertOp>(op.getLoc(), product.getResults()[0], dstVal, ValueRange{m, n}, | |
| ValueRange{one, one}, | |
| ValueRange{one, one}); | |
| }); | |
| }); | |
| }); | |
| rewriter.replaceOp(op, /* How do I replace 'op' by the above loop nest? */); | |
| return success(); | |
| } | |
| }; | |
| struct ConvertMehMatmulToSCFPass | |
| : public PassWrapper<ConvertMehMatmulToSCFPass, FunctionPass> { | |
| void getDependentDialects(DialectRegistry ®istry) const override { | |
| registry.insert<meh::MehDialect, scf::SCFDialect>(); | |
| } | |
| void runOnFunction() override; | |
| }; | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment