Created
November 4, 2022 04:46
-
-
Save AmosLewis/ff018db6312174a3fba8fd209c4df888 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <typename AtenOpT> | |
class ConvertSelectiveAtenOpToTosaCustom : public OpConversionPattern<AtenOpT> { | |
public: | |
using OpConversionPattern<AtenOpT>::OpConversionPattern; | |
using OpAdaptor = typename AtenOpT::Adaptor; | |
LogicalResult | |
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, | |
ConversionPatternRewriter &rewriter) const override { | |
ValueRange operands = adaptor.getOperands(); | |
int num_operands = operands.size(); | |
std::vector<mlir::Value> inputs_vec; | |
std::cout << " Start print type: " << std::endl; | |
// for(auto operand : operands){ | |
for(int i = 0; i < num_operands; i++){ | |
auto operand = *adaptor.getODSOperands(i).begin(); | |
auto operand_type = operands[i].getType(); | |
std::cout << " operand_type: " << std::endl; | |
operand_type.dump(); // tensor<2x3xf32> , i64, i64 | |
std::cout << std::endl; | |
std::cout << " operand: " << std::endl; | |
operand.dump(); // tensor<2x3xf32> , i64, i64 | |
std::cout << std::endl; | |
//if(operand_type.isa<mlir::IntegerType>()) { // Torch::ConstantIntOp | |
if(i == 1 || i == 2){ | |
std::cout <<"process int op" <<std::endl; | |
int64_t operand_tosa; | |
if (!matchPattern(*adaptor.getODSOperands(i).begin(), m_TorchConstantInt(&operand_tosa))){ | |
std::cout << "rewriter.notifyMatchFailure" << std::endl; | |
return rewriter.notifyMatchFailure( | |
op, "unimplemented: operand should be a torch.constant.int"); | |
} else { | |
std::cout << "rewriter.notifyMatchSuccess" << std::endl; | |
} | |
std::cout << " tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1}) " << std::endl; | |
auto operand_tensor_int = | |
tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1}); | |
std::cout <<" inputs_vec.push_back(operand_tensor_int.value()); " <<std::endl; | |
inputs_vec.push_back(operand_tensor_int.value()); | |
} else if (operand_type.isInteger(1)) { // Torch::ConstantBoolOp | |
std::cout << "process bool op" << std::endl; | |
bool operand_tosa; | |
if (!matchPattern(operand, m_TorchConstantBool(&operand_tosa))) | |
return rewriter.notifyMatchFailure( | |
op, "unimplemented: operand should be a torch.constant.bool"); | |
auto operand_tensor_bool = tosa::getConstTensor<int64_t>( | |
rewriter, op.getOperation(), operand_tosa, {1}); | |
inputs_vec.push_back(operand_tensor_bool.value()); | |
} else if(operand_type.isa<mlir::FloatType>()){ // Torch::ConstantFloatOp | |
std::cout <<"process Float op" <<std::endl; | |
double operand_tosa; | |
if (!matchPattern(operand, m_TorchConstantFloat(&operand_tosa))) | |
return rewriter.notifyMatchFailure( | |
op, "unimplemented: operand should be a torch.constant.float"); | |
auto operand_tensor_float = | |
tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), operand_tosa, {1}); | |
inputs_vec.push_back(operand_tensor_float.value()); | |
} else if(operand_type.isa<mlir::TensorType>()){ // Torch::ValueTensorType | |
} else if(i == 0){ // Torch::ValueTensorType | |
std::cout <<" process ValueTensorType op" <<std::endl; | |
inputs_vec.push_back(operand); | |
} else{ | |
return rewriter.notifyMatchFailure( | |
op, "unimplemented: inputs type. The input has to be int/bool/float "); | |
} | |
} | |
// Create output type for tosa::CustomOp input | |
auto outType = this->getTypeConverter()->convertType(op.getType()); | |
llvm::ArrayRef<mlir::Value> ref(inputs_vec.data(), inputs_vec.size()); | |
ValueRange custom_inputs(ref); | |
rewriter.replaceOpWithNewOp<tosa::CustomOp>( | |
op, outType, op.getOperationName(), custom_inputs); | |
std::cout<<std::endl; | |
std::cout << " adaptor.self().getType(): " << std::endl; | |
adaptor.self().getType().dump(); // tensor<2x3xf32> | |
std::cout<<std::endl; | |
std::cout << " op.getType(): " << std::endl; // !torch.vtensor<[2,3],f32> | |
op.getType().dump(); | |
std::cout<<std::endl; | |
// auto selfType = adaptor.self().getType().template dyn_cast<TensorType>(); | |
// if (!selfType) | |
// return rewriter.notifyMatchFailure( | |
// op, "unimplemented: value 'self' should be tensor types "); | |
// | |
// // Get the dim int64_t type value from AtenSoftmaxIntOp second input, | |
// // type need to convert from mlir::TypedValue<::mlir::torch::Torch::IntType> | |
// int64_t dim; | |
// if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) | |
// return rewriter.notifyMatchFailure( | |
// op, "unimplemented: value `dim` should be a torch.constant.int"); | |
// | |
// int64_t dtype; | |
// if (!matchPattern(op.dtype(), m_TorchConstantInt(&dtype))) | |
// return rewriter.notifyMatchFailure( | |
// op, "unimplemented: value `dtype` should be a torch.constant.int"); | |
// | |
// // Create output type for tosa::CustomOp input | |
// auto outType = this->getTypeConverter()->convertType(op.getType()); | |
// // Create name attribute and multi-args for tosa::CustomOp input | |
// StringAttr nameValueAttr = rewriter.getStringAttr("aten.softmax.int"); | |
// auto dimTensor = | |
// tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), dim, {1}); | |
// auto dtypeTensor = | |
// tosa::getConstTensor<int64_t>(rewriter, op.getOperation(), dtype, {1}); | |
// | |
// rewriter.replaceOpWithNewOp<tosa::CustomOp>( | |
// op, outType, nameValueAttr, | |
// ValueRange{adaptor.self(), dimTensor.value(), dtypeTensor.value()}); | |
return success(); | |
} | |
}; |
Author
AmosLewis
commented
Nov 4, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment