Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created November 4, 2022 04:46
Show Gist options
  • Save AmosLewis/ff018db6312174a3fba8fd209c4df888 to your computer and use it in GitHub Desktop.
Save AmosLewis/ff018db6312174a3fba8fd209c4df888 to your computer and use it in GitHub Desktop.
template <typename AtenOpT>
class ConvertSelectiveAtenOpToTosaCustom : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
int num_operands = operands.size();
std::vector<mlir::Value> inputs_vec;
std::cout << " Start print type: " << std::endl;
// for(auto operand : operands){
for(int i = 0; i < num_operands; i++){
auto operand = *adaptor.getODSOperands(i).begin();
auto operand_type = operands[i].getType();
std::cout << " operand_type: " << std::endl;
operand_type.dump(); // tensor<2x3xf32> , i64, i64
std::cout << std::endl;
std::cout << " operand: " << std::endl;
operand.dump(); // tensor<2x3xf32> , i64, i64
std::cout << std::endl;
//if(operand_type.isa<mlir::IntegerType>()) { // Torch::ConstantIntOp
if(i == 1 || i == 2){
std::cout <<"process int op" <<std::endl;
int64_t operand_tosa;
if (!matchPattern(*adaptor.getODSOperands(i).begin(), m_TorchConstantInt(&operand_tosa))){
std::cout << "rewriter.notifyMatchFailure" << std::endl;
return rewriter.notifyMatchFailure(
op, "unimplemented: operand should be a torch.constant.int");
} else {
std::cout << "rewriter.notifyMatchSuccess" << std::endl;
}
std::cout << " tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1}) " << std::endl;
auto operand_tensor_int =
tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1});
std::cout <<" inputs_vec.push_back(operand_tensor_int.value()); " <<std::endl;
inputs_vec.push_back(operand_tensor_int.value());
} else if (operand_type.isInteger(1)) { // Torch::ConstantBoolOp
std::cout << "process bool op" << std::endl;
bool operand_tosa;
if (!matchPattern(operand, m_TorchConstantBool(&operand_tosa)))
return rewriter.notifyMatchFailure(
op, "unimplemented: operand should be a torch.constant.bool");
auto operand_tensor_bool = tosa::getConstTensor<int64_t>(
rewriter, op.getOperation(), operand_tosa, {1});
inputs_vec.push_back(operand_tensor_bool.value());
} else if(operand_type.isa<mlir::FloatType>()){ // Torch::ConstantFloatOp
std::cout <<"process Float op" <<std::endl;
double operand_tosa;
if (!matchPattern(operand, m_TorchConstantFloat(&operand_tosa)))
return rewriter.notifyMatchFailure(
op, "unimplemented: operand should be a torch.constant.float");
auto operand_tensor_float =
tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1});
inputs_vec.push_back(operand_tensor_float.value());
} else if(operand_type.isa<mlir::TensorType>()){ // Torch::ValueTensorType
} else if(i == 0){ // Torch::ValueTensorType
std::cout <<" process ValueTensorType op" <<std::endl;
inputs_vec.push_back(operand);
} else{
return rewriter.notifyMatchFailure(
op, "unimplemented: inputs type. The input has to be int/bool/float ");
}
}
// Create output type for tosa::CustomOp input
auto outType = this->getTypeConverter()->convertType(op.getType());
llvm::ArrayRef<mlir::Value> ref(inputs_vec.data(), inputs_vec.size());
ValueRange custom_inputs(ref);
rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op, outType, op.getOperationName(), custom_inputs);
std::cout<<std::endl;
std::cout << " adaptor.self().getType(): " << std::endl;
adaptor.self().getType().dump(); // tensor<2x3xf32>
std::cout<<std::endl;
std::cout << " op.getType(): " << std::endl; // !torch.vtensor<[2,3],f32>
op.getType().dump();
std::cout<<std::endl;
// auto selfType = adaptor.self().getType().template dyn_cast<TensorType>();
// if (!selfType)
// return rewriter.notifyMatchFailure(
// op, "unimplemented: value 'self' should be tensor types ");
//
// // Get the dim int64_t type value from AtenSoftmaxIntOp second input,
// // type need to convert from mlir::TypedValue<::mlir::torch::Torch::IntType>
// int64_t dim;
// if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
// return rewriter.notifyMatchFailure(
// op, "unimplemented: value `dim` should be a torch.constant.int");
//
// int64_t dtype;
// if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtype)))
// return rewriter.notifyMatchFailure(
// op, "unimplemented: value `dtype` should be a torch.constant.int");
//
// // Create output type for tosa::CustomOp input
// auto outType = this->getTypeConverter()->convertType(op.getType());
// // Create name attribute and multi-args for tosa::CustomOp input
// StringAttr nameValueAttr = rewriter.getStringAttr("aten.softmax.int");
// auto dimTensor =
// tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), dim, {1});
// auto dtypeTensor =
// tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), dtype, {1});
//
// rewriter.replaceOpWithNewOp<tosa::CustomOp>(
// op, outType, nameValueAttr,
// ValueRange{adaptor.self(), dimTensor.value(), dtypeTensor.value()});
return success();
}
};
@AmosLewis
Copy link
Author

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <unordered_map>
#include <iostream>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

template <typename AtenOpT>
class ConvertSelectiveAtenOpToTosaCustom : public OpConversionPattern<AtenOpT> {
public:
  using OpConversionPattern<AtenOpT>::OpConversionPattern;
  using OpAdaptor = typename AtenOpT::Adaptor;
  LogicalResult
  matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {

    ValueRange operands = adaptor.getOperands();
    int num_operands = operands.size();
    std::vector<mlir::Value> inputs_vec;
    std::cout << " Start print type: " << std::endl;
//    for(auto operand : operands){
    for(int i = 0; i < num_operands; i++){
      auto operand = *adaptor.getODSOperands(i).begin();
      auto operand_type = operands[i].getType();
      std::cout << " operand_type: " << std::endl;
      operand_type.dump(); // tensor<2x3xf32> , i64, i64
      std::cout << std::endl;
      std::cout << " operand: " << std::endl;
      operand.dump(); // tensor<2x3xf32> , i64, i64
      std::cout << std::endl;
//      std::cout << " operand.getDefiningOp().dump();" << std::endl;
//      operand.getDefiningOp()->dump();
//      std::cout << std::endl;

      if(operand_type.isa<mlir::IntegerType>()) { // Torch::ConstantIntOp
//      if(i == 1 || i == 2){
//      if(auto constantInt = dyn_cast<Torch::ConstantIntOp>(operand.getDefiningOp())){
        std::cout <<"process int op" <<std::endl;
        int64_t operand_tosa;
        if (!matchPattern(*op.getODSOperands(i).begin(), m_TorchConstantInt(&operand_tosa))){
          std::cout << "rewriter.notifyMatchFailure" << std::endl;
          return rewriter.notifyMatchFailure(
              op, "unimplemented: operand should be a torch.constant.int");
        } else {
          std::cout << "rewriter.notifyMatchSuccess" << std::endl;
        }

        std::cout << " tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1}) " << std::endl;
        auto operand_tensor_int =
            tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1});
        std::cout <<" inputs_vec.push_back(operand_tensor_int.value()); " <<std::endl;
        inputs_vec.push_back(operand_tensor_int.value());
      } else if (operand_type.isInteger(1)) { // Torch::ConstantBoolOp
        std::cout << "process bool op" << std::endl;
        bool operand_tosa;
        if (!matchPattern(operand, m_TorchConstantBool(&operand_tosa)))
          return rewriter.notifyMatchFailure(
              op, "unimplemented: operand should be a torch.constant.bool");
        auto operand_tensor_bool = tosa::getConstTensor<int64_t>(
            rewriter, op.getOperation(), operand_tosa, {1});
        inputs_vec.push_back(operand_tensor_bool.value());
      } else if(operand_type.isa<mlir::FloatType>()){ // Torch::ConstantFloatOp
        std::cout <<"process Float op" <<std::endl;
        double operand_tosa;
        if (!matchPattern(operand, m_TorchConstantFloat(&operand_tosa)))
          return rewriter.notifyMatchFailure(
              op, "unimplemented: operand should be a torch.constant.float");
        auto operand_tensor_float =
            tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1});
        inputs_vec.push_back(operand_tensor_float.value());
      } else if(operand_type.isa<mlir::TensorType>()){ // Torch::ValueTensorType
//      } else if(i == 0){ // Torch::ValueTensorType
        std::cout <<" process ValueTensorType op" <<std::endl;
        inputs_vec.push_back(*adaptor.getODSOperands(i).begin());
      } else{
        return rewriter.notifyMatchFailure(
            op, "unimplemented: inputs type. The input has to be int/bool/float ");
      }

    }


    // Create output type for tosa::CustomOp input
    auto outType = this->getTypeConverter()->convertType(op.getType());
    llvm::ArrayRef<mlir::Value> ref(inputs_vec.data(), inputs_vec.size());
    ValueRange custom_inputs(ref);
    rewriter.replaceOpWithNewOp<tosa::CustomOp>(
        op, outType, op.getOperationName(), custom_inputs);

    std::cout<<std::endl;
    std::cout << " adaptor.self().getType(): " << std::endl;
    adaptor.self().getType().dump(); // tensor<2x3xf32>
    std::cout<<std::endl;
    std::cout << " op.getType(): " << std::endl; // !torch.vtensor<[2,3],f32>
    op.getType().dump();
    std::cout<<std::endl;
    return success();
  }
};

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
  using OpConversionPattern<AtenOpT>::OpConversionPattern;
  using OpAdaptor = typename AtenOpT::Adaptor;
  LogicalResult
  matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

template <>
LogicalResult ConvertAtenOp<AtenSoftmaxIntOp>::matchAndRewrite(
    AtenSoftmaxIntOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  // Math: exp(%x) / sum(exp(%x), %dim)
  // Torch format:
  //     "aten.softmax.int"(%x,%dim): (tensor<2x3xf32>, int) -> tensor<2x3xf32>
  // Decompose tosa format: with -torch-decompose-complex-ops flag
  //     https://gist.github.com/AmosLewis/e668c3bfd2472e9f9f045e012362d831
  //     %2 = "tosa.exp"(%x) : (tensor<2x3xf32>) -> tensor<2x3xf32>
  //     %3 = "tosa.reduce_sum"(%2) {axis = %dim : i64} : (tensor<2x3xf32>) ->
  //     tensor<2x1xf32> %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) ->
  //     tensor<2x1xf32> %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} :
  //     (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  // No-Decompose TOSA format: without -torch-decompose-complex-ops flag
  //     "tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>,
  //     tensor<1xi64>) -> tensor<2x3xf32>
  ValueRange operands = adaptor.getOperands();
  // Check AtenSoftmaxIntOp first input is a tensor type.
  auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
  if (!selfType)
    return rewriter.notifyMatchFailure(
        op, "unimplemented: value 'self' should be tensor types ");

  // Get the dim int64_t type value from AtenSoftmaxIntOp second input,
  // type need to convert from mlir::TypedValue<::mlir::torch::Torch::IntType>
  int64_t dim;
  if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
    return rewriter.notifyMatchFailure(
        op, "unimplemented: value `dim` should be a torch.constant.int");

  int64_t dtype;
  if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtype)))
    return rewriter.notifyMatchFailure(
        op, "unimplemented: value `dtype` should be a torch.constant.int");

  // Create output type for tosa::CustomOp input
  auto outType = getTypeConverter()->convertType(op.getType());
  // Create name attribute and multi-args for tosa::CustomOp input
  StringAttr nameValueAttr = rewriter.getStringAttr("aten.softmax.int");
  auto dimTensor =
      tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), dim, {1});
  auto dtypeTensor =
      tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), dtype, {1});

  rewriter.replaceOpWithNewOp<tosa::CustomOp>(
      op, outType, nameValueAttr,
      ValueRange{adaptor.self(), dimTensor.value(), dtypeTensor.value()});
  return success();
}

} // namespace

// -----------------------------------------------------------------------------
// TorchToTosaCustom Pass
// -----------------------------------------------------------------------------

namespace {
class ConvertTorchToTosaCustom
    : public ConvertTorchToTosaCustomBase<ConvertTorchToTosaCustom> {
public:
  ConvertTorchToTosaCustom() = default;
  ConvertTorchToTosaCustom(ArrayRef<std::string> customOps) {
    this->customOps = customOps;
  }
  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<tosa::TosaDialect>();
    registry.insert<tensor::TensorDialect>();
    registry.insert<arith::ArithDialect>();
    TorchConversion::getBackendTypeConversionDependentDialects(registry);
  }

  void runOnOperation() override {
    MLIRContext *context = &getContext();
    ConversionTarget target(*context);
    target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
                           arith::ArithDialect>();

    TypeConverter typeConverter;
    typeConverter.addConversion([](Type type) { return type; });
    TorchConversion::setupBackendTypeConversion(target, typeConverter);

    RewritePatternSet patterns(context);

    std::unordered_map<std::string, bool> customOpsMap;
    for (auto key : customOps) {
      customOpsMap[key] = true;
    }

#define INSERT_ATENOP_PATTERN(AtenOp)                                          \
  patterns.add<ConvertSelectiveAtenOpToTosaCustom<AtenOp>>(typeConverter, context);
    if (customOpsMap["torch.aten.softmax.int"]) {
      INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp);
    }
#undef INSERT_ATENOP_PATTERN

    if (failed(applyPartialConversion(getOperation(), target,
                                      std::move(patterns))))
      return signalPassFailure();
  }
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaCustomPass(
    ArrayRef<std::string> customOps) {
  return std::make_unique<ConvertTorchToTosaCustom>(customOps);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment