Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 15, 2022 21:30
Show Gist options
  • Save AmosLewis/2f18434397025211da4491735bcc6db6 to your computer and use it in GitHub Desktop.
Save AmosLewis/2f18434397025211da4491735bcc6db6 to your computer and use it in GitHub Desktop.
func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],f32>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],f32>
return %0 : !torch.vtensor<[1,4,2],f32>
}
@AmosLewis
Copy link
Author

AmosLewis commented Dec 15, 2022

torch-mlir-opt -convert-torch-to-tosa /tmp/gathernd_torch.mlir   | externals/llvm-project/mlir/utils/generate-test-checks.py
// CHECK-LABEL:   func.func @torch.aten.gather(
// CHECK-SAME:                                 %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
// CHECK-SAME:                                 %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> {
// CHECK:           %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32>
// CHECK:           %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64>
// CHECK:           %[[VAL_4:.*]] = torch.constant.int -1
// CHECK:           %[[VAL_5:.*]] = torch.constant.bool false
// CHECK:           %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [1, 4, 2, 1]} : (tensor<1x4x2xi64>) -> tensor<1x4x2x1xi64>
// CHECK:           %[[VAL_7:.*]] = "tosa.cast"(%[[VAL_6]]) : (tensor<1x4x2x1xi64>) -> tensor<1x4x2x1xi32>
// CHECK:           %[[VAL_8:.*]] = "tosa.const"() {value = dense<0> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
// CHECK:           %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
// CHECK:           %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) {axis = 3 : i64} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
// CHECK:           %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
// CHECK:           %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32>
// CHECK:           %[[VAL_13:.*]] = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK:           %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
// CHECK:           %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<8x1xi32>
// CHECK:           %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) {new_shape = [1, 8]} : (tensor<8x1xi32>) -> tensor<1x8xi32>
// CHECK:           %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
// CHECK:           %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
// CHECK:           %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
// CHECK:           return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32>
// CHECK:         }

@AmosLewis
Copy link
Author

AmosLewis commented Dec 16, 2022

➜  torch-mlir git:(gather-deberta) ✗ torch-mlir-opt -convert-torch-to-tosa /tmp/gathernd_torch.mlir   
module {
  func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32>
    %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64>
    %int-1 = torch.constant.int -1
    %false = torch.constant.bool false
    %2 = "tosa.reshape"(%1) {new_shape = [1, 4, 2, 1]} : (tensor<1x4x2xi64>) -> tensor<1x4x2x1xi64>
    %3 = "tosa.cast"(%2) : (tensor<1x4x2x1xi64>) -> tensor<1x4x2x1xi32>
    %4 = "tosa.const"() {value = dense<0> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
    %5 = "tosa.const"() {value = dense<[[[[0], [0]], [[1], [1]], [[2], [2]], [[3], [3]]]]> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
    %6 = "tosa.concat"(%4, %5, %3) {axis = 3 : i64} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
    %7 = "tosa.reshape"(%0) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
    %8 = "tosa.reshape"(%6) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32>
    %9 = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
    %10 = "tosa.mul"(%8, %9) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
    %11 = "tosa.reduce_sum"(%10) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<8x1xi32>
    %12 = "tosa.reshape"(%11) {new_shape = [1, 8]} : (tensor<8x1xi32>) -> tensor<1x8xi32>
    %13 = "tosa.gather"(%7, %12) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
    %14 = "tosa.reshape"(%13) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
    %15 = torch_c.from_builtin_tensor %14 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
    return %15 : !torch.vtensor<[1,4,2],f32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment