Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created April 12, 2023 23:26
Show Gist options
  • Save AmosLewis/e637465b23f6f96965487ad6e7e8f7b1 to your computer and use it in GitHub Desktop.
Save AmosLewis/e637465b23f6f96965487ad6e7e8f7b1 to your computer and use it in GitHub Desktop.
func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64>{
%0 = torch.aten.abs %arg0 : !torch.vtensor<[15,15],si64> -> !torch.vtensor<[15,15],si64>
return %0 : !torch.vtensor<[15,15],si64>
}
@AmosLewis
Copy link
Author

template <>
LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite(
    AtenAbsOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {

  // Not a tensor type.
  auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
  if (!selfType)
    return rewriter.notifyMatchFailure(
        op, "Only tensor types input are currently supported");


  auto outType = getTypeConverter()->convertType(op.getType());
  rewriter.replaceOpWithNewOp<tosa::AbsOp>(
      op, outType, adaptor.getSelf());

  return success();
}

@AmosLewis
Copy link
Author

(mlir_venv) nod% torch-mlir-opt -convert-torch-to-tosa  ./test_abs.mlir     
module {
  func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64>
    %1 = "tosa.abs"(%0) : (tensor<15x15xi64>) -> tensor<15x15xi64>
    %2 = torch_c.from_builtin_tensor %1 : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64>
    return %2 : !torch.vtensor<[15,15],si64>
  }
}

@AmosLewis
Copy link
Author

// CHECK-LABEL: func.func @torch.aten.abs(
// CHECK-SAME: %[[VAL_0:.]]: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> {
// CHECK: %[[VAL_1:.
]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64>
// CHECK: %[[VAL_2:.]] = "tosa.abs"(%[[VAL_1]]) : (tensor<15x15xi64>) -> tensor<15x15xi64>
// CHECK: %[[VAL_3:.
]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[15,15],si64>
// CHECK: }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment