Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created April 20, 2023 19:19
Show Gist options
  • Save AmosLewis/41081e79cb30b4ef8ec736b0bc8f6bf3 to your computer and use it in GitHub Desktop.
Save AmosLewis/41081e79cb30b4ef8ec736b0bc8f6bf3 to your computer and use it in GitHub Desktop.
func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64>{
%false = torch.constant.bool false
%none = torch.constant.none
%indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
return %out : !torch.vtensor<[1,4],si64>
}
@AmosLewis
Copy link
Author

AmosLewis commented Apr 20, 2023

torch-mlir-opt -convert-torch-to-tosa ./t5small/test_indexput_zerorank.mlir
TOSA

module {
  func.func @torch.aten._index_put_impl(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
    %1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[],si64> -> tensor<i64>
    %false = torch.constant.bool false
    %none = torch.constant.none
    %2 = torch.prim.ListConstruct %none, %arg1 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %3 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
    %4 = "tosa.const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32>
    %5 = "tosa.reshape"(%4) {new_shape = array<i64: 3, 1>} : (tensor<3xi32>) -> tensor<3x1xi32>
    %6 = "tosa.cast"(%3) : (tensor<3xi64>) -> tensor<3xi32>
    %7 = "tosa.reshape"(%6) {new_shape = array<i64: 3, 1>} : (tensor<3xi32>) -> tensor<3x1xi32>
    %8 = "tosa.concat"(%5, %7) {axis = 1 : i64} : (tensor<3x1xi32>, tensor<3x1xi32>) -> tensor<3x2xi32>
    %9 = "tosa.reshape"(%1) {new_shape = array<i64: 1>} : (tensor<i64>) -> tensor<1xi64>
    %10 = "tosa.tile"(%9) {multiples = array<i64: 3>} : (tensor<1xi64>) -> tensor<3xi64>
    %11 = "tosa.reshape"(%10) {new_shape = array<i64: 1, 3>} : (tensor<3xi64>) -> tensor<1x3xi64>
    %12 = "tosa.reshape"(%11) {new_shape = array<i64: 1, 3, 1>} : (tensor<1x3xi64>) -> tensor<1x3x1xi64>
    %13 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 4, 1>} : (tensor<1x4xi64>) -> tensor<1x4x1xi64>
    %14 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>) -> tensor<3x2xi32>
    %15 = "tosa.const"() {value = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
    %16 = "tosa.mul"(%14, %15) {shift = 0 : i32} : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
    %17 = "tosa.reduce_sum"(%16) {axis = 1 : i64} : (tensor<3x2xi32>) -> tensor<3x1xi32>
    %18 = "tosa.reshape"(%17) {new_shape = array<i64: 1, 3>} : (tensor<3x1xi32>) -> tensor<1x3xi32>
    %19 = "tosa.scatter"(%13, %18, %12) : (tensor<1x4x1xi64>, tensor<1x3xi32>, tensor<1x3x1xi64>) -> tensor<1x4x1xi64>
    %20 = "tosa.reshape"(%19) {new_shape = array<i64: 1, 4>} : (tensor<1x4x1xi64>) -> tensor<1x4xi64>
    %21 = torch_c.from_builtin_tensor %20 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
    return %21 : !torch.vtensor<[1,4],si64>
  }
}

@AmosLewis
Copy link
Author

AmosLewis commented Aug 1, 2023

Stablehlo

func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[],si64>, %fillValues: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64>{
  %false = torch.constant.bool false
  %none = torch.constant.none
  %indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
  %out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
  return %out : !torch.vtensor<[1,4],si64>
}

--->

module {
  func.func @torch.aten._index_put_impl(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
    %1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[],si64> -> tensor<i64>
    %false = torch.constant.bool false
    %none = torch.constant.none
    %2 = torch.prim.ListConstruct %none, %arg1 : (!torch.none, !torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
    %3 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[],si64> -> tensor<i64>
    %4 = stablehlo.constant dense<0> : tensor<i64>
    %5 = stablehlo.reshape %1 : (tensor<i64>) -> tensor<1xi64>
    %6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor<1x1xi64>
    %7 = stablehlo.reshape %1 : (tensor<i64>) -> tensor<1xi64>
    %8 = stablehlo.reshape %7 : (tensor<1xi64>) -> tensor<1x1xi64>
    %9 = stablehlo.concatenate %6, %8, dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64>
    %10 = stablehlo.reshape %1 : (tensor<i64>) -> tensor<1xi64>
    %11 = stablehlo.reshape %10 : (tensor<1xi64>) -> tensor<1x1xi64>
    %12 = stablehlo.reshape %11 : (tensor<1x1xi64>) -> tensor<1x1xi64>
    %13 = "stablehlo.scatter"(%0, %9, %12) ({
    ^bb0(%arg3: tensor<i64>, %arg4: tensor<i64>):
      stablehlo.return %arg4 : tensor<i64>
    }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<1x4xi64>, tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
    %14 = torch_c.from_builtin_tensor %13 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
    return %14 : !torch.vtensor<[1,4],si64>
  }
}

@AmosLewis
Copy link
Author

AmosLewis commented Aug 1, 2023

STABLEHLO

// CHECK-LABEL:   func.func @torch.aten._index_put_impl.zerorank(
// CHECK-SAME:                                                   %[[VAL_0:.*]]: !torch.vtensor<[1,4],si64>,
// CHECK-SAME:                                                   %[[VAL_1:.*]]: !torch.vtensor<[],si64>,
// CHECK-SAME:                                                   %[[VAL_2:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64> {
// CHECK:           %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
// CHECK:           %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK:           %[[VAL_5:.*]] = torch.constant.bool false
// CHECK:           %[[VAL_6:.*]] = torch.constant.none
// CHECK:           %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_1]] : (!torch.none, !torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
// CHECK:           %[[VAL_8:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK:           %[[VAL_9:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK:           %[[VAL_10:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK:           %[[VAL_11:.*]] = stablehlo.reshape %[[VAL_10]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_12:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK:           %[[VAL_13:.*]] = stablehlo.reshape %[[VAL_12]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_14:.*]] = stablehlo.concatenate %[[VAL_11]], %[[VAL_13]], dim = 1 : (tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x2xi64>
// CHECK:           %[[VAL_15:.*]] = stablehlo.reshape %[[VAL_4]] : (tensor<i64>) -> tensor<1xi64>
// CHECK:           %[[VAL_16:.*]] = stablehlo.reshape %[[VAL_15]] : (tensor<1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_17:.*]] = stablehlo.reshape %[[VAL_16]] : (tensor<1x1xi64>) -> tensor<1x1xi64>
// CHECK:           %[[VAL_18:.*]] = "stablehlo.scatter"(%[[VAL_3]], %[[VAL_14]], %[[VAL_17]]) ({
// CHECK:           ^bb0(%[[VAL_19:.*]]: tensor<i64>, %[[VAL_20:.*]]: tensor<i64>):
// CHECK:             stablehlo.return %[[VAL_20]] : tensor<i64>
// CHECK:           }) {indices_are_sorted = false, scatter_dimension_numbers = #[[?]]<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<1x4xi64>, tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
// CHECK:           %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
// CHECK:           return %[[VAL_21]] : !torch.vtensor<[1,4],si64>
// CHECK:         }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment