Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active October 6, 2022 17:48
Show Gist options
  • Save AmosLewis/21d951947609c60ce967dc2e6b6dd748 to your computer and use it in GitHub Desktop.
Save AmosLewis/21d951947609c60ce967dc2e6b6dd748 to your computer and use it in GitHub Desktop.
ElementwiseAtenWhereSelfModulemlir
#loc0 = loc(unknown)
module attributes {torch.debug_module_name = "ElementwiseAtenWhereSelfModule"} {
func.func @forward(%arg0: tensor<1x1x5x5xi1> loc(unknown), %arg1: tensor<1x12x5x5xf32> loc(unknown), %arg2: tensor<?xf32> loc(unknown)) -> tensor<1x12x5x5xf32> {
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<?xf32>) -> tensor<1x12x5x5xf32> loc(#loc1)
return %0 : tensor<1x12x5x5xf32> loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/elementwise.py":150:15)
@AmosLewis
Copy link
Author

(mlir_venv) nod% torch-mlir-opt -pass-pipeline='func.func(tosa-to-linalg)' /tmp/ElementwiseAtenWhereSelfModule.mlir  | externals/llvm-project/mlir/utils/generate-test-checks.py
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py

// The script is designed to make adding checks to
// a test case fast, it is *not* designed to be authoritative
// about what constitutes a good test! The CHECK should be
// minimized and named to reflect the test intent.



// CHECK-LABEL:   func.func @forward(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<1x1x5x5xi1>,
// CHECK-SAME:                       %[[VAL_1:.*]]: tensor<1x12x5x5xf32>,
// CHECK-SAME:                       %[[VAL_2:.*]]: tensor<f32>) -> tensor<1x12x5x5xf32> {
// CHECK:           %[[VAL_3:.*]] = linalg.init_tensor [1, 12, 5, 5] : tensor<1x12x5x5xf32>
// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3]] : tensor<1x1x5x5xi1> into tensor<1x5x5xi1>
// CHECK:           %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_4]], %[[VAL_1]], %[[VAL_2]] : tensor<1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) outs(%[[VAL_3]] : tensor<1x12x5x5xf32>) {
// CHECK:           ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32):
// CHECK:             %[[VAL_10:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : f32
// CHECK:             linalg.yield %[[VAL_10]] : f32
// CHECK:           } -> tensor<1x12x5x5xf32>
// CHECK:           return %[[VAL_11:.*]] : tensor<1x12x5x5xf32>
// CHECK:         }

(mlir_venv) nod% 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment