Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created November 17, 2022 22:08
Show Gist options
  • Save AmosLewis/27208f893fec9e2fec02f2405d484b84 to your computer and use it in GitHub Desktop.
Save AmosLewis/27208f893fec9e2fec02f2405d484b84 to your computer and use it in GitHub Desktop.
func.func @torch.aten.gather(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
return %0 : !torch.vtensor<[12,128,128],f32>
}
@AmosLewis
Copy link
Author

(mlir_venv) nod% torch-mlir-opt -convert-torch-to-tosa  /tmp/gather.mlir

/tmp/gather.mlir:4:8: error: 'tosa.gather' op operand #1 must be 2D tensor of 32-bit signless integer values, but got 'tensor<1x128x128xi64>'
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
       ^
/tmp/gather.mlir:4:8: note: see current operation: %4 = "tosa.gather"(%0, %1) : (tensor<12x128x512xf32>, tensor<1x128x128xi64>) -> tensor<12x128x128xf32>

@AmosLewis
Copy link
Author

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x97a9600) {
  "func.return"(%14) : (!torch.vtensor<[12,128,128],f32>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert  : 'torch_c.to_builtin_tensor'(0x9813730)
** Insert  : 'torch_c.from_builtin_tensor'(0x9814aa0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
/tmp/gather.mlir:4:8: error: 'tosa.gather' op operand #1 must be 2D tensor of 32-bit signless integer values, but got 'tensor<1x786432xf32>'
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
       ^
/tmp/gather.mlir:4:8: note: see current operation: %8 = "tosa.gather"(%6, %7) : (tensor<1x512x1536xf32>, tensor<1x786432xf32>) -> tensor<?x?x?xf32>
// -----// IR Dump After ConvertTorchToTosa Failed (convert-torch-to-tosa) //----- //
mlir-asm-printer: Verifying operation: func.func
'tosa.gather' op operand #1 must be 2D tensor of 32-bit signless integer values, but got 'tensor<1x786432xf32>'
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() ({
^bb0(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>):
  %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[12,128,512],f32>) -> tensor<12x128x512xf32>
  %1 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int
  %2 = "torch.constant.bool"() {value = false} : () -> !torch.bool
  %3 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
  %4 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>
  %5 = "tosa.transpose"(%0, %3) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>
  %6 = "tosa.reshape"(%5) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>
  %7 = "tosa.reshape"(%0) {new_shape = [1, 786432]} : (tensor<12x128x512xf32>) -> tensor<1x786432xf32>
  %8 = "tosa.gather"(%6, %7) : (tensor<1x512x1536xf32>, tensor<1x786432xf32>) -> tensor<?x?x?xf32>
  %9 = "tosa.reshape"(%8) {new_shape = [12, 128, 512, 12, 128]} : (tensor<?x?x?xf32>) -> tensor<12x128x512x12x128xf32>
  %10 = "tosa.transpose"(%9, %4) : (tensor<12x128x512x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>
  %11 = "torch_c.from_builtin_tensor"(%10) : (tensor<*xf32>) -> !torch.vtensor<[12,128,128],f32>
  "func.return"(%11) : (!torch.vtensor<[12,128,128],f32>) -> ()
}) {function_type = (!torch.vtensor<[12,128,512],f32>, !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32>, sym_name = "torch.aten.gather"} : () -> ()

@AmosLewis
Copy link
Author

AmosLewis commented Dec 1, 2022

After directly input outType into tosa::convertGatherOp

➜  externals git:(mix) ✗ torch-mlir-opt -convert-torch-to-tosa /tmp/gather.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug  --mlir-print-ir-before-all
Args: /home/chi/src/ubuntu20/shark/torch-mlir/build/bin/torch-mlir-opt -convert-torch-to-tosa /tmp/gather.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug --mlir-print-ir-before-all 
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementTypeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ResourceBlobManagerDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineBinaryOpExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineConstantExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineDimExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineMapStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::IntegerSetStorage)
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DebugActionManager::GenericHandler)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneRegion<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroSuccessors<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoRegionArguments<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SingleBlock<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OpInvariants<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AffineScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsIsolatedFromAbove<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SymbolTable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasOnlyGraphRegion<Empty>)
Load new dialect in Context func
Load new dialect in Context cf
Load new dialect in Context arith
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithFastMathInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VectorUnrollOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectInlinerInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferizableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolUserOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AutomaticAllocationScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface::Trait<Empty>)
Load new dialect in Context torch
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroRegions<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneResult<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::IntType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ConstantLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AlwaysSpeculatableImplTrait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowedInModuleInitializer<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::BoolType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<4>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowsTypeRefinement<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::HasValueSemantics<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::ReadOnly<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<mlir::func::FuncOp>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::MemRefsNormalizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ReturnLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::OpToOpPassAdaptor)
Load new dialect in Context tensor
Load new dialect in Context affine
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaStartOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineMapAccessInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaWaitOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::LoopLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineReadOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineWriteOpInterface)
Load new dialect in Context complex
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedDimOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ReifyRankedShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OffsetSizeAndStrideOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestinationStyleOpInterface)
Load new dialect in Context linalg
Load new dialect in Context math
Load new dialect in Context memref
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CopyOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ViewLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::LinalgOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ContractionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ConvolutionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::FillOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TilingInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PartialReductionOpInterface)
Ignoring repeated interface registrationIgnoring repeated interface registrationLoad new dialect in Context torch_c
Load new dialect in Context tosa
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp)
// -----// IR Dump Before ConvertTorchToTosa (convert-torch-to-tosa) //----- //
mlir-asm-printer: Verifying operation: func.func
func.func @torch.aten.gather(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
  %int-1 = torch.constant.int -1
  %false = torch.constant.bool false
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
  return %0 : !torch.vtensor<[12,128,128],f32>
}


//===-------------------------------------------===//
Legalizing operation : 'func.func'(0x979ff70) {
  * Fold {
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectFoldInterface)
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0x974dd00) {
  %0 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.bool'(0x97a6a20) {
  %1 = "torch.constant.bool"() {value = false} : () -> !torch.bool

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.aten.gather'(0x97a7fd0) {
  %2 = "torch.aten.gather"(%arg0, %0, %arg1, %1) : (!torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool) -> !torch.vtensor<[12,128,128],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.gather -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>"
CHI DEBUG:   Only tensor types input are currently supported. 
CHI DEBUG:   unimplemented: value `dim` should be a torch constant int 
CHI DEBUG:   llvm::Optional<Value> result = tosa::convertGatherOp(. 
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::TensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp::Trait<Empty>)
    ** Insert  : 'tosa.const'(0x97e0480)
    ** Insert  : 'tosa.const'(0x97ae360)
    ** Insert  : 'tosa.transpose'(0x974cb70)
    ** Insert  : 'tosa.reshape'(0x980eb40)
    ** Insert  : 'tosa.reshape'(0x98100c0)
    ** Insert  : 'tosa.gather'(0x9810150)
    ** Insert  : 'tosa.reshape'(0x98116f0)
CHI DEBUG:   tosa::CreateOpAndInfer<tosa::TransposeOp>. 
    ** Insert  : 'tosa.transpose'(0x9811800)
CHI DEBUG:   !result 
CHI DEBUG:   rewriter.replaceOp(op, {result.getValue()}); 
    ** Replace : 'torch.aten.gather'(0x97a7fd0)
CHI DEBUG:   success 
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.const'(0x97e0480) {
      %6 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.const'(0x97ae360) {
      %7 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.transpose'(0x974cb70) {
      %8 = "tosa.transpose"(%1, %6) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x980eb40) {
      %9 = "tosa.reshape"(%8) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x98100c0) {
      %10 = "tosa.reshape"(%1) {new_shape = [1, 786432]} : (tensor<12x128x512xf32>) -> tensor<1x786432xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.gather'(0x9810150) {
      %11 = "tosa.gather"(%9, %10) : (tensor<1x512x1536xf32>, tensor<1x786432xf32>) -> tensor<?x?x?xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x98116f0) {
      %12 = "tosa.reshape"(%11) {new_shape = [12, 128, 512, 12, 128]} : (tensor<?x?x?xf32>) -> tensor<12x128x512x12x128xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.transpose'(0x9811800) {
      %13 = "tosa.transpose"(%12, %7) : (tensor<12x128x512x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<2>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface::Trait<Empty>)
'tosa.gather' op operand #1 must be 2D tensor of 32-bit signless integer values, but got 'tensor<1x786432xf32>'
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() ({
^bb0(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>):
  %0 = "builtin.unrealized_conversion_cast"(%arg1) : (!torch.vtensor<[1,128,128],si64>) -> tensor<1x128x128xi64>
  %1 = "builtin.unrealized_conversion_cast"(%arg0) : (!torch.vtensor<[12,128,512],f32>) -> tensor<12x128x512xf32>
  %2 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int
  %3 = "builtin.unrealized_conversion_cast"(%2) : (!torch.int) -> i64
  %4 = "torch.constant.bool"() {value = false} : () -> !torch.bool
  %5 = "builtin.unrealized_conversion_cast"(%4) : (!torch.bool) -> i1
  %6 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
  %7 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>
  %8 = "tosa.transpose"(%1, %6) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>
  %9 = "tosa.reshape"(%8) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>
  %10 = "tosa.reshape"(%1) {new_shape = [1, 786432]} : (tensor<12x128x512xf32>) -> tensor<1x786432xf32>
  %11 = "tosa.gather"(%9, %10) : (tensor<1x512x1536xf32>, tensor<1x786432xf32>) -> tensor<?x?x?xf32>
  %12 = "tosa.reshape"(%11) {new_shape = [12, 128, 512, 12, 128]} : (tensor<?x?x?xf32>) -> tensor<12x128x512x12x128xf32>
  %13 = "tosa.transpose"(%12, %7) : (tensor<12x128x512x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>
  %14 = "torch.aten.gather"(%arg0, %2, %arg1, %4) : (!torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool) -> !torch.vtensor<[12,128,128],f32>
  "func.return"(%14) : (!torch.vtensor<[12,128,128],f32>) -> ()
}) {function_type = (!torch.vtensor<[12,128,512],f32>, !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32>, sym_name = "torch.aten.gather"} : () -> ()


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x97a9600) {
  "func.return"(%14) : (!torch.vtensor<[12,128,128],f32>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert  : 'torch_c.to_builtin_tensor'(0x9813730)
** Insert  : 'torch_c.from_builtin_tensor'(0x9814aa0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
/tmp/gather.mlir:4:8: error: 'tosa.gather' op operand #1 must be 2D tensor of 32-bit signless integer values, but got 'tensor<1x786432xf32>'
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
       ^
/tmp/gather.mlir:4:8: note: see current operation: %8 = "tosa.gather"(%6, %7) : (tensor<1x512x1536xf32>, tensor<1x786432xf32>) -> tensor<?x?x?xf32>
// -----// IR Dump After ConvertTorchToTosa Failed (convert-torch-to-tosa) //----- //
mlir-asm-printer: Verifying operation: func.func
'tosa.gather' op operand #1 must be 2D tensor of 32-bit signless integer values, but got 'tensor<1x786432xf32>'
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() ({
^bb0(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>):
  %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[12,128,512],f32>) -> tensor<12x128x512xf32>
  %1 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int
  %2 = "torch.constant.bool"() {value = false} : () -> !torch.bool
  %3 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
  %4 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>
  %5 = "tosa.transpose"(%0, %3) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>
  %6 = "tosa.reshape"(%5) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>
  %7 = "tosa.reshape"(%0) {new_shape = [1, 786432]} : (tensor<12x128x512xf32>) -> tensor<1x786432xf32>
  %8 = "tosa.gather"(%6, %7) : (tensor<1x512x1536xf32>, tensor<1x786432xf32>) -> tensor<?x?x?xf32>
  %9 = "tosa.reshape"(%8) {new_shape = [12, 128, 512, 12, 128]} : (tensor<?x?x?xf32>) -> tensor<12x128x512x12x128xf32>
  %10 = "tosa.transpose"(%9, %4) : (tensor<12x128x512x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>
  %11 = "torch_c.from_builtin_tensor"(%10) : (tensor<*xf32>) -> !torch.vtensor<[12,128,128],f32>
  "func.return"(%11) : (!torch.vtensor<[12,128,128],f32>) -> ()
}) {function_type = (!torch.vtensor<[12,128,512],f32>, !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32>, sym_name = "torch.aten.gather"} : () -> ()

@AmosLewis
Copy link
Author

After add a castop from i64 to i32

➜  externals git:(mix) ✗ torch-mlir-opt -convert-torch-to-tosa /tmp/gather.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug  --mlir-print-ir-before-all
Args: /home/chi/src/ubuntu20/shark/torch-mlir/build/bin/torch-mlir-opt -convert-torch-to-tosa /tmp/gather.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug --mlir-print-ir-before-all 
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementTypeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ResourceBlobManagerDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineBinaryOpExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineConstantExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineDimExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineMapStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::IntegerSetStorage)
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DebugActionManager::GenericHandler)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneRegion<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroSuccessors<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoRegionArguments<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SingleBlock<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OpInvariants<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AffineScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsIsolatedFromAbove<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SymbolTable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasOnlyGraphRegion<Empty>)
Load new dialect in Context func
Load new dialect in Context cf
Load new dialect in Context arith
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithFastMathInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VectorUnrollOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectInlinerInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferizableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolUserOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AutomaticAllocationScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface::Trait<Empty>)
Load new dialect in Context torch
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroRegions<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneResult<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::IntType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ConstantLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AlwaysSpeculatableImplTrait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowedInModuleInitializer<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::BoolType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<4>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowsTypeRefinement<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::HasValueSemantics<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::ReadOnly<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<mlir::func::FuncOp>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::MemRefsNormalizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ReturnLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::OpToOpPassAdaptor)
Load new dialect in Context tensor
Load new dialect in Context affine
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaStartOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineMapAccessInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaWaitOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::LoopLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineReadOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineWriteOpInterface)
Load new dialect in Context complex
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedDimOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ReifyRankedShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OffsetSizeAndStrideOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestinationStyleOpInterface)
Load new dialect in Context linalg
Load new dialect in Context math
Load new dialect in Context memref
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CopyOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ViewLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::LinalgOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ContractionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ConvolutionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::FillOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TilingInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PartialReductionOpInterface)
Ignoring repeated interface registrationIgnoring repeated interface registrationLoad new dialect in Context torch_c
Load new dialect in Context tosa
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp)
// -----// IR Dump Before ConvertTorchToTosa (convert-torch-to-tosa) //----- //
mlir-asm-printer: Verifying operation: func.func
func.func @torch.aten.gather(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
  %int-1 = torch.constant.int -1
  %false = torch.constant.bool false
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
  return %0 : !torch.vtensor<[12,128,128],f32>
}


//===-------------------------------------------===//
Legalizing operation : 'func.func'(0x91bc440) {
  * Fold {
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectFoldInterface)
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0x9174d00) {
  %0 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.bool'(0x91cc520) {
  %1 = "torch.constant.bool"() {value = false} : () -> !torch.bool

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.aten.gather'(0x91cc5e0) {
  %2 = "torch.aten.gather"(%arg0, %0, %arg1, %1) : (!torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool) -> !torch.vtensor<[12,128,128],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.gather -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>"
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::TensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp::Trait<Empty>)
    ** Insert  : 'tosa.const'(0x9203cc0)
    ** Insert  : 'tosa.const'(0x91d0540)
    ** Insert  : 'tosa.transpose'(0x9173b70)
    ** Insert  : 'tosa.reshape'(0x9231030)
    ** Insert  : 'tosa.reshape'(0x9233aa0)
    ** Insert  : 'tosa.cast'(0x9233b30)
    ** Insert  : 'tosa.gather'(0x9233bc0)
    ** Insert  : 'tosa.reshape'(0x9237b40)
    ** Insert  : 'tosa.transpose'(0x9237c50)
    ** Replace : 'torch.aten.gather'(0x91cc5e0)
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.const'(0x9203cc0) {
      %6 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.const'(0x91d0540) {
      %7 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.transpose'(0x9173b70) {
      %8 = "tosa.transpose"(%1, %6) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x9231030) {
      %9 = "tosa.reshape"(%8) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x9233aa0) {
      %10 = "tosa.reshape"(%0) {new_shape = [1, 16384]} : (tensor<1x128x128xi64>) -> tensor<1x16384xi64>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.cast'(0x9233b30) {
      %11 = "tosa.cast"(%10) : (tensor<1x16384xi64>) -> tensor<1x16384xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.gather'(0x9233bc0) {
      %12 = "tosa.gather"(%9, %11) : (tensor<1x512x1536xf32>, tensor<1x16384xi32>) -> tensor<?x?x?xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x9237b40) {
      %13 = "tosa.reshape"(%12) {new_shape = [1, 128, 128, 12, 128]} : (tensor<?x?x?xf32>) -> tensor<1x128x128x12x128xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.transpose'(0x9237c50) {
      %14 = "tosa.transpose"(%13, %7) : (tensor<1x128x128x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @torch.aten.gather(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
  %0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[1,128,128],si64> to tensor<1x128x128xi64>
  %1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[12,128,512],f32> to tensor<12x128x512xf32>
  %int-1 = torch.constant.int -1
  %2 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
  %false = torch.constant.bool false
  %3 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
  %4 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
  %5 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>
  %6 = "tosa.transpose"(%1, %4) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>
  %7 = "tosa.reshape"(%6) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>
  %8 = "tosa.reshape"(%0) {new_shape = [1, 16384]} : (tensor<1x128x128xi64>) -> tensor<1x16384xi64>
  %9 = "tosa.cast"(%8) : (tensor<1x16384xi64>) -> tensor<1x16384xi32>
  %10 = "tosa.gather"(%7, %9) : (tensor<1x512x1536xf32>, tensor<1x16384xi32>) -> tensor<?x?x?xf32>
  %11 = "tosa.reshape"(%10) {new_shape = [1, 128, 128, 12, 128]} : (tensor<?x?x?xf32>) -> tensor<1x128x128x12x128xf32>
  %12 = "tosa.transpose"(%11, %5) : (tensor<1x128x128x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>
  %13 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
  return %13 : !torch.vtensor<[12,128,128],f32>
}


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x91cdc10) {
  "func.return"(%15) : (!torch.vtensor<[12,128,128],f32>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert  : 'torch_c.to_builtin_tensor'(0x9239b20)
** Insert  : 'torch_c.to_builtin_tensor'(0x923adc0)
** Insert  : 'torch_c.from_builtin_tensor'(0x923ae50)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::ValueTensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneOperand<Empty>)
/tmp/gather.mlir:4:8: error: operand and result must have the same size and dtype
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
       ^
/tmp/gather.mlir:4:8: note: see current operation: %13 = "torch_c.from_builtin_tensor"(%12) : (tensor<*xf32>) -> !torch.vtensor<[12,128,128],f32>
// -----// IR Dump After ConvertTorchToTosa Failed (convert-torch-to-tosa) //----- //
mlir-asm-printer: Verifying operation: func.func
operand and result must have the same size and dtype
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() ({
^bb0(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>):
  %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[12,128,512],f32>) -> tensor<12x128x512xf32>
  %1 = "torch_c.to_builtin_tensor"(%arg1) : (!torch.vtensor<[1,128,128],si64>) -> tensor<1x128x128xi64>
  %2 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int
  %3 = "torch.constant.bool"() {value = false} : () -> !torch.bool
  %4 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
  %5 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>
  %6 = "tosa.transpose"(%0, %4) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>
  %7 = "tosa.reshape"(%6) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>
  %8 = "tosa.reshape"(%1) {new_shape = [1, 16384]} : (tensor<1x128x128xi64>) -> tensor<1x16384xi64>
  %9 = "tosa.cast"(%8) : (tensor<1x16384xi64>) -> tensor<1x16384xi32>
  %10 = "tosa.gather"(%7, %9) : (tensor<1x512x1536xf32>, tensor<1x16384xi32>) -> tensor<?x?x?xf32>
  %11 = "tosa.reshape"(%10) {new_shape = [1, 128, 128, 12, 128]} : (tensor<?x?x?xf32>) -> tensor<1x128x128x12x128xf32>
  %12 = "tosa.transpose"(%11, %5) : (tensor<1x128x128x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>
  %13 = "torch_c.from_builtin_tensor"(%12) : (tensor<*xf32>) -> !torch.vtensor<[12,128,128],f32>
  "func.return"(%13) : (!torch.vtensor<[12,128,128],f32>) -> ()
}) {function_type = (!torch.vtensor<[12,128,512],f32>, !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32>, sym_name = "torch.aten.gather"} : () -> ()

@AmosLewis
Copy link
Author

➜  externals git:(mix) ✗ torch-mlir-opt -convert-torch-to-tosa /tmp/gather.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug  --mlir-print-ir-before-all
Args: /home/chi/src/ubuntu20/shark/torch-mlir/build/bin/torch-mlir-opt -convert-torch-to-tosa /tmp/gather.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug --mlir-print-ir-before-all 
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementTypeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ResourceBlobManagerDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineBinaryOpExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineConstantExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineDimExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineMapStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::IntegerSetStorage)
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DebugActionManager::GenericHandler)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneRegion<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroSuccessors<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoRegionArguments<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SingleBlock<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OpInvariants<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AffineScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsIsolatedFromAbove<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SymbolTable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasOnlyGraphRegion<Empty>)
Load new dialect in Context func
Load new dialect in Context cf
Load new dialect in Context arith
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithFastMathInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VectorUnrollOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectInlinerInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferizableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolUserOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AutomaticAllocationScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface::Trait<Empty>)
Load new dialect in Context torch
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroRegions<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneResult<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::IntType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ConstantLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AlwaysSpeculatableImplTrait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowedInModuleInitializer<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::BoolType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<4>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowsTypeRefinement<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::HasValueSemantics<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::ReadOnly<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<mlir::func::FuncOp>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::MemRefsNormalizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ReturnLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::OpToOpPassAdaptor)
Load new dialect in Context tensor
Load new dialect in Context affine
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaStartOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineMapAccessInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaWaitOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::LoopLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineReadOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineWriteOpInterface)
Load new dialect in Context complex
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedDimOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ReifyRankedShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OffsetSizeAndStrideOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestinationStyleOpInterface)
Load new dialect in Context linalg
Load new dialect in Context math
Load new dialect in Context memref
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CopyOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ViewLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::LinalgOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ContractionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ConvolutionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::FillOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TilingInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PartialReductionOpInterface)
Ignoring repeated interface registrationIgnoring repeated interface registrationLoad new dialect in Context torch_c
Load new dialect in Context tosa
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp)
// -----// IR Dump Before ConvertTorchToTosa (convert-torch-to-tosa) //----- //
mlir-asm-printer: Verifying operation: func.func
func.func @torch.aten.gather(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
  %int-1 = torch.constant.int -1
  %false = torch.constant.bool false
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
  return %0 : !torch.vtensor<[12,128,128],f32>
}


//===-------------------------------------------===//
Legalizing operation : 'func.func'(0x88ebe50) {
  * Fold {
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectFoldInterface)
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0x88976e0) {
  %0 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.bool'(0x88f28e0) {
  %1 = "torch.constant.bool"() {value = false} : () -> !torch.bool

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.aten.gather'(0x88f3e90) {
  %2 = "torch.aten.gather"(%arg0, %0, %arg1, %1) : (!torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool) -> !torch.vtensor<[12,128,128],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.gather -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>"
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::TensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp::Trait<Empty>)
    ** Insert  : 'tosa.const'(0x892af10)
    ** Insert  : 'tosa.const'(0x88f7d20)
    ** Insert  : 'tosa.transpose'(0x8896550)
    ** Insert  : 'tosa.reshape'(0x89573a0)
    ** Insert  : 'tosa.reshape'(0x8958920)
    ** Insert  : 'tosa.cast'(0x8959ea0)
    ** Insert  : 'tosa.gather'(0x8959f30)
    ** Insert  : 'tosa.reshape'(0x895c9c0)
    ** Insert  : 'tosa.transpose'(0x895cad0)
    ** Replace : 'torch.aten.gather'(0x88f3e90)
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.const'(0x892af10) {
      %6 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.const'(0x88f7d20) {
      %7 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.transpose'(0x8896550) {
      %8 = "tosa.transpose"(%1, %6) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x89573a0) {
      %9 = "tosa.reshape"(%8) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x8958920) {
      %10 = "tosa.reshape"(%0) {new_shape = [1, 16384]} : (tensor<1x128x128xi64>) -> tensor<1x16384xi64>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.cast'(0x8959ea0) {
      %11 = "tosa.cast"(%10) : (tensor<1x16384xi64>) -> tensor<1x16384xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.gather'(0x8959f30) {
      %12 = "tosa.gather"(%9, %11) : (tensor<1x512x1536xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1536xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x895c9c0) {
      %13 = "tosa.reshape"(%12) {new_shape = [1, 128, 128, 12, 128]} : (tensor<1x16384x1536xf32>) -> tensor<1x128x128x12x128xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.transpose'(0x895cad0) {
      %14 = "tosa.transpose"(%13, %7) : (tensor<1x128x128x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @torch.aten.gather(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
  %0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[1,128,128],si64> to tensor<1x128x128xi64>
  %1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[12,128,512],f32> to tensor<12x128x512xf32>
  %int-1 = torch.constant.int -1
  %2 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
  %false = torch.constant.bool false
  %3 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
  %4 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
  %5 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>
  %6 = "tosa.transpose"(%1, %4) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>
  %7 = "tosa.reshape"(%6) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>
  %8 = "tosa.reshape"(%0) {new_shape = [1, 16384]} : (tensor<1x128x128xi64>) -> tensor<1x16384xi64>
  %9 = "tosa.cast"(%8) : (tensor<1x16384xi64>) -> tensor<1x16384xi32>
  %10 = "tosa.gather"(%7, %9) : (tensor<1x512x1536xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1536xf32>
  %11 = "tosa.reshape"(%10) {new_shape = [1, 128, 128, 12, 128]} : (tensor<1x16384x1536xf32>) -> tensor<1x128x128x12x128xf32>
  %12 = "tosa.transpose"(%11, %5) : (tensor<1x128x128x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>
  %13 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
  return %13 : !torch.vtensor<[12,128,128],f32>
}


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x88f3fd0) {
  "func.return"(%15) : (!torch.vtensor<[12,128,128],f32>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert  : 'torch_c.to_builtin_tensor'(0x895e9a0)
** Insert  : 'torch_c.to_builtin_tensor'(0x895fc40)
** Insert  : 'torch_c.from_builtin_tensor'(0x895fcd0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::ValueTensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneOperand<Empty>)
/tmp/gather.mlir:4:8: error: operand and result must have the same size and dtype
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
       ^
/tmp/gather.mlir:4:8: note: see current operation: %13 = "torch_c.from_builtin_tensor"(%12) : (tensor<*xf32>) -> !torch.vtensor<[12,128,128],f32>
// -----// IR Dump After ConvertTorchToTosa Failed (convert-torch-to-tosa) //----- //
mlir-asm-printer: Verifying operation: func.func
operand and result must have the same size and dtype
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() ({
^bb0(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>):
  %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[12,128,512],f32>) -> tensor<12x128x512xf32>
  %1 = "torch_c.to_builtin_tensor"(%arg1) : (!torch.vtensor<[1,128,128],si64>) -> tensor<1x128x128xi64>
  %2 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int
  %3 = "torch.constant.bool"() {value = false} : () -> !torch.bool
  %4 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
  %5 = "tosa.const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>} : () -> tensor<5xi32>
  %6 = "tosa.transpose"(%0, %4) : (tensor<12x128x512xf32>, tensor<3xi32>) -> tensor<512x12x128xf32>
  %7 = "tosa.reshape"(%6) {new_shape = [1, 512, 1536]} : (tensor<512x12x128xf32>) -> tensor<1x512x1536xf32>
  %8 = "tosa.reshape"(%1) {new_shape = [1, 16384]} : (tensor<1x128x128xi64>) -> tensor<1x16384xi64>
  %9 = "tosa.cast"(%8) : (tensor<1x16384xi64>) -> tensor<1x16384xi32>
  %10 = "tosa.gather"(%7, %9) : (tensor<1x512x1536xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1536xf32>
  %11 = "tosa.reshape"(%10) {new_shape = [1, 128, 128, 12, 128]} : (tensor<1x16384x1536xf32>) -> tensor<1x128x128x12x128xf32>
  %12 = "tosa.transpose"(%11, %5) : (tensor<1x128x128x12x128xf32>, tensor<5xi32>) -> tensor<*xf32>
  %13 = "torch_c.from_builtin_tensor"(%12) : (tensor<*xf32>) -> !torch.vtensor<[12,128,128],f32>
  "func.return"(%13) : (!torch.vtensor<[12,128,128],f32>) -> ()
}) {function_type = (!torch.vtensor<[12,128,512],f32>, !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32>, sym_name = "torch.aten.gather"} : () -> ()

@AmosLewis
Copy link
Author

With convertTorchIndexToTfIndices

// CHECK-LABEL:   func.func @torch.aten.gather(
// CHECK-SAME:                                 %[[VAL_0:.*]]: !torch.vtensor<[12,128,512],f32>,
// CHECK-SAME:                                 %[[VAL_1:.*]]: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
// CHECK:           %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[12,128,512],f32> -> tensor<12x128x512xf32>
// CHECK:           %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,128,128],si64> -> tensor<1x128x128xi64>
// CHECK:           %[[VAL_4:.*]] = torch.constant.int -1
// CHECK:           %[[VAL_5:.*]] = torch.constant.bool false
// CHECK:           %[[VAL_6:.*]] = "tosa.const"() {value = dense<0> : tensor<1x128x128x1xi32>} : () -> tensor<1x128x128x1xi32>
// CHECK:           %[[VAL_7:.*]] = "tosa.const"() {value = dense<"tensor<1x128x128x1xi32>} : () -> tensor<1x128x128x1xi32>
// CHECK:           %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [1, 128, 128, 1]} : (tensor<1x128x128xi64>) -> tensor<1x128x128x1xi64>
// CHECK:           %[[VAL_9:.*]] = "tosa.concat"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {axis = 3 : i64} : (tensor<1x128x128x1xi32>, tensor<1x128x128x1xi32>, tensor<1x128x128x1xi64>) -> tensor<1x128x128x3xi64>
// CHECK:           %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
// CHECK:           %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_9]]) {new_shape = [16384, 3]} : (tensor<1x128x128x3xi64>) -> tensor<16384x3xi64>
// CHECK:           %[[VAL_12:.*]] = "tosa.const"() {value = dense<[65536, 512, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK:           %[[VAL_13:.*]] = "tosa.mul"(%[[VAL_11]], %[[VAL_12]]) {shift = 0 : i32} : (tensor<16384x3xi64>, tensor<3xi32>) -> tensor<16384x3xi64>
// CHECK:           %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<16384x3xi64>) -> tensor<?x?xi64>
// CHECK:           %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = [1, 16384]} : (tensor<?x?xi64>) -> tensor<1x16384xi64>
// CHECK:           %[[VAL_16:.*]] = "tosa.cast"(%[[VAL_15]]) : (tensor<1x16384xi64>) -> tensor<1x16384xi32>
// CHECK:           %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_10]], %[[VAL_16]]) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
// CHECK:           %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
// CHECK:           %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<12x128x128xf32> -> !torch.vtensor<[12,128,128],f32>
// CHECK:           return %[[VAL_19]] : !torch.vtensor<[12,128,128],f32>
// CHECK:         }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment