Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created October 6, 2022 17:36
Show Gist options
  • Save AmosLewis/32847885f8b3ff27b7ef6564154fec59 to your computer and use it in GitHub Desktop.
Save AmosLewis/32847885f8b3ff27b7ef6564154fec59 to your computer and use it in GitHub Desktop.
wheremlirrun
func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
%0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32>
return %0 : !torch.vtensor<[1,12,5,5],f32>
}
@AmosLewis
Copy link
Author

AmosLewis commented Oct 6, 2022

(mlir_venv) nod% torch-mlir-opt -pass-pipeline='torch-backend-to-tosa-backend-pipeline' /tmp/where.mlir -mlir-print-ir-after-all -mlir-pretty-debuginfo -mlir-disable-threading
// -----// IR Dump After ConvertTorchToTosa (convert-torch-to-tosa) //----- //
func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1>
  %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
  %2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[],f32> -> tensor<f32>
  %3 = "tosa.select"(%0, %1, %2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
  %4 = torch_c.from_builtin_tensor %3 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
  return %4 : !torch.vtensor<[1,12,5,5],f32>
}

// -----// IR Dump After TosaMakeBroadcastable (tosa-make-broadcastable) //----- //
func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1>
  %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
  %2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[],f32> -> tensor<f32>
  %3 = "tosa.select"(%0, %1, %2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
  %4 = torch_c.from_builtin_tensor %3 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
  return %4 : !torch.vtensor<[1,12,5,5],f32>
}

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1>
  %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
  %2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[],f32> -> tensor<f32>
  %3 = "tosa.select"(%0, %1, %2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
  %4 = torch_c.from_builtin_tensor %3 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
  return %4 : !torch.vtensor<[1,12,5,5],f32>
}

// -----// IR Dump After CSE (cse) //----- //
func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1>
  %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
  %2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[],f32> -> tensor<f32>
  %3 = "tosa.select"(%0, %1, %2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
  %4 = torch_c.from_builtin_tensor %3 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
  return %4 : !torch.vtensor<[1,12,5,5],f32>
}

// -----// IR Dump After FuncBackendTypeConversion (torch-func-backend-type-conversion) //----- //
module {
  func.func @torch.aten.where.self(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<f32>) -> tensor<1x12x5x5xf32> {
    %0 = torch_c.from_builtin_tensor %arg2 : tensor<f32> -> !torch.vtensor<[],f32>
    %1 = torch_c.from_builtin_tensor %arg1 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
    %2 = torch_c.from_builtin_tensor %arg0 : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1>
    %4 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
    %5 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
    %6 = "tosa.select"(%3, %4, %5) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
    %7 = torch_c.from_builtin_tensor %6 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
    %8 = torch_c.to_builtin_tensor %7 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
    return %8 : tensor<1x12x5x5xf32>
  }
}


// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @torch.aten.where.self(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<f32>) -> tensor<1x12x5x5xf32> {
  %0 = torch_c.from_builtin_tensor %arg2 : tensor<f32> -> !torch.vtensor<[],f32>
  %1 = torch_c.from_builtin_tensor %arg1 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
  %2 = torch_c.from_builtin_tensor %arg0 : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1>
  %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1>
  %4 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
  %5 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
  %6 = "tosa.select"(%3, %4, %5) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
  %7 = torch_c.from_builtin_tensor %6 : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32>
  %8 = torch_c.to_builtin_tensor %7 : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32>
  return %8 : tensor<1x12x5x5xf32>
}

// -----// IR Dump After FinalizingBackendTypeConversion (torch-finalizing-backend-type-conversion) //----- //
func.func @torch.aten.where.self(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<f32>) -> tensor<1x12x5x5xf32> {
  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
  return %0 : tensor<1x12x5x5xf32>
}

// -----// IR Dump After VerifyTosaBackendContract (torch-verify-tosa-backend-contract) //----- //
module {
  func.func @torch.aten.where.self(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<f32>) -> tensor<1x12x5x5xf32> {
    %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
    return %0 : tensor<1x12x5x5xf32>
  }
}


module {
  func.func @torch.aten.where.self(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<f32>) -> tensor<1x12x5x5xf32> {
    %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
    return %0 : tensor<1x12x5x5xf32>
  }
}

(mlir_venv) nod% 


@AmosLewis
Copy link
Author

(mlir_venv) nod% torch-mlir-opt -pass-pipeline='torch-backend-to-tosa-backend-pipeline' /tmp/where.mlir  | externals/llvm-project/mlir/utils/generate-test-checks.py
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py

// The script is designed to make adding checks to
// a test case fast, it is *not* designed to be authoritative
// about what constitutes a good test! The CHECK should be
// minimized and named to reflect the test intent.



// CHECK-LABEL:   func.func @torch.aten.where.self(
// CHECK-SAME:                                     %[[VAL_0:.*]]: tensor<1x1x5x5xi1>,
// CHECK-SAME:                                     %[[VAL_1:.*]]: tensor<1x12x5x5xf32>,
// CHECK-SAME:                                     %[[VAL_2:.*]]: tensor<f32>) -> tensor<1x12x5x5xf32> {
// CHECK:           %[[VAL_3:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
// CHECK:           return %[[VAL_3]] : tensor<1x12x5x5xf32>
// CHECK:         }

(mlir_venv) nod% 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment