Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 5, 2022 06:54
Show Gist options
  • Save AmosLewis/a2957a0cf1894f2c870ba2d58e5e17e7 to your computer and use it in GitHub Desktop.
Save AmosLewis/a2957a0cf1894f2c870ba2d58e5e17e7 to your computer and use it in GitHub Desktop.
func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2,3],si64>) -> !torch.vtensor<[1,4,2],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],f32>, !torch.int, !torch.vtensor<[1,4,2,3],si64>, !torch.bool -> !torch.vtensor<[1,4,2],f32>
return %0 : !torch.vtensor<[1,4,2],f32>
}
@AmosLewis
Copy link
Author

AmosLewis commented Dec 5, 2022

This is tf style indices

module {
  func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2,3],si64>) -> !torch.vtensor<[1,4,2],f32> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32>
    %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,4,2,3],si64> -> tensor<1x4x2x3xi64>
    %int-1 = torch.constant.int -1
    %false = torch.constant.bool false
    %2 = "tosa.reshape"(%0) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
    %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi64>) -> tensor<8x3xi64>
    %4 = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
    %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi64>, tensor<3xi32>) -> tensor<8x3xi64>
    %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi64>) -> tensor<?x?xi64>
    %7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<?x?xi64>) -> tensor<1x8xi64>
    %8 = "tosa.cast"(%7) : (tensor<1x8xi64>) -> tensor<1x8xi32>
    %9 = "tosa.gather"(%2, %8) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
    %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
    %11 = torch_c.from_builtin_tensor %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
    return %11 : !torch.vtensor<[1,4,2],f32>
  }
}

@AmosLewis
Copy link
Author

// CHECK-LABEL:   func.func @torch.aten.gather(
// CHECK-SAME:                                 %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
// CHECK-SAME:                                 %[[VAL_1:.*]]: !torch.vtensor<[1,4,2,3],si64>) -> !torch.vtensor<[1,4,2],f32> {
// CHECK:           %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32>
// CHECK:           %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2,3],si64> -> tensor<1x4x2x3xi64>
// CHECK:           %[[VAL_4:.*]] = torch.constant.int -1
// CHECK:           %[[VAL_5:.*]] = torch.constant.bool false
// CHECK:           %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
// CHECK:           %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [8, 3]} : (tensor<1x4x2x3xi64>) -> tensor<8x3xi64>
// CHECK:           %[[VAL_8:.*]] = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK:           %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_7]], %[[VAL_8]]) {shift = 0 : i32} : (tensor<8x3xi64>, tensor<3xi32>) -> tensor<8x3xi64>
// CHECK:           %[[VAL_10:.*]] = "tosa.reduce_sum"(%[[VAL_9]]) {axis = 1 : i64} : (tensor<8x3xi64>) -> tensor<?x?xi64>
// CHECK:           %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = [1, 8]} : (tensor<?x?xi64>) -> tensor<1x8xi64>
// CHECK:           %[[VAL_12:.*]] = "tosa.cast"(%[[VAL_11]]) : (tensor<1x8xi64>) -> tensor<1x8xi32>
// CHECK:           %[[VAL_13:.*]] = "tosa.gather"(%[[VAL_6]], %[[VAL_12]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
// CHECK:           %[[VAL_14:.*]] = "tosa.reshape"(%[[VAL_13]]) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
// CHECK:           %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
// CHECK:           return %[[VAL_15]] : !torch.vtensor<[1,4,2],f32>
// CHECK:         }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment