Created
April 20, 2023 19:19
-
-
Save AmosLewis/41081e79cb30b4ef8ec736b0bc8f6bf3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64>{ | |
%false = torch.constant.bool false | |
%none = torch.constant.none | |
%indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>> | |
%out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64> | |
return %out : !torch.vtensor<[1,4],si64> | |
} |
Stablehlo
func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[],si64>, %fillValues: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64>{
%false = torch.constant.bool false
%none = torch.constant.none
%indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
%out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
return %out : !torch.vtensor<[1,4],si64>
}
--->
module {
func.func @torch.aten._index_put_impl(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
%1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[],si64> -> tensor<i64>
%false = torch.constant.bool false
%none = torch.constant.none
%2 = torch.prim.ListConstruct %none, %arg1 : (!torch.none, !torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
%3 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[],si64> -> tensor<i64>
%4 = stablehlo.constant dense<0> : tensor<i64>
%5 = stablehlo.reshape %1 : (tensor<i64>) -> tensor<1xi64>
%6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor<1x1xi64>
%7 = stablehlo.reshape %1 : (tensor<i64>) -> tensor<1xi64>
%8 = stablehlo.reshape %7 : (tensor<1xi64>) -> tensor<1x1xi64>
%9 = stablehlo.concatenate %6, %8, dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64>
%10 = stablehlo.reshape %1 : (tensor<i64>) -> tensor<1xi64>
%11 = stablehlo.reshape %10 : (tensor<1xi64>) -> tensor<1x1xi64>
%12 = stablehlo.reshape %11 : (tensor<1x1xi64>) -> tensor<1x1xi64>
%13 = "stablehlo.scatter"(%0, %9, %12) ({
^bb0(%arg3: tensor<i64>, %arg4: tensor<i64>):
stablehlo.return %arg4 : tensor<i64>
}) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<1x4xi64>, tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
%14 = torch_c.from_builtin_tensor %13 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
return %14 : !torch.vtensor<[1,4],si64>
}
}
STABLEHLO
// CHECK-LABEL: func.func @torch.aten._index_put_impl.zerorank(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],si64>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64> {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = torch.constant.none
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_1]] : (!torch.none, !torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
// CHECK: %[[VAL_8:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK: %[[VAL_10:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK: %[[VAL_11:.*]] = stablehlo.reshape %[[VAL_10]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK: %[[VAL_12:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK: %[[VAL_13:.*]] = stablehlo.reshape %[[VAL_12]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK: %[[VAL_14:.*]] = stablehlo.concatenate %[[VAL_11]], %[[VAL_13]], dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64>
// CHECK: %[[VAL_15:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK: %[[VAL_16:.*]] = stablehlo.reshape %[[VAL_15]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK: %[[VAL_17:.*]] = stablehlo.reshape %[[VAL_16]] : (tensor<1x1xi64>) -> tensor<1x1xi64>
// CHECK: %[[VAL_18:.*]] = "stablehlo.scatter"(%[[VAL_3]], %[[VAL_14]], %[[VAL_17]]) ({
// CHECK: ^bb0(%[[VAL_19:.*]]: tensor<i64>, %[[VAL_20:.*]]: tensor<i64>):
// CHECK: stablehlo.return %[[VAL_20]] : tensor<i64>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #[[?]]<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<1x4xi64>, tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,4],si64>
// CHECK: }
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
torch-mlir-opt -convert-torch-to-tosa ./t5small/test_indexput_zerorank.mlir
TOSA