Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 17, 2022 20:36
Show Gist options
  • Save AmosLewis/03e551db516fb3d789e0b322bed71336 to your computer and use it in GitHub Desktop.
Save AmosLewis/03e551db516fb3d789e0b322bed71336 to your computer and use it in GitHub Desktop.
func.func @torch.aten.gather(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.vtensor<[?,?,?,?],si64>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
@AmosLewis
Copy link
Author

AmosLewis commented Dec 17, 2022

➜  torch-mlir git:(gather-deberta) ✗ torch-mlir-opt -convert-torch-to-tosa /tmp/gather_dynamic.mlir --debug 
Args: /home/chi/src/ubuntu20/shark/torch-mlir/build/bin/torch-mlir-opt -convert-torch-to-tosa /tmp/gather_dynamic.mlir --debug 
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementTypeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubElementAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ResourceBlobManagerDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineBinaryOpExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineConstantExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineDimExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineMapStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::IntegerSetStorage)
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DebugActionManager::GenericHandler)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneRegion<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroSuccessors<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoRegionArguments<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SingleBlock<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OpInvariants<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AffineScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsIsolatedFromAbove<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SymbolTable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasOnlyGraphRegion<Empty>)
Load new dialect in Context func
Load new dialect in Context cf
Load new dialect in Context arith
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithFastMathInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VectorUnrollOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectInlinerInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferizableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolUserOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AutomaticAllocationScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface::Trait<Empty>)
Load new dialect in Context torch
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroRegions<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneResult<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::IntType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ConstantLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AlwaysSpeculatableImplTrait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowedInModuleInitializer<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::BoolType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<4>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowsTypeRefinement<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::HasValueSemantics<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::ReadOnly<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<mlir::func::FuncOp>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::MemRefsNormalizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ReturnLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::OpToOpPassAdaptor)
Load new dialect in Context tensor
Load new dialect in Context affine
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaStartOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineMapAccessInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineDmaWaitOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::LoopLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineReadOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::AffineWriteOpInterface)
Load new dialect in Context complex
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedDimOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ReifyRankedShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OffsetSizeAndStrideOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestinationStyleOpInterface)
Load new dialect in Context linalg
Load new dialect in Context math
Load new dialect in Context memref
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CopyOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ViewLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::LinalgOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ContractionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ConvolutionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::FillOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TilingInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PartialReductionOpInterface)
Ignoring repeated interface registrationIgnoring repeated interface registrationLoad new dialect in Context torch_c
Load new dialect in Context tosa
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp)

//===-------------------------------------------===//
Legalizing operation : 'func.func'(0x9f03760) {
  * Fold {
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectFoldInterface)
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0x9ebdb50) {
  %0 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.constant.bool'(0x9f03c20) {
  %1 = "torch.constant.bool"() {value = false} : () -> !torch.bool

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.aten.gather'(0x9f15890) {
  %2 = "torch.aten.gather"(%arg0, %0, %arg1, %1) : (!torch.vtensor<[?,?,?],f32>, !torch.int, !torch.vtensor<[?,?,?,?],si64>, !torch.bool) -> !torch.vtensor<[?,?,?],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.gather -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>"
    ** Insert  : 'tosa.cast'(0x9f66360)
    ** Insert  : 'tosa.reshape'(0x9f71a70)
    ** Insert  : 'tosa.reshape'(0x9f744e0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::TensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tosa::TosaOp::Trait<Empty>)
    ** Insert  : 'tosa.const'(0x9f5cec0)
    ** Insert  : 'tosa.mul'(0x9ebc410)
    ** Insert  : 'tosa.reduce_sum'(0x9ebc9a0)
    ** Insert  : 'tosa.reshape'(0x9f7bfb0)
    ** Insert  : 'tosa.gather'(0x9f7c040)
    ** Insert  : 'tosa.reshape'(0x9f7ffc0)
    ** Replace : 'torch.aten.gather'(0x9f15890)
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenGatherOp>" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.cast'(0x9f66360) {
      %6 = "tosa.cast"(%0) : (tensor<?x?x?x?xi64>) -> tensor<?x?x?x?xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x9f71a70) {
      %7 = "tosa.reshape"(%1) {new_shape = [1, 1, 0]} : (tensor<?x?x?xf32>) -> tensor<1x1x0xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x9f744e0) {
      %8 = "tosa.reshape"(%6) {new_shape = [0, 0]} : (tensor<?x?x?x?xi32>) -> tensor<0x0xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.const'(0x9f5cec0) {
      %9 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.mul'(0x9ebc410) {
      %10 = "tosa.mul"(%8, %9) {shift = 0 : i32} : (tensor<0x0xi32>, tensor<1xi32>) -> tensor<0x0xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reduce_sum'(0x9ebc9a0) {
      %11 = "tosa.reduce_sum"(%10) {axis = 1 : i64} : (tensor<0x0xi32>) -> tensor<0x1xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x9f7bfb0) {
      %12 = "tosa.reshape"(%11) {new_shape = [1, 0]} : (tensor<0x1xi32>) -> tensor<1x0xi32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.gather'(0x9f7c040) {
      %13 = "tosa.gather"(%7, %12) : (tensor<1x1x0xf32>, tensor<1x0xi32>) -> tensor<1x0x0xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'tosa.reshape'(0x9f7ffc0) {
      %14 = "tosa.reshape"(%13) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x0x0xf32>) -> tensor<0x0x0xf32>

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @torch.aten.gather(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?],f32> {
  %0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[?,?,?,?],si64> to tensor<?x?x?x?xi64>
  %1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[?,?,?],f32> to tensor<?x?x?xf32>
  %int-1 = torch.constant.int -1
  %2 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
  %false = torch.constant.bool false
  %3 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
  %4 = "tosa.cast"(%0) : (tensor<?x?x?x?xi64>) -> tensor<?x?x?x?xi32>
  %5 = "tosa.reshape"(%1) {new_shape = [1, 1, 0]} : (tensor<?x?x?xf32>) -> tensor<1x1x0xf32>
  %6 = "tosa.reshape"(%4) {new_shape = [0, 0]} : (tensor<?x?x?x?xi32>) -> tensor<0x0xi32>
  %7 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
  %8 = "tosa.mul"(%6, %7) {shift = 0 : i32} : (tensor<0x0xi32>, tensor<1xi32>) -> tensor<0x0xi32>
  %9 = "tosa.reduce_sum"(%8) {axis = 1 : i64} : (tensor<0x0xi32>) -> tensor<0x1xi32>
  %10 = "tosa.reshape"(%9) {new_shape = [1, 0]} : (tensor<0x1xi32>) -> tensor<1x0xi32>
  %11 = "tosa.gather"(%5, %10) : (tensor<1x1x0xf32>, tensor<1x0xi32>) -> tensor<1x0x0xf32>
  %12 = "tosa.reshape"(%11) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x0x0xf32>) -> tensor<0x0x0xf32>
  %13 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.vtensor<[?,?,?,?],si64>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
  return %13 : !torch.vtensor<[?,?,?],f32>
}


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x9f159d0) {
  "func.return"(%15) : (!torch.vtensor<[?,?,?],f32>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert  : 'torch_c.to_builtin_tensor'(0x9f82740)
** Insert  : 'torch_c.to_builtin_tensor'(0x9f827f0)
** Insert  : 'torch_c.from_builtin_tensor'(0x9f82880)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::ValueTensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneOperand<Empty>)
/tmp/gather_dynamic.mlir:4:8: error: operand and result must have the same size and dtype
  %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.vtensor<[?,?,?,?],si64>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
       ^
/tmp/gather_dynamic.mlir:4:8: note: see current operation: %13 = "torch_c.from_builtin_tensor"(%12) : (tensor<0x0x0xf32>) -> !torch.vtensor<[?,?,?],f32>

@AmosLewis
Copy link
Author

if (inputShape[axis_val] < 0)
      op->emitOpError("Failed convertReduceMean: support for dynamic input "
                      "shape not implemented");

@AmosLewis
Copy link
Author

AmosLewis commented Dec 17, 2022

    auto lhsTy = lhs.getType().cast<RankedTensorType>();
    auto rhsTy = rhs.getType().cast<RankedTensorType>();

    auto lhsRank = lhsTy.getRank();
    auto rhsRank = rhsTy.getRank();

    auto lhsShape = makeShapeTorchCompatible(lhsTy.getShape());
    auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape());

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment