Created
October 20, 2022 01:19
-
-
Save AmosLewis/e668c3bfd2472e9f9f045e012362d831 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { | |
%none = torch.constant.none | |
%int1 = torch.constant.int 1 | |
%true = torch.constant.bool true | |
%values, %indices = torch.aten.max.dim %arg0, %int1, %true : !torch.vtensor<[2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],f32>, !torch.vtensor<[2,1],si64> | |
%float1.000000e00 = torch.constant.float 1.000000e+00 | |
%0 = torch.aten.sub.Tensor %arg0, %values, %float1.000000e00 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,1],f32>, !torch.float -> !torch.vtensor<[2,3],f32> | |
%1 = torch.aten.exp %0 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> | |
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int> | |
%true_0 = torch.constant.bool true | |
%none_1 = torch.constant.none | |
%3 = torch.aten.sum.dim_IntList %1, %2, %true_0, %none_1 : !torch.vtensor<[2,3],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,1],f32> | |
%4 = torch.aten.div.Tensor %1, %3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,1],f32> -> !torch.vtensor<[2,3],f32> | |
%5 = torch.tensor_static_info_cast %4 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[2,3],f32> | |
return %5 : !torch.vtensor<[2,3],f32> | |
} |
Here is all all the ir when lowering
(mlir_venv) nod% torch-mlir-opt -torch-backend-to-tosa-backend-pipeline /tmp/softmax-decompose.mlir --mlir-print-ir-after-all
// -----// IR Dump After ConvertTorchToTosa (convert-torch-to-tosa) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
%none = torch.constant.none
%int1 = torch.constant.int 1
%1 = torch_c.to_i64 %int1
%true = torch.constant.bool true
%2 = "tosa.reduce_max"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%3 = "tosa.argmax"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2xi64>
%4 = "tosa.reshape"(%3) {new_shape = [2, 1]} : (tensor<2xi64>) -> tensor<2x1xi64>
%float1.000000e00 = torch.constant.float 1.000000e+00
%5 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%6 = "tosa.mul"(%2, %5) {shift = 0 : i32} : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
%7 = "tosa.sub"(%0, %6) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%8 = "tosa.exp"(%7) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%9 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%true_0 = torch.constant.bool true
%none_1 = torch.constant.none
%10 = "tosa.reduce_sum"(%8) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%11 = "tosa.reciprocal"(%10) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%12 = "tosa.mul"(%8, %11) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%13 = torch_c.from_builtin_tensor %12 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
return %13 : !torch.vtensor<[2,3],f32>
}
// -----// IR Dump After TosaMakeBroadcastable (tosa-make-broadcastable) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%0 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
%2 = "tosa.reduce_max"(%1) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%3 = "tosa.mul"(%2, %0) {shift = 0 : i32} : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32>
%4 = "tosa.sub"(%1, %3) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%5 = "tosa.exp"(%4) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%7 = "tosa.reciprocal"(%6) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%8 = "tosa.mul"(%5, %7) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%9 = torch_c.from_builtin_tensor %8 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
return %9 : !torch.vtensor<[2,3],f32>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
%1 = "tosa.reduce_max"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%2 = "tosa.sub"(%0, %1) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%3 = "tosa.exp"(%2) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tosa.reduce_sum"(%3) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%5 = "tosa.reciprocal"(%4) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%6 = "tosa.mul"(%3, %5) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%7 = torch_c.from_builtin_tensor %6 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
return %7 : !torch.vtensor<[2,3],f32>
}
// -----// IR Dump After CSE (cse) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
%1 = "tosa.reduce_max"(%0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%2 = "tosa.sub"(%0, %1) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%3 = "tosa.exp"(%2) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tosa.reduce_sum"(%3) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%5 = "tosa.reciprocal"(%4) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%6 = "tosa.mul"(%3, %5) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%7 = torch_c.from_builtin_tensor %6 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
return %7 : !torch.vtensor<[2,3],f32>
}
// -----// IR Dump After FuncBackendTypeConversion (torch-func-backend-type-conversion) //----- //
module {
func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
%2 = "tosa.reduce_max"(%1) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%3 = "tosa.sub"(%1, %2) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%4 = "tosa.exp"(%3) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%5 = "tosa.reduce_sum"(%4) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%6 = "tosa.reciprocal"(%5) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%7 = "tosa.mul"(%4, %6) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%8 = torch_c.from_builtin_tensor %7 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
%9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
return %9 : tensor<2x3xf32>
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
%2 = "tosa.reduce_max"(%1) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%3 = "tosa.sub"(%1, %2) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%4 = "tosa.exp"(%3) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%5 = "tosa.reduce_sum"(%4) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%6 = "tosa.reciprocal"(%5) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%7 = "tosa.mul"(%4, %6) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%8 = torch_c.from_builtin_tensor %7 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
%9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
return %9 : tensor<2x3xf32>
}
// -----// IR Dump After FinalizingBackendTypeConversion (torch-finalizing-backend-type-conversion) //----- //
func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%1 = "tosa.sub"(%arg0, %0) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%2 = "tosa.exp"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tosa.reduce_sum"(%2) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
return %5 : tensor<2x3xf32>
}
// -----// IR Dump After VerifyTosaBackendContract (torch-verify-tosa-backend-contract) //----- //
module {
func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%1 = "tosa.sub"(%arg0, %0) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%2 = "tosa.exp"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tosa.reduce_sum"(%2) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
return %5 : tensor<2x3xf32>
}
}
module {
func.func @torch.aten.softmax.int$cst_dim(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%1 = "tosa.sub"(%arg0, %0) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
%2 = "tosa.exp"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tosa.reduce_sum"(%2) {axis = 1 : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
return %5 : tensor<2x3xf32>
}
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The up mlir file is got by
Then you have to manually change all the torch.tensor into torch.vtensor. But don't change the torch.tensor_static_info_cast