Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active April 13, 2023 00:15
Show Gist options
  • Save AmosLewis/a0b9cd193a103301bb8168eadae9e0b1 to your computer and use it in GitHub Desktop.
Save AmosLewis/a0b9cd193a103301bb8168eadae9e0b1 to your computer and use it in GitHub Desktop.
func.func @torch.aten.le.Tensor(%arg0: !torch.vtensor<[1,4,4],si64>, %arg1: !torch.vtensor<[1,4,1],si64>) -> !torch.vtensor<[1,4,4],i1>{
%0 = torch.aten.le.Tensor %arg0, %arg1 : !torch.vtensor<[1,4,4],si64>, !torch.vtensor<[1,4,1],si64> -> !torch.vtensor<[1,4,4],i1>
return %0 : !torch.vtensor<[1,4,4],i1>
}
@AmosLewis
Copy link
Author

template <>
LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
    AtenLeTensorOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {

  // Not a tensor type.
  auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
  if (!selfType)
    return rewriter.notifyMatchFailure(
        op, "Only tensor types input are currently supported");
  auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
  if (!otherType)
    return rewriter.notifyMatchFailure(
        op, "Only tensor types condition are currently supported");

  auto outType = getTypeConverter()->convertType(op.getType());

  auto greaterOp = rewriter.create<tosa::GreaterOp>(op.getLoc(), outType, adaptor.getSelf(), adaptor.getOther());

  rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(
      op, outType, greaterOp.getOutput());

  return success();
}

@AmosLewis
Copy link
Author

AmosLewis commented Apr 13, 2023

(mlir_venv) nod% torch-mlir-opt -convert-torch-to-tosa  ./test_le_tensor.mlir
module {
  func.func @torch.aten.le.Tensor(%arg0: !torch.vtensor<[1,4,4],si64>, %arg1: !torch.vtensor<[1,4,1],si64>) -> !torch.vtensor<[1,4,4],i1> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4,4],si64> -> tensor<1x4x4xi64>
    %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,4,1],si64> -> tensor<1x4x1xi64>
    %2 = "tosa.greater"(%0, %1) : (tensor<1x4x4xi64>, tensor<1x4x1xi64>) -> tensor<1x4x4xi1>
    %3 = "tosa.logical_not"(%2) : (tensor<1x4x4xi1>) -> tensor<1x4x4xi1>
    %4 = torch_c.from_builtin_tensor %3 : tensor<1x4x4xi1> -> !torch.vtensor<[1,4,4],i1>
    return %4 : !torch.vtensor<[1,4,4],i1>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment