Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created October 20, 2022 01:19
Show Gist options
  • Save AmosLewis/e668c3bfd2472e9f9f045e012362d831 to your computer and use it in GitHub Desktop.
Save AmosLewis/e668c3bfd2472e9f9f045e012362d831 to your computer and use it in GitHub Desktop.
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%values, %indices = torch.aten.max.dim %arg0, %int1, %true : !torch.vtensor<[2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],f32>, !torch.vtensor<[2,1],si64>
%float1.000000e00 = torch.constant.float 1.000000e+00
%0 = torch.aten.sub.Tensor %arg0, %values, %float1.000000e00 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,1],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
%1 = torch.aten.exp %0 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%true_0 = torch.constant.bool true
%none_1 = torch.constant.none
%3 = torch.aten.sum.dim_IntList %1, %2, %true_0, %none_1 : !torch.vtensor<[2,3],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,1],f32>
%4 = torch.aten.div.Tensor %1, %3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,1],f32> -> !torch.vtensor<[2,3],f32>
%5 = torch.tensor_static_info_cast %4 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[2,3],f32>
return %5 : !torch.vtensor<[2,3],f32>
}
@AmosLewis
Copy link
Author

AmosLewis commented Oct 20, 2022

The up mlir file is got by

(mlir_venv) nod% torch-mlir-opt -torch-decompose-complex-ops  test/Dialect/Torch/decompose-complex-ops-legal.mlir
 
module {
  func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
    %none = torch.constant.none
    %int1 = torch.constant.int 1
    %true = torch.constant.bool true
    %values, %indices = torch.aten.max.dim %arg0, %int1, %true : !torch.tensor<[2,3],f32>, !torch.int, !torch.bool -> !torch.tensor<[2,1],f32>, !torch.tensor<[2,1],si64>
    %float1.000000e00 = torch.constant.float 1.000000e+00
    %0 = torch.aten.sub.Tensor %arg0, %values, %float1.000000e00 : !torch.tensor<[2,3],f32>, !torch.tensor<[2,1],f32>, !torch.float -> !torch.tensor<[2,3],f32>
    %1 = torch.aten.exp %0 : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
    %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
    %true_0 = torch.constant.bool true
    %none_1 = torch.constant.none
    %3 = torch.aten.sum.dim_IntList %1, %2, %true_0, %none_1 : !torch.tensor<[2,3],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor<[2,1],f32>
    %4 = torch.aten.div.Tensor %1, %3 : !torch.tensor<[2,3],f32>, !torch.tensor<[2,1],f32> -> !torch.tensor<[2,3],f32>
    %5 = torch.tensor_static_info_cast %4 : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32>
    return %5 : !torch.tensor<[2,3],f32>
  }
}

Then you have to manually change all the torch.tensor into torch.vtensor. But don't change the torch.tensor_static_info_cast

(mlir_venv) nod% torch-mlir-opt -torch-backend-to-tosa-backend-pipeline    /tmp/softmax-decompose.mlir

module {
  func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
    %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %1 = "tosa.sub"(%arg0, %0) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    %2 = "tosa.exp"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
    %3 = "tosa.reduce_sum"(%2) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    return %5 : tensor<2x3xf32>
  }
}

@AmosLewis
Copy link
Author

Here is all all the ir when lowering

(mlir_venv) nod% torch-mlir-opt -torch-backend-to-tosa-backend-pipeline    /tmp/softmax-decompose.mlir --mlir-print-ir-after-all  

// -----// IR Dump After ConvertTorchToTosa (convert-torch-to-tosa) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
  %none = torch.constant.none
  %int1 = torch.constant.int 1
  %1 = torch_c.to_i64 %int1
  %true = torch.constant.bool true
  %2 = "tosa.reduce_max"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %3 = "tosa.argmax"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2xi64>
  %4 = "tosa.reshape"(%3) {new_shape = [2, 1]} : (tensor<2xi64>) -> tensor<2x1xi64>
  %float1.000000e00 = torch.constant.float 1.000000e+00
  %5 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %6 = "tosa.mul"(%2, %5) {shift = 0 : i32} : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
  %7 = "tosa.sub"(%0, %6) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %8 = "tosa.exp"(%7) : (tensor<2x3xf32>) -> tensor<2x3xf32>
  %9 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
  %true_0 = torch.constant.bool true
  %none_1 = torch.constant.none
  %10 = "tosa.reduce_sum"(%8) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %11 = "tosa.reciprocal"(%10) : (tensor<2x1xf32>) -> tensor<2x1xf32>
  %12 = "tosa.mul"(%8, %11) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %13 = torch_c.from_builtin_tensor %12 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
  return %13 : !torch.vtensor<[2,3],f32>
}

// -----// IR Dump After TosaMakeBroadcastable (tosa-make-broadcastable) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
  %0 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
  %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
  %2 = "tosa.reduce_max"(%1) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %3 = "tosa.mul"(%2, %0) {shift = 0 : i32} : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32>
  %4 = "tosa.sub"(%1, %3) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %5 = "tosa.exp"(%4) : (tensor<2x3xf32>) -> tensor<2x3xf32>
  %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %7 = "tosa.reciprocal"(%6) : (tensor<2x1xf32>) -> tensor<2x1xf32>
  %8 = "tosa.mul"(%5, %7) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %9 = torch_c.from_builtin_tensor %8 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
  return %9 : !torch.vtensor<[2,3],f32>
}

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
  %1 = "tosa.reduce_max"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %2 = "tosa.sub"(%0, %1) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %3 = "tosa.exp"(%2) : (tensor<2x3xf32>) -> tensor<2x3xf32>
  %4 = "tosa.reduce_sum"(%3) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %5 = "tosa.reciprocal"(%4) : (tensor<2x1xf32>) -> tensor<2x1xf32>
  %6 = "tosa.mul"(%3, %5) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %7 = torch_c.from_builtin_tensor %6 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
  return %7 : !torch.vtensor<[2,3],f32>
}

// -----// IR Dump After CSE (cse) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
  %1 = "tosa.reduce_max"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %2 = "tosa.sub"(%0, %1) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %3 = "tosa.exp"(%2) : (tensor<2x3xf32>) -> tensor<2x3xf32>
  %4 = "tosa.reduce_sum"(%3) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %5 = "tosa.reciprocal"(%4) : (tensor<2x1xf32>) -> tensor<2x1xf32>
  %6 = "tosa.mul"(%3, %5) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %7 = torch_c.from_builtin_tensor %6 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
  return %7 : !torch.vtensor<[2,3],f32>
}

// -----// IR Dump After FuncBackendTypeConversion (torch-func-backend-type-conversion) //----- //
module {
  func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
    %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
    %2 = "tosa.reduce_max"(%1) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %3 = "tosa.sub"(%1, %2) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    %4 = "tosa.exp"(%3) : (tensor<2x3xf32>) -> tensor<2x3xf32>
    %5 = "tosa.reduce_sum"(%4) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %6 = "tosa.reciprocal"(%5) : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %7 = "tosa.mul"(%4, %6) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    %8 = torch_c.from_builtin_tensor %7 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
    %9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
    return %9 : tensor<2x3xf32>
  }
}


// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
  %0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
  %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
  %2 = "tosa.reduce_max"(%1) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %3 = "tosa.sub"(%1, %2) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %4 = "tosa.exp"(%3) : (tensor<2x3xf32>) -> tensor<2x3xf32>
  %5 = "tosa.reduce_sum"(%4) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %6 = "tosa.reciprocal"(%5) : (tensor<2x1xf32>) -> tensor<2x1xf32>
  %7 = "tosa.mul"(%4, %6) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %8 = torch_c.from_builtin_tensor %7 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
  %9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
  return %9 : tensor<2x3xf32>
}

// -----// IR Dump After FinalizingBackendTypeConversion (torch-finalizing-backend-type-conversion) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
  %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %1 = "tosa.sub"(%arg0, %0) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  %2 = "tosa.exp"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
  %3 = "tosa.reduce_sum"(%2) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
  %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
  %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
  return %5 : tensor<2x3xf32>
}

// -----// IR Dump After VerifyTosaBackendContract (torch-verify-tosa-backend-contract) //----- //
module {
  func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
    %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %1 = "tosa.sub"(%arg0, %0) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    %2 = "tosa.exp"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
    %3 = "tosa.reduce_sum"(%2) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    return %5 : tensor<2x3xf32>
  }
}


module {
  func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
    %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %1 = "tosa.sub"(%arg0, %0) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    %2 = "tosa.exp"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
    %3 = "tosa.reduce_sum"(%2) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
    %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
    return %5 : tensor<2x3xf32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment