Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active April 27, 2023 18:53
Show Gist options
  • Save AmosLewis/448fbe07b161389ec6b1bfbdd54b6ec5 to your computer and use it in GitHub Desktop.
Save AmosLewis/448fbe07b161389ec6b1bfbdd54b6ec5 to your computer and use it in GitHub Desktop.
func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[1,3],si64>) -> !torch.vtensor<[1,4],si64>{
%false = torch.constant.bool false
%none = torch.constant.none
%indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[1,3],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
return %out : !torch.vtensor<[1,4],si64>
}
@AmosLewis
Copy link
Author

AmosLewis commented Apr 19, 2023

TOSA

func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[1,3],si64>) -> !torch.vtensor<[1,4],si64>{
  %false = torch.constant.bool false
  %none = torch.constant.none
  %indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
  %out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[1,3],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
  return %out : !torch.vtensor<[1,4],si64>
}

torch-mlir-opt -convert-torch-to-tosa ./t5small/test_indexput.mlir

module {
  func.func @torch.aten._index_put_impl(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[1,3],si64>) -> !torch.vtensor<[1,4],si64> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
    %1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[1,3],si64> -> tensor<1x3xi64>
    %false = torch.constant.bool false
    %none = torch.constant.none
    %2 = torch.prim.ListConstruct %none, %arg1 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %3 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
    %4 = "tosa.const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32>
    %5 = "tosa.reshape"(%4) {new_shape = array<i64: 3, 1>} : (tensor<3xi32>) -> tensor<3x1xi32>
    %6 = "tosa.cast"(%3) : (tensor<3xi64>) -> tensor<3xi32>
    %7 = "tosa.reshape"(%6) {new_shape = array<i64: 3, 1>} : (tensor<3xi32>) -> tensor<3x1xi32>
    %8 = "tosa.concat"(%5, %7) {axis = 1 : i64} : (tensor<3x1xi32>, tensor<3x1xi32>) -> tensor<3x2xi32>
    %9 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 4, 1>} : (tensor<1x4xi64>) -> tensor<1x4x1xi64>
    %10 = "tosa.reshape"(%1) {new_shape = array<i64: 1, 3, 1>} : (tensor<1x3xi64>) -> tensor<1x3x1xi64>
    %11 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>) -> tensor<3x2xi32>
    %12 = "tosa.const"() {value = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
    %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
    %14 = "tosa.reduce_sum"(%13) {axis = 1 : i64} : (tensor<3x2xi32>) -> tensor<3x1xi32>
    %15 = "tosa.reshape"(%14) {new_shape = array<i64: 1, 3>} : (tensor<3x1xi32>) -> tensor<1x3xi32>
    %16 = "tosa.scatter"(%9, %15, %10) : (tensor<1x4x1xi64>, tensor<1x3xi32>, tensor<1x3x1xi64>) -> tensor<1x4x1xi64>
    %17 = "tosa.reshape"(%16) {new_shape = array<i64: 1, 4>} : (tensor<1x4x1xi64>) -> tensor<1x4xi64>
    %18 = torch_c.from_builtin_tensor %17 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
    return %18 : !torch.vtensor<[1,4],si64>
  }
}


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment