Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created April 20, 2023 19:19
Show Gist options
  • Save AmosLewis/41081e79cb30b4ef8ec736b0bc8f6bf3 to your computer and use it in GitHub Desktop.
Save AmosLewis/41081e79cb30b4ef8ec736b0bc8f6bf3 to your computer and use it in GitHub Desktop.
func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64>{
%false = torch.constant.bool false
%none = torch.constant.none
%indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
return %out : !torch.vtensor<[1,4],si64>
}
@AmosLewis
Copy link
Author

AmosLewis commented Aug 1, 2023

STABLEHLO

// CHECK-LABEL:   func.func @torch.aten._index_put_impl.zerorank(
// CHECK-SAME:                                                   %[[VAL_0:.*]]: !torch.vtensor<[1,4],si64>,
// CHECK-SAME:                                                   %[[VAL_1:.*]]: !torch.vtensor<[],si64>,
// CHECK-SAME:                                                   %[[VAL_2:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64> {
// CHECK:           %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
// CHECK:           %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK:           %[[VAL_5:.*]] = torch.constant.bool false
// CHECK:           %[[VAL_6:.*]] = torch.constant.none
// CHECK:           %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_1]] : (!torch.none, !torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
// CHECK:           %[[VAL_8:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK:           %[[VAL_9:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK:           %[[VAL_10:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK:           %[[VAL_11:.*]] = stablehlo.reshape %[[VAL_10]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_12:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK:           %[[VAL_13:.*]] = stablehlo.reshape %[[VAL_12]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_14:.*]] = stablehlo.concatenate %[[VAL_11]], %[[VAL_13]], dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64>
// CHECK:           %[[VAL_15:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK:           %[[VAL_16:.*]] = stablehlo.reshape %[[VAL_15]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_17:.*]] = stablehlo.reshape %[[VAL_16]] : (tensor<1x1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_18:.*]] = "stablehlo.scatter"(%[[VAL_3]], %[[VAL_14]], %[[VAL_17]]) ({
// CHECK:           ^bb0(%[[VAL_19:.*]]: tensor<i64>, %[[VAL_20:.*]]: tensor<i64>):
// CHECK:             stablehlo.return %[[VAL_20]] : tensor<i64>
// CHECK:           }) {indices_are_sorted = false, scatter_dimension_numbers = #[[?]]<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<1x4xi64>, tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
// CHECK:           %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
// CHECK:           return %[[VAL_21]] : !torch.vtensor<[1,4],si64>
// CHECK:         }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment