Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created September 16, 2024 04:48
Show Gist options
  • Save AmosLewis/ff1ec4411b7cc5662b95ddc030ddcad0 to your computer and use it in GitHub Desktop.
Save AmosLewis/ff1ec4411b7cc5662b95ddc030ddcad0 to your computer and use it in GitHub Desktop.
module {
func.func @main_graph(%arg5:!torch.vtensor<[2708],f32>, %arg1: !torch.vtensor<[?],si64>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2708],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
%59 = torch.operator "onnx.ScatterElements"(%arg5, %arg1, %arg2) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[2708],f32>, !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[2708],f32>
return %59 : !torch.vtensor<[2708],f32>
}
}
@AmosLewis
Copy link
Author

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize model1.mlir

module {
  func.func @main_graph(%arg0: !torch.vtensor<[2708],f32>, %arg1: !torch.vtensor<[?],si64>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2708],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
    %str = torch.constant.str "add"
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int2708 = torch.constant.int 2708
    %0 = torch.aten.add.Scalar %arg1, %int2708, %int1 : !torch.vtensor<[?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
    %1 = torch.aten.lt.Scalar %arg1, %int0 : !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],i1>
    %2 = torch.aten.where.self %1, %0, %arg1 : !torch.vtensor<[?],i1>, !torch.vtensor<[?],si64>, !torch.vtensor<[?],si64> -> !torch.vtensor<[?],si64>
    %3 = torch.aten.scatter.reduce %arg0, %int0, %2, %arg2, %str : !torch.vtensor<[2708],f32>, !torch.int, !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.str -> !torch.vtensor<[2708],f32>
    return %3 : !torch.vtensor<[2708],f32>
  }
}

@AmosLewis
Copy link
Author

AmosLewis commented Sep 16, 2024

model2.mlir pure torch.aten.scatter.reduce

module {
  func.func @main_graph(%arg0: !torch.vtensor<[2708],f32>, %arg1: !torch.vtensor<[?],si64>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2708],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
    %str = torch.constant.str "add"
    %int0 = torch.constant.int 0
    %3 = torch.aten.scatter.reduce %arg0, %int0, %arg1, %arg2, %str : !torch.vtensor<[2708],f32>, !torch.int, !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.str -> !torch.vtensor<[2708],f32>
    return %3 : !torch.vtensor<[2708],f32>
  }
}
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.scatter.reduce'(0x55cbc3908420) {
  %2 = "torch.aten.scatter.reduce"(%arg0, %1, %arg1, %arg2, %0) : (!torch.vtensor<[2708],f32>, !torch.int, !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.str) -> !torch.vtensor<[2708],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.scatter.reduce -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenScalarToTensorLike"
    ** Failure : not a supported Scalar to Tensor like op
"(anonymous namespace)::ConvertAtenScalarToTensorLike" result 0
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.scatter.reduce -> ()' {
Trying to match "(anonymous namespace)::ConvertElementwiseOp"
    ** Failure : not a supported elementwise op
"(anonymous namespace)::ConvertElementwiseOp" result 0
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.scatter.reduce -> ()' {
Trying to match "(anonymous namespace)::ConvertReductionOp"
    ** Failure : not a supported reduce op
"(anonymous namespace)::ConvertReductionOp" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment