Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created August 29, 2023 06:16
Show Gist options
  • Save AmosLewis/960681429accac3f435df193419a38fe to your computer and use it in GitHub Desktop.
Save AmosLewis/960681429accac3f435df193419a38fe to your computer and use it in GitHub Desktop.
(mlir_venv) ➜ torch-mlir git:(decompose) ✗ torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir --debug
Args: torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir --debug
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DistinctAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ResourceBlobManagerDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineBinaryOpExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineConstantExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineDimExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineMapStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::IntegerSetStorage)
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestructurableTypeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneRegion<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroSuccessors<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoRegionArguments<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SingleBlock<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OpInvariants<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AffineScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsIsolatedFromAbove<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SymbolTable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasOnlyGraphRegion<Empty>)
Load new dialect in Context torch
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolUserOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectInlinerInterface)
Load new dialect in Context func
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferizableOpInterface)
Load new dialect in Context cf
Load new dialect in Context arith
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithFastMathInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VectorUnrollOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ValueBoundsOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConvertToLLVMPatternInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BranchOpInterface)
[dialect] repeated interface registration for dialect funcImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AutomaticAllocationScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroRegions<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneResult<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::NoneType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ConstantLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AlwaysSpeculatableImplTrait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowedInModuleInitializer<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::IntType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::BoolType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::ListType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::AllowsTypeRefinement<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<5>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::HasValueSemantics<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::torch::Torch::OpTrait::ReadOnly<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<2>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<7>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<4>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<mlir::func::FuncOp>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::MemRefsNormalizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ReturnLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DataLayoutSpecInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::OpToOpPassAdaptor)
Load new dialect in Context chlo
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface)
Load new dialect in Context shape
Load new dialect in Context tensor
Load new dialect in Context affine
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineDmaStartOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineMapAccessInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineDmaWaitOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::LoopLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineReadOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineWriteOpInterface)
Load new dialect in Context complex
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedDimOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ReifyRankedShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OffsetSizeAndStrideOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestinationStyleOpInterface)
Load new dialect in Context linalg
Load new dialect in Context math
Load new dialect in Context memref
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CopyOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PromotableMemOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestructurableAccessorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PromotableAllocationOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestructurableAllocationOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ViewLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RuntimeVerifiableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::AggregatedOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TilingInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::LinalgOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ContractionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ConvolutionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::FillOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PartialReductionOpInterface)
Ignoring repeated interface registration
Ignoring repeated interface registration
ImplicitTypeIDRegistry::lookupOrInsert(mlir::transform::FindPayloadReplacementOpInterface)
Load new dialect in Context stablehlo
ImplicitTypeIDRegistry::lookupOrInsert(mlir::hlo::HloDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::stablehlo::TokenType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VerifiableTensorEncoding)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::hlo::BoundedAttrInterface)
Load new dialect in Context torch_c
//===-------------------------------------------===//
Legalizing operation : 'func.func'(0xf3ee5f0) {
* Fold {
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectFoldInterface)
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.constant.none'(0xf3d4970) {
%0 = "torch.constant.none"() : () -> !torch.none
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0xf3d3460) {
%1 = "torch.constant.int"() {value = 4 : i64} : () -> !torch.int
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0xf3eb190) {
%2 = "torch.constant.int"() {value = 1 : i64} : () -> !torch.int
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0xf3eb6d0) {
%3 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.constant.int'(0xf3ec090) {
%4 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.constant.bool'(0xf3eca50) {
%5 = "torch.constant.bool"() {value = false} : () -> !torch.bool
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.prim.ListConstruct'(0xf3ed410) {
%6 = "torch.prim.ListConstruct"(%2, %1) : (!torch.int, !torch.int) -> !torch.list<int>
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.zeros'(0xf393750) {
%7 = "torch.aten.zeros"(%6, %1, %0, %0, %5) : (!torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool) -> !torch.vtensor<[1,4],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.zeros -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenConstPatternOp<mlir::torch::Torch::AtenZerosOp, 0>"
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::RankedTensorType>::Impl<Empty>)
** Insert : 'stablehlo.constant'(0xf454490)
** Insert : 'stablehlo.convert'(0xf461bc0)
** Replace : 'torch.aten.zeros'(0xf393750)
"(anonymous namespace)::ConvertAtenConstPatternOp<mlir::torch::Torch::AtenZerosOp, 0>" result 1
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.constant'(0xf454490) {
%9 = "stablehlo.constant"() {value = dense<0> : tensor<1x4xi32>} : () -> tensor<1x4xi32>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf461bc0) {
%10 = "stablehlo.convert"(%9) : (tensor<1x4xi32>) -> tensor<1x4xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
%none = torch.constant.none
%int4 = torch.constant.int 4
%0 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%1 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
%2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = stablehlo.constant dense<0> : tensor<1x4xi32>
%4 = stablehlo.convert %3 : (tensor<1x4xi32>) -> tensor<1x4xi64>
%5 = torch.aten.zeros %2, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
%6 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
%7 = torch.aten.clone %6, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
%8 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
%9 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%10 = torch.aten.unsqueeze %9, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%11 = torch.prim.ListConstruct %10, %8 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%12 = torch.aten.index_put.hacked_twin %5, %11, %7, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
return %12 : !torch.vtensor<[1,4],si64>
}
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.slice.Tensor'(0xf391e20) {
%12 = "torch.aten.slice.Tensor"(%arg1, %3, %4, %5, %3) : (!torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.slice.Tensor -> ()' {
Trying to match "mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenSliceTensorOp>"
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::detail::ConstantOpGenericAdaptorBase::Properties)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface::Trait<Empty>)
** Insert : 'arith.constant'(0xf4626f0)
** Insert : 'tensor.dim'(0xf462760)
** Insert : 'arith.index_cast'(0xf462810)
** Insert : 'arith.constant'(0xf4628a0)
** Insert : 'arith.subi'(0xf462910)
** Insert : 'arith.maxsi'(0xf462a30)
** Insert : 'arith.minsi'(0xf462ae0)
** Insert : 'arith.addi'(0xf462b90)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::detail::CmpIOpGenericAdaptorBase::Properties)
** Insert : 'arith.cmpi'(0xf462c40)
** Insert : 'arith.select'(0xf3e8a10)
** Insert : 'arith.constant'(0xf462cf0)
** Insert : 'arith.subi'(0xf462d60)
** Insert : 'arith.maxsi'(0xf462ef0)
** Insert : 'arith.minsi'(0xf462fa0)
** Insert : 'arith.addi'(0xf463050)
** Insert : 'arith.cmpi'(0xf463100)
** Insert : 'arith.select'(0xf4631b0)
** Insert : 'arith.constant'(0xf463690)
** Insert : 'tensor.dim'(0xf463700)
** Insert : 'arith.index_cast'(0xf4637b0)
** Insert : 'arith.constant'(0xf466e40)
** Insert : 'tensor.dim'(0xf466eb0)
** Insert : 'arith.index_cast'(0xf466f60)
** Insert : 'arith.constant'(0xf466ff0)
** Insert : 'arith.constant'(0xf467060)
** Insert : 'arith.cmpi'(0xf4670d0)
** Insert : 'arith.select'(0xf462e10)
** Insert : 'tensor.from_elements'(0xf4677c0)
** Insert : 'tensor.from_elements'(0xf467870)
** Insert : 'tensor.from_elements'(0xf467920)
** Insert : 'stablehlo.real_dynamic_slice'(0xf3f1100)
** Replace : 'torch.aten.slice.Tensor'(0xf391e20)
"mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenSliceTensorOp>" result 1
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf4626f0) {
%16 = "arith.constant"() <{value = 1 : index}> : () -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.dim'(0xf462760) {
%17 = "tensor.dim"(%0, %16) : (tensor<1x4xi64>, index) -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.index_cast'(0xf462810) {
%18 = "arith.index_cast"(%17) : (index) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf4628a0) {
%19 = "arith.constant"() <{value = 0 : i64}> : () -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.subi'(0xf462910) {
%20 = "arith.subi"(%19, %18) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.maxsi'(0xf462a30) {
%21 = "arith.maxsi"(%20, %7) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.minsi'(0xf462ae0) {
%22 = "arith.minsi"(%18, %21) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.addi'(0xf462b90) {
%23 = "arith.addi"(%18, %22) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.cmpi'(0xf462c40) {
%24 = "arith.cmpi"(%22, %19) <{predicate = 5 : i64}> : (i64, i64) -> i1
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.select'(0xf3e8a10) {
%25 = "arith.select"(%24, %22, %23) : (i1, i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf462cf0) {
%26 = "arith.constant"() <{value = 0 : i64}> : () -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.subi'(0xf462d60) {
%27 = "arith.subi"(%26, %18) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.maxsi'(0xf462ef0) {
%28 = "arith.maxsi"(%27, %9) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.minsi'(0xf462fa0) {
%29 = "arith.minsi"(%18, %28) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.addi'(0xf463050) {
%30 = "arith.addi"(%18, %29) : (i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.cmpi'(0xf463100) {
%31 = "arith.cmpi"(%29, %26) <{predicate = 5 : i64}> : (i64, i64) -> i1
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.select'(0xf4631b0) {
%32 = "arith.select"(%31, %29, %30) : (i1, i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf463690) {
%33 = "arith.constant"() <{value = 0 : index}> : () -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.dim'(0xf463700) {
%34 = "tensor.dim"(%0, %33) : (tensor<1x4xi64>, index) -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.index_cast'(0xf4637b0) {
%35 = "arith.index_cast"(%34) : (index) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf466e40) {
%36 = "arith.constant"() <{value = 1 : index}> : () -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.dim'(0xf466eb0) {
%37 = "tensor.dim"(%0, %36) : (tensor<1x4xi64>, index) -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.index_cast'(0xf466f60) {
%38 = "arith.index_cast"(%37) : (index) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf466ff0) {
%39 = "arith.constant"() <{value = 0 : i64}> : () -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf467060) {
%40 = "arith.constant"() <{value = 1 : i64}> : () -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.cmpi'(0xf4670d0) {
%41 = "arith.cmpi"(%32, %39) <{predicate = 0 : i64}> : (i64, i64) -> i1
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.select'(0xf462e10) {
%42 = "arith.select"(%41, %38, %32) : (i1, i64, i64) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf4677c0) {
%43 = "tensor.from_elements"(%39, %25) : (i64, i64) -> tensor<2xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf467870) {
%44 = "tensor.from_elements"(%35, %42) : (i64, i64) -> tensor<2xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf467920) {
%45 = "tensor.from_elements"(%40, %5) : (i64, i64) -> tensor<2xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.real_dynamic_slice'(0xf3f1100) {
%46 = "stablehlo.real_dynamic_slice"(%0, %43, %44, %45) : (tensor<1x4xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[1,4],si64> to tensor<1x4xi64>
%none = torch.constant.none
%int4 = torch.constant.int 4
%1 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
%int1 = torch.constant.int 1
%2 = builtin.unrealized_conversion_cast %int1 : !torch.int to i64
%int0 = torch.constant.int 0
%3 = builtin.unrealized_conversion_cast %int0 : !torch.int to i64
%int-1 = torch.constant.int -1
%4 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
%false = torch.constant.bool false
%5 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
%6 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = stablehlo.constant dense<0> : tensor<1x4xi32>
%8 = stablehlo.convert %7 : (tensor<1x4xi32>) -> tensor<1x4xi64>
%9 = torch.aten.zeros %6, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
%c1 = arith.constant 1 : index
%dim = tensor.dim %0, %c1 : tensor<1x4xi64>
%10 = arith.index_cast %dim : index to i64
%c0_i64 = arith.constant 0 : i64
%11 = arith.subi %c0_i64, %10 : i64
%12 = arith.maxsi %11, %3 : i64
%13 = arith.minsi %10, %12 : i64
%14 = arith.addi %10, %13 : i64
%15 = arith.cmpi sge, %13, %c0_i64 : i64
%16 = arith.select %15, %13, %14 : i64
%c0_i64_0 = arith.constant 0 : i64
%17 = arith.subi %c0_i64_0, %10 : i64
%18 = arith.maxsi %17, %4 : i64
%19 = arith.minsi %10, %18 : i64
%20 = arith.addi %10, %19 : i64
%21 = arith.cmpi sge, %19, %c0_i64_0 : i64
%22 = arith.select %21, %19, %20 : i64
%c0 = arith.constant 0 : index
%dim_1 = tensor.dim %0, %c0 : tensor<1x4xi64>
%23 = arith.index_cast %dim_1 : index to i64
%c1_2 = arith.constant 1 : index
%dim_3 = tensor.dim %0, %c1_2 : tensor<1x4xi64>
%24 = arith.index_cast %dim_3 : index to i64
%c0_i64_4 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%25 = arith.cmpi eq, %22, %c0_i64_4 : i64
%26 = arith.select %25, %24, %22 : i64
%from_elements = tensor.from_elements %c0_i64_4, %16 : tensor<2xi64>
%from_elements_5 = tensor.from_elements %23, %26 : tensor<2xi64>
%from_elements_6 = tensor.from_elements %c1_i64, %2 : tensor<2xi64>
%27 = stablehlo.real_dynamic_slice %0, %from_elements, %from_elements_5, %from_elements_6 : (tensor<1x4xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xi64>
%28 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
%29 = torch.aten.clone %28, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
%30 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
%31 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%32 = torch.aten.unsqueeze %31, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%33 = torch.prim.ListConstruct %32, %30 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%34 = torch.aten.index_put.hacked_twin %9, %33, %29, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
return %34 : !torch.vtensor<[1,4],si64>
}
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.clone'(0xf3edd00) {
%48 = "torch.aten.clone"(%47, %1) : (!torch.vtensor<[1,3],si64>, !torch.none) -> !torch.vtensor<[1,3],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.clone -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenUnaryOp<mlir::torch::Torch::AtenCloneOp, mlir::stablehlo::ConvertOp>"
** Insert : 'stablehlo.convert'(0xf468fe0)
** Replace : 'torch.aten.clone'(0xf3edd00)
"(anonymous namespace)::ConvertAtenUnaryOp<mlir::torch::Torch::AtenCloneOp, mlir::stablehlo::ConvertOp>" result 1
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf468fe0) {
%48 = "stablehlo.convert"(%46) : (tensor<1x3xi64>) -> tensor<1x3xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[1,4],si64> to tensor<1x4xi64>
%none = torch.constant.none
%int4 = torch.constant.int 4
%1 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
%int1 = torch.constant.int 1
%2 = builtin.unrealized_conversion_cast %int1 : !torch.int to i64
%int0 = torch.constant.int 0
%3 = builtin.unrealized_conversion_cast %int0 : !torch.int to i64
%int-1 = torch.constant.int -1
%4 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
%false = torch.constant.bool false
%5 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
%6 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = stablehlo.constant dense<0> : tensor<1x4xi32>
%8 = stablehlo.convert %7 : (tensor<1x4xi32>) -> tensor<1x4xi64>
%9 = torch.aten.zeros %6, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
%c1 = arith.constant 1 : index
%dim = tensor.dim %0, %c1 : tensor<1x4xi64>
%10 = arith.index_cast %dim : index to i64
%c0_i64 = arith.constant 0 : i64
%11 = arith.subi %c0_i64, %10 : i64
%12 = arith.maxsi %11, %3 : i64
%13 = arith.minsi %10, %12 : i64
%14 = arith.addi %10, %13 : i64
%15 = arith.cmpi sge, %13, %c0_i64 : i64
%16 = arith.select %15, %13, %14 : i64
%c0_i64_0 = arith.constant 0 : i64
%17 = arith.subi %c0_i64_0, %10 : i64
%18 = arith.maxsi %17, %4 : i64
%19 = arith.minsi %10, %18 : i64
%20 = arith.addi %10, %19 : i64
%21 = arith.cmpi sge, %19, %c0_i64_0 : i64
%22 = arith.select %21, %19, %20 : i64
%c0 = arith.constant 0 : index
%dim_1 = tensor.dim %0, %c0 : tensor<1x4xi64>
%23 = arith.index_cast %dim_1 : index to i64
%c1_2 = arith.constant 1 : index
%dim_3 = tensor.dim %0, %c1_2 : tensor<1x4xi64>
%24 = arith.index_cast %dim_3 : index to i64
%c0_i64_4 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%25 = arith.cmpi eq, %22, %c0_i64_4 : i64
%26 = arith.select %25, %24, %22 : i64
%from_elements = tensor.from_elements %c0_i64_4, %16 : tensor<2xi64>
%from_elements_5 = tensor.from_elements %23, %26 : tensor<2xi64>
%from_elements_6 = tensor.from_elements %c1_i64, %2 : tensor<2xi64>
%27 = stablehlo.real_dynamic_slice %0, %from_elements, %from_elements_5, %from_elements_6 : (tensor<1x4xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xi64>
%28 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
%29 = stablehlo.convert %27 : tensor<1x3xi64>
%30 = torch.aten.clone %28, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
%31 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
%32 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%33 = torch.aten.unsqueeze %32, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%34 = torch.prim.ListConstruct %33, %31 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%35 = torch.aten.index_put.hacked_twin %9, %34, %30, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
return %35 : !torch.vtensor<[1,4],si64>
}
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.arange.start_step'(0xf3ee310) {
%50 = "torch.aten.arange.start_step"(%4, %2, %4, %1, %1, %1, %1) : (!torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[3],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.arange.start_step -> ()' {
Trying to match "mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenArangeStartStepOp>"
** Insert : 'tensor.from_elements'(0xf469130)
** Insert : 'stablehlo.convert'(0xf469e70)
** Insert : 'stablehlo.reshape'(0xf468240)
** Insert : 'tensor.from_elements'(0xf4682d0)
** Insert : 'stablehlo.convert'(0xf468360)
** Insert : 'stablehlo.reshape'(0xf4683f0)
** Insert : 'tensor.from_elements'(0xf468480)
** Insert : 'stablehlo.convert'(0xf468510)
** Insert : 'stablehlo.reshape'(0xf4685a0)
** Insert : 'stablehlo.subtract'(0xf468630)
** Insert : 'stablehlo.divide'(0xf4686e0)
** Insert : 'stablehlo.reshape'(0xf468790)
** Insert : 'stablehlo.dynamic_iota'(0xf468820)
** Insert : 'chlo.broadcast_multiply'(0xf4688b0)
** Insert : 'chlo.broadcast_add'(0xf468960)
** Replace : 'torch.aten.arange.start_step'(0xf3ee310)
"mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenArangeStartStepOp>" result 1
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf469130) {
%50 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf469e70) {
%51 = "stablehlo.convert"(%50) : (tensor<1xi64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf468240) {
%52 = "stablehlo.reshape"(%51) : (tensor<1xi64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf4682d0) {
%53 = "tensor.from_elements"(%3) : (i64) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf468360) {
%54 = "stablehlo.convert"(%53) : (tensor<1xi64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf4683f0) {
%55 = "stablehlo.reshape"(%54) : (tensor<1xi64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf468480) {
%56 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf468510) {
%57 = "stablehlo.convert"(%56) : (tensor<1xi64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf4685a0) {
%58 = "stablehlo.reshape"(%57) : (tensor<1xi64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.subtract'(0xf468630) {
%59 = "stablehlo.subtract"(%55, %52) : (tensor<i64>, tensor<i64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.divide'(0xf4686e0) {
%60 = "stablehlo.divide"(%59, %58) : (tensor<i64>, tensor<i64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf468790) {
%61 = "stablehlo.reshape"(%60) : (tensor<i64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.dynamic_iota'(0xf468820) {
%62 = "stablehlo.dynamic_iota"(%61) {iota_dimension = 0 : i64} : (tensor<1xi64>) -> tensor<3xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'chlo.broadcast_multiply'(0xf4688b0) {
%63 = "chlo.broadcast_multiply"(%62, %58) : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'chlo.broadcast_add'(0xf468960) {
%64 = "chlo.broadcast_add"(%63, %52) : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneOperand<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SameOperandsAndResultElementType<Empty>)
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[1,4],si64> to tensor<1x4xi64>
%none = torch.constant.none
%int4 = torch.constant.int 4
%1 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
%int1 = torch.constant.int 1
%2 = builtin.unrealized_conversion_cast %int1 : !torch.int to i64
%int0 = torch.constant.int 0
%3 = builtin.unrealized_conversion_cast %int0 : !torch.int to i64
%int-1 = torch.constant.int -1
%4 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
%false = torch.constant.bool false
%5 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
%6 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = stablehlo.constant dense<0> : tensor<1x4xi32>
%8 = stablehlo.convert %7 : (tensor<1x4xi32>) -> tensor<1x4xi64>
%9 = torch.aten.zeros %6, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
%c1 = arith.constant 1 : index
%dim = tensor.dim %0, %c1 : tensor<1x4xi64>
%10 = arith.index_cast %dim : index to i64
%c0_i64 = arith.constant 0 : i64
%11 = arith.subi %c0_i64, %10 : i64
%12 = arith.maxsi %11, %3 : i64
%13 = arith.minsi %10, %12 : i64
%14 = arith.addi %10, %13 : i64
%15 = arith.cmpi sge, %13, %c0_i64 : i64
%16 = arith.select %15, %13, %14 : i64
%c0_i64_0 = arith.constant 0 : i64
%17 = arith.subi %c0_i64_0, %10 : i64
%18 = arith.maxsi %17, %4 : i64
%19 = arith.minsi %10, %18 : i64
%20 = arith.addi %10, %19 : i64
%21 = arith.cmpi sge, %19, %c0_i64_0 : i64
%22 = arith.select %21, %19, %20 : i64
%c0 = arith.constant 0 : index
%dim_1 = tensor.dim %0, %c0 : tensor<1x4xi64>
%23 = arith.index_cast %dim_1 : index to i64
%c1_2 = arith.constant 1 : index
%dim_3 = tensor.dim %0, %c1_2 : tensor<1x4xi64>
%24 = arith.index_cast %dim_3 : index to i64
%c0_i64_4 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%25 = arith.cmpi eq, %22, %c0_i64_4 : i64
%26 = arith.select %25, %24, %22 : i64
%from_elements = tensor.from_elements %c0_i64_4, %16 : tensor<2xi64>
%from_elements_5 = tensor.from_elements %23, %26 : tensor<2xi64>
%from_elements_6 = tensor.from_elements %c1_i64, %2 : tensor<2xi64>
%27 = stablehlo.real_dynamic_slice %0, %from_elements, %from_elements_5, %from_elements_6 : (tensor<1x4xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xi64>
%28 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
%29 = stablehlo.convert %27 : tensor<1x3xi64>
%30 = torch.aten.clone %28, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
%from_elements_7 = tensor.from_elements %2 : tensor<1xi64>
%31 = stablehlo.convert %from_elements_7 : tensor<1xi64>
%32 = stablehlo.reshape %31 : (tensor<1xi64>) -> tensor<i64>
%from_elements_8 = tensor.from_elements %1 : tensor<1xi64>
%33 = stablehlo.convert %from_elements_8 : tensor<1xi64>
%34 = stablehlo.reshape %33 : (tensor<1xi64>) -> tensor<i64>
%from_elements_9 = tensor.from_elements %2 : tensor<1xi64>
%35 = stablehlo.convert %from_elements_9 : tensor<1xi64>
%36 = stablehlo.reshape %35 : (tensor<1xi64>) -> tensor<i64>
%37 = stablehlo.subtract %34, %32 : tensor<i64>
%38 = stablehlo.divide %37, %36 : tensor<i64>
%39 = stablehlo.reshape %38 : (tensor<i64>) -> tensor<1xi64>
%40 = stablehlo.dynamic_iota %39, dim = 0 : (tensor<1xi64>) -> tensor<3xi64>
%41 = chlo.broadcast_multiply %40, %36 : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%42 = chlo.broadcast_add %41, %32 : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%43 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
%44 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%45 = torch.aten.unsqueeze %44, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%46 = torch.prim.ListConstruct %45, %43 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%47 = torch.aten.index_put.hacked_twin %9, %46, %30, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
return %47 : !torch.vtensor<[1,4],si64>
}
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.arange.start_step'(0xf3eeca0) {
%66 = "torch.aten.arange.start_step"(%6, %4, %4, %2, %1, %1, %1) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[1],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.arange.start_step -> ()' {
Trying to match "mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenArangeStartStepOp>"
** Insert : 'tensor.from_elements'(0xf46a990)
** Insert : 'stablehlo.convert'(0xf46aa90)
** Insert : 'stablehlo.reshape'(0xf46ab20)
** Insert : 'tensor.from_elements'(0xf46abb0)
** Insert : 'stablehlo.convert'(0xf46ac40)
** Insert : 'stablehlo.reshape'(0xf46acd0)
** Insert : 'tensor.from_elements'(0xf46ad60)
** Insert : 'stablehlo.convert'(0xf469070)
** Insert : 'stablehlo.reshape'(0xf4642b0)
** Insert : 'stablehlo.subtract'(0xf468ea0)
** Insert : 'stablehlo.divide'(0xf467180)
** Insert : 'stablehlo.reshape'(0xf467230)
** Insert : 'stablehlo.dynamic_iota'(0xf46ded0)
** Insert : 'chlo.broadcast_multiply'(0xf46df60)
** Insert : 'chlo.broadcast_add'(0xf46e010)
** Replace : 'torch.aten.arange.start_step'(0xf3eeca0)
"mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenArangeStartStepOp>" result 1
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf46a990) {
%66 = "tensor.from_elements"(%7) : (i64) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf46aa90) {
%67 = "stablehlo.convert"(%66) : (tensor<1xi64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf46ab20) {
%68 = "stablehlo.reshape"(%67) : (tensor<1xi64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf46abb0) {
%69 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf46ac40) {
%70 = "stablehlo.convert"(%69) : (tensor<1xi64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf46acd0) {
%71 = "stablehlo.reshape"(%70) : (tensor<1xi64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf46ad60) {
%72 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.convert'(0xf469070) {
%73 = "stablehlo.convert"(%72) : (tensor<1xi64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf4642b0) {
%74 = "stablehlo.reshape"(%73) : (tensor<1xi64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.subtract'(0xf468ea0) {
%75 = "stablehlo.subtract"(%71, %68) : (tensor<i64>, tensor<i64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.divide'(0xf467180) {
%76 = "stablehlo.divide"(%75, %74) : (tensor<i64>, tensor<i64>) -> tensor<i64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf467230) {
%77 = "stablehlo.reshape"(%76) : (tensor<i64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.dynamic_iota'(0xf46ded0) {
%78 = "stablehlo.dynamic_iota"(%77) {iota_dimension = 0 : i64} : (tensor<1xi64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'chlo.broadcast_multiply'(0xf46df60) {
%79 = "chlo.broadcast_multiply"(%78, %74) : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'chlo.broadcast_add'(0xf46e010) {
%80 = "chlo.broadcast_add"(%79, %68) : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[1,4],si64> to tensor<1x4xi64>
%none = torch.constant.none
%int4 = torch.constant.int 4
%1 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
%int1 = torch.constant.int 1
%2 = builtin.unrealized_conversion_cast %int1 : !torch.int to i64
%int0 = torch.constant.int 0
%3 = builtin.unrealized_conversion_cast %int0 : !torch.int to i64
%int-1 = torch.constant.int -1
%4 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
%false = torch.constant.bool false
%5 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
%6 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = stablehlo.constant dense<0> : tensor<1x4xi32>
%8 = stablehlo.convert %7 : (tensor<1x4xi32>) -> tensor<1x4xi64>
%9 = torch.aten.zeros %6, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
%c1 = arith.constant 1 : index
%dim = tensor.dim %0, %c1 : tensor<1x4xi64>
%10 = arith.index_cast %dim : index to i64
%c0_i64 = arith.constant 0 : i64
%11 = arith.subi %c0_i64, %10 : i64
%12 = arith.maxsi %11, %3 : i64
%13 = arith.minsi %10, %12 : i64
%14 = arith.addi %10, %13 : i64
%15 = arith.cmpi sge, %13, %c0_i64 : i64
%16 = arith.select %15, %13, %14 : i64
%c0_i64_0 = arith.constant 0 : i64
%17 = arith.subi %c0_i64_0, %10 : i64
%18 = arith.maxsi %17, %4 : i64
%19 = arith.minsi %10, %18 : i64
%20 = arith.addi %10, %19 : i64
%21 = arith.cmpi sge, %19, %c0_i64_0 : i64
%22 = arith.select %21, %19, %20 : i64
%c0 = arith.constant 0 : index
%dim_1 = tensor.dim %0, %c0 : tensor<1x4xi64>
%23 = arith.index_cast %dim_1 : index to i64
%c1_2 = arith.constant 1 : index
%dim_3 = tensor.dim %0, %c1_2 : tensor<1x4xi64>
%24 = arith.index_cast %dim_3 : index to i64
%c0_i64_4 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%25 = arith.cmpi eq, %22, %c0_i64_4 : i64
%26 = arith.select %25, %24, %22 : i64
%from_elements = tensor.from_elements %c0_i64_4, %16 : tensor<2xi64>
%from_elements_5 = tensor.from_elements %23, %26 : tensor<2xi64>
%from_elements_6 = tensor.from_elements %c1_i64, %2 : tensor<2xi64>
%27 = stablehlo.real_dynamic_slice %0, %from_elements, %from_elements_5, %from_elements_6 : (tensor<1x4xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xi64>
%28 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
%29 = stablehlo.convert %27 : tensor<1x3xi64>
%30 = torch.aten.clone %28, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
%from_elements_7 = tensor.from_elements %2 : tensor<1xi64>
%31 = stablehlo.convert %from_elements_7 : tensor<1xi64>
%32 = stablehlo.reshape %31 : (tensor<1xi64>) -> tensor<i64>
%from_elements_8 = tensor.from_elements %1 : tensor<1xi64>
%33 = stablehlo.convert %from_elements_8 : tensor<1xi64>
%34 = stablehlo.reshape %33 : (tensor<1xi64>) -> tensor<i64>
%from_elements_9 = tensor.from_elements %2 : tensor<1xi64>
%35 = stablehlo.convert %from_elements_9 : tensor<1xi64>
%36 = stablehlo.reshape %35 : (tensor<1xi64>) -> tensor<i64>
%37 = stablehlo.subtract %34, %32 : tensor<i64>
%38 = stablehlo.divide %37, %36 : tensor<i64>
%39 = stablehlo.reshape %38 : (tensor<i64>) -> tensor<1xi64>
%40 = stablehlo.dynamic_iota %39, dim = 0 : (tensor<1xi64>) -> tensor<3xi64>
%41 = chlo.broadcast_multiply %40, %36 : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%42 = chlo.broadcast_add %41, %32 : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%43 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
%from_elements_10 = tensor.from_elements %3 : tensor<1xi64>
%44 = stablehlo.convert %from_elements_10 : tensor<1xi64>
%45 = stablehlo.reshape %44 : (tensor<1xi64>) -> tensor<i64>
%from_elements_11 = tensor.from_elements %2 : tensor<1xi64>
%46 = stablehlo.convert %from_elements_11 : tensor<1xi64>
%47 = stablehlo.reshape %46 : (tensor<1xi64>) -> tensor<i64>
%from_elements_12 = tensor.from_elements %2 : tensor<1xi64>
%48 = stablehlo.convert %from_elements_12 : tensor<1xi64>
%49 = stablehlo.reshape %48 : (tensor<1xi64>) -> tensor<i64>
%50 = stablehlo.subtract %47, %45 : tensor<i64>
%51 = stablehlo.divide %50, %49 : tensor<i64>
%52 = stablehlo.reshape %51 : (tensor<i64>) -> tensor<1xi64>
%53 = stablehlo.dynamic_iota %52, dim = 0 : (tensor<1xi64>) -> tensor<1xi64>
%54 = chlo.broadcast_multiply %53, %49 : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
%55 = chlo.broadcast_add %54, %45 : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
%56 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%57 = torch.aten.unsqueeze %56, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%58 = torch.prim.ListConstruct %57, %43 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%59 = torch.aten.index_put.hacked_twin %9, %58, %30, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
return %59 : !torch.vtensor<[1,4],si64>
}
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.unsqueeze'(0xf3ee7a0) {
%82 = "torch.aten.unsqueeze"(%81, %8) : (!torch.vtensor<[1],si64>, !torch.int) -> !torch.vtensor<[1,1],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.unsqueeze -> ()' {
Trying to match "mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenUnsqueezeOp>"
** Insert : 'arith.constant'(0xf46aa20)
** Insert : 'tensor.dim'(0xf46dca0)
** Insert : 'arith.index_cast'(0xf46dc10)
** Insert : 'arith.constant'(0xf46dd50)
** Insert : 'tensor.from_elements'(0xf46b170)
** Insert : 'stablehlo.dynamic_reshape'(0xf46b220)
** Replace : 'torch.aten.unsqueeze'(0xf3ee7a0)
"mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenUnsqueezeOp>" result 1
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf46aa20) {
%82 = "arith.constant"() <{value = 0 : index}> : () -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.dim'(0xf46dca0) {
%83 = "tensor.dim"(%80, %82) : (tensor<1xi64>, index) -> index
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.index_cast'(0xf46dc10) {
%84 = "arith.index_cast"(%83) : (index) -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'arith.constant'(0xf46dd50) {
%85 = "arith.constant"() <{value = 1 : i64}> : () -> i64
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'tensor.from_elements'(0xf46b170) {
%86 = "tensor.from_elements"(%84, %85) : (i64, i64) -> tensor<2xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.dynamic_reshape'(0xf46b220) {
%87 = "stablehlo.dynamic_reshape"(%80, %86) : (tensor<1xi64>, tensor<2xi64>) -> tensor<1x1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[1,4],si64> to tensor<1x4xi64>
%none = torch.constant.none
%int4 = torch.constant.int 4
%1 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
%int1 = torch.constant.int 1
%2 = builtin.unrealized_conversion_cast %int1 : !torch.int to i64
%int0 = torch.constant.int 0
%3 = builtin.unrealized_conversion_cast %int0 : !torch.int to i64
%int-1 = torch.constant.int -1
%4 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
%false = torch.constant.bool false
%5 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
%6 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = stablehlo.constant dense<0> : tensor<1x4xi32>
%8 = stablehlo.convert %7 : (tensor<1x4xi32>) -> tensor<1x4xi64>
%9 = torch.aten.zeros %6, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
%c1 = arith.constant 1 : index
%dim = tensor.dim %0, %c1 : tensor<1x4xi64>
%10 = arith.index_cast %dim : index to i64
%c0_i64 = arith.constant 0 : i64
%11 = arith.subi %c0_i64, %10 : i64
%12 = arith.maxsi %11, %3 : i64
%13 = arith.minsi %10, %12 : i64
%14 = arith.addi %10, %13 : i64
%15 = arith.cmpi sge, %13, %c0_i64 : i64
%16 = arith.select %15, %13, %14 : i64
%c0_i64_0 = arith.constant 0 : i64
%17 = arith.subi %c0_i64_0, %10 : i64
%18 = arith.maxsi %17, %4 : i64
%19 = arith.minsi %10, %18 : i64
%20 = arith.addi %10, %19 : i64
%21 = arith.cmpi sge, %19, %c0_i64_0 : i64
%22 = arith.select %21, %19, %20 : i64
%c0 = arith.constant 0 : index
%dim_1 = tensor.dim %0, %c0 : tensor<1x4xi64>
%23 = arith.index_cast %dim_1 : index to i64
%c1_2 = arith.constant 1 : index
%dim_3 = tensor.dim %0, %c1_2 : tensor<1x4xi64>
%24 = arith.index_cast %dim_3 : index to i64
%c0_i64_4 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%25 = arith.cmpi eq, %22, %c0_i64_4 : i64
%26 = arith.select %25, %24, %22 : i64
%from_elements = tensor.from_elements %c0_i64_4, %16 : tensor<2xi64>
%from_elements_5 = tensor.from_elements %23, %26 : tensor<2xi64>
%from_elements_6 = tensor.from_elements %c1_i64, %2 : tensor<2xi64>
%27 = stablehlo.real_dynamic_slice %0, %from_elements, %from_elements_5, %from_elements_6 : (tensor<1x4xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xi64>
%28 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
%29 = stablehlo.convert %27 : tensor<1x3xi64>
%30 = torch.aten.clone %28, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
%from_elements_7 = tensor.from_elements %2 : tensor<1xi64>
%31 = stablehlo.convert %from_elements_7 : tensor<1xi64>
%32 = stablehlo.reshape %31 : (tensor<1xi64>) -> tensor<i64>
%from_elements_8 = tensor.from_elements %1 : tensor<1xi64>
%33 = stablehlo.convert %from_elements_8 : tensor<1xi64>
%34 = stablehlo.reshape %33 : (tensor<1xi64>) -> tensor<i64>
%from_elements_9 = tensor.from_elements %2 : tensor<1xi64>
%35 = stablehlo.convert %from_elements_9 : tensor<1xi64>
%36 = stablehlo.reshape %35 : (tensor<1xi64>) -> tensor<i64>
%37 = stablehlo.subtract %34, %32 : tensor<i64>
%38 = stablehlo.divide %37, %36 : tensor<i64>
%39 = stablehlo.reshape %38 : (tensor<i64>) -> tensor<1xi64>
%40 = stablehlo.dynamic_iota %39, dim = 0 : (tensor<1xi64>) -> tensor<3xi64>
%41 = chlo.broadcast_multiply %40, %36 : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%42 = chlo.broadcast_add %41, %32 : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%43 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
%from_elements_10 = tensor.from_elements %3 : tensor<1xi64>
%44 = stablehlo.convert %from_elements_10 : tensor<1xi64>
%45 = stablehlo.reshape %44 : (tensor<1xi64>) -> tensor<i64>
%from_elements_11 = tensor.from_elements %2 : tensor<1xi64>
%46 = stablehlo.convert %from_elements_11 : tensor<1xi64>
%47 = stablehlo.reshape %46 : (tensor<1xi64>) -> tensor<i64>
%from_elements_12 = tensor.from_elements %2 : tensor<1xi64>
%48 = stablehlo.convert %from_elements_12 : tensor<1xi64>
%49 = stablehlo.reshape %48 : (tensor<1xi64>) -> tensor<i64>
%50 = stablehlo.subtract %47, %45 : tensor<i64>
%51 = stablehlo.divide %50, %49 : tensor<i64>
%52 = stablehlo.reshape %51 : (tensor<i64>) -> tensor<1xi64>
%53 = stablehlo.dynamic_iota %52, dim = 0 : (tensor<1xi64>) -> tensor<1xi64>
%54 = chlo.broadcast_multiply %53, %49 : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
%55 = chlo.broadcast_add %54, %45 : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
%56 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
%c0_13 = arith.constant 0 : index
%dim_14 = tensor.dim %55, %c0_13 : tensor<1xi64>
%57 = arith.index_cast %dim_14 : index to i64
%c1_i64_15 = arith.constant 1 : i64
%from_elements_16 = tensor.from_elements %57, %c1_i64_15 : tensor<2xi64>
%58 = stablehlo.dynamic_reshape %55, %from_elements_16 : (tensor<1xi64>, tensor<2xi64>) -> tensor<1x1xi64>
%59 = torch.aten.unsqueeze %56, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%60 = torch.prim.ListConstruct %59, %43 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%61 = torch.aten.index_put.hacked_twin %9, %60, %30, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
return %61 : !torch.vtensor<[1,4],si64>
}
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.prim.ListConstruct'(0xf3efea0) {
%89 = "torch.prim.ListConstruct"(%88, %65) : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.index_put.hacked_twin'(0xf3ed6c0) {
%90 = "torch.aten.index_put.hacked_twin"(%15, %89, %49, %10) : (!torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool) -> !torch.vtensor<[1,4],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.index_put.hacked_twin -> ()' {
Trying to match "mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenIndexPutHackedTwinOp>"
** Insert : 'torch_c.to_builtin_tensor'(0xf46d950)
** Insert : 'torch_c.to_builtin_tensor'(0xf46d9e0)
** Insert : 'stablehlo.reshape'(0xf46db40)
** Insert : 'stablehlo.reshape'(0xf46b2d0)
ii = 0
indexesShape[0]: 1, 3,
indexShapeOneDim: 1, 3,
ii = 1
indexesShape[0]: 1, 3,
indexShapeOneDim: 1, 3,
** Insert : 'stablehlo.concatenate'(0xf46b7e0)
** Insert : 'stablehlo.reshape'(0xf46b890)
** Insert : 'stablehlo.reshape'(0xf46b920)
** Insert : 'stablehlo.scatter'(0xf42a550)
** Insert : 'stablehlo.return'(0xf45f830)
** Replace : 'torch.aten.index_put.hacked_twin'(0xf3ed6c0)
"mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenIndexPutHackedTwinOp>" result 1
//===-------------------------------------------===//
Legalizing operation : 'torch_c.to_builtin_tensor'(0xf46d950) {
%90 = "torch_c.to_builtin_tensor"(%88) : (!torch.vtensor<[1,1],si64>) -> tensor<1x1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch_c.to_builtin_tensor'(0xf46d9e0) {
%91 = "torch_c.to_builtin_tensor"(%65) : (!torch.vtensor<[3],si64>) -> tensor<3xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf46db40) {
%92 = "stablehlo.reshape"(%90) : (tensor<1x1xi64>) -> tensor<1x3x1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf46b2d0) {
%93 = "stablehlo.reshape"(%91) : (tensor<3xi64>) -> tensor<1x3x1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.concatenate'(0xf46b7e0) {
%94 = "stablehlo.concatenate"(%92, %93) {dimension = 2 : i64} : (tensor<1x3x1xi64>, tensor<1x3x1xi64>) -> tensor<1x3x2xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf46b890) {
%95 = "stablehlo.reshape"(%48) : (tensor<1x3xi64>) -> tensor<3x1xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.reshape'(0xf46b920) {
%96 = "stablehlo.reshape"(%94) : (tensor<1x3x2xi64>) -> tensor<3x2xi64>
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.scatter'(0xf42a550) {
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'stablehlo.return'(0xf45f830) {
"stablehlo.return"(%arg3) : (tensor<i64>) -> ()
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AtLeastNOperands<1>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasRecursiveMemoryEffects<Empty>)
number of output elements (3) doesn't match expected number of elements (1)
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (!torch.vtensor<[1,15],si64>, !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64>, sym_name = "forward"}> ({
^bb0(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>):
%0 = "builtin.unrealized_conversion_cast"(%arg1) : (!torch.vtensor<[1,4],si64>) -> tensor<1x4xi64>
%1 = "torch.constant.none"() : () -> !torch.none
%2 = "torch.constant.int"() {value = 4 : i64} : () -> !torch.int
%3 = "builtin.unrealized_conversion_cast"(%2) : (!torch.int) -> i64
%4 = "torch.constant.int"() {value = 1 : i64} : () -> !torch.int
%5 = "builtin.unrealized_conversion_cast"(%4) : (!torch.int) -> i64
%6 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int
%7 = "builtin.unrealized_conversion_cast"(%6) : (!torch.int) -> i64
%8 = "torch.constant.int"() {value = -1 : i64} : () -> !torch.int
%9 = "builtin.unrealized_conversion_cast"(%8) : (!torch.int) -> i64
%10 = "torch.constant.bool"() {value = false} : () -> !torch.bool
%11 = "builtin.unrealized_conversion_cast"(%10) : (!torch.bool) -> i1
%12 = "torch.prim.ListConstruct"(%4, %2) : (!torch.int, !torch.int) -> !torch.list<int>
%13 = "stablehlo.constant"() {value = dense<0> : tensor<1x4xi32>} : () -> tensor<1x4xi32>
%14 = "stablehlo.convert"(%13) : (tensor<1x4xi32>) -> tensor<1x4xi64>
%15 = "torch.aten.zeros"(%12, %2, %1, %1, %10) : (!torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool) -> !torch.vtensor<[1,4],si64>
%16 = "arith.constant"() <{value = 1 : index}> : () -> index
%17 = "tensor.dim"(%0, %16) : (tensor<1x4xi64>, index) -> index
%18 = "arith.index_cast"(%17) : (index) -> i64
%19 = "arith.constant"() <{value = 0 : i64}> : () -> i64
%20 = "arith.subi"(%19, %18) : (i64, i64) -> i64
%21 = "arith.maxsi"(%20, %7) : (i64, i64) -> i64
%22 = "arith.minsi"(%18, %21) : (i64, i64) -> i64
%23 = "arith.addi"(%18, %22) : (i64, i64) -> i64
%24 = "arith.cmpi"(%22, %19) <{predicate = 5 : i64}> : (i64, i64) -> i1
%25 = "arith.select"(%24, %22, %23) : (i1, i64, i64) -> i64
%26 = "arith.constant"() <{value = 0 : i64}> : () -> i64
%27 = "arith.subi"(%26, %18) : (i64, i64) -> i64
%28 = "arith.maxsi"(%27, %9) : (i64, i64) -> i64
%29 = "arith.minsi"(%18, %28) : (i64, i64) -> i64
%30 = "arith.addi"(%18, %29) : (i64, i64) -> i64
%31 = "arith.cmpi"(%29, %26) <{predicate = 5 : i64}> : (i64, i64) -> i1
%32 = "arith.select"(%31, %29, %30) : (i1, i64, i64) -> i64
%33 = "arith.constant"() <{value = 0 : index}> : () -> index
%34 = "tensor.dim"(%0, %33) : (tensor<1x4xi64>, index) -> index
%35 = "arith.index_cast"(%34) : (index) -> i64
%36 = "arith.constant"() <{value = 1 : index}> : () -> index
%37 = "tensor.dim"(%0, %36) : (tensor<1x4xi64>, index) -> index
%38 = "arith.index_cast"(%37) : (index) -> i64
%39 = "arith.constant"() <{value = 0 : i64}> : () -> i64
%40 = "arith.constant"() <{value = 1 : i64}> : () -> i64
%41 = "arith.cmpi"(%32, %39) <{predicate = 0 : i64}> : (i64, i64) -> i1
%42 = "arith.select"(%41, %38, %32) : (i1, i64, i64) -> i64
%43 = "tensor.from_elements"(%39, %25) : (i64, i64) -> tensor<2xi64>
%44 = "tensor.from_elements"(%35, %42) : (i64, i64) -> tensor<2xi64>
%45 = "tensor.from_elements"(%40, %5) : (i64, i64) -> tensor<2xi64>
%46 = "stablehlo.real_dynamic_slice"(%0, %43, %44, %45) : (tensor<1x4xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xi64>
%47 = "torch.aten.slice.Tensor"(%arg1, %4, %6, %8, %4) : (!torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3],si64>
%48 = "stablehlo.convert"(%46) : (tensor<1x3xi64>) -> tensor<1x3xi64>
%49 = "torch.aten.clone"(%47, %1) : (!torch.vtensor<[1,3],si64>, !torch.none) -> !torch.vtensor<[1,3],si64>
%50 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
%51 = "stablehlo.convert"(%50) : (tensor<1xi64>) -> tensor<1xi64>
%52 = "stablehlo.reshape"(%51) : (tensor<1xi64>) -> tensor<i64>
%53 = "tensor.from_elements"(%3) : (i64) -> tensor<1xi64>
%54 = "stablehlo.convert"(%53) : (tensor<1xi64>) -> tensor<1xi64>
%55 = "stablehlo.reshape"(%54) : (tensor<1xi64>) -> tensor<i64>
%56 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
%57 = "stablehlo.convert"(%56) : (tensor<1xi64>) -> tensor<1xi64>
%58 = "stablehlo.reshape"(%57) : (tensor<1xi64>) -> tensor<i64>
%59 = "stablehlo.subtract"(%55, %52) : (tensor<i64>, tensor<i64>) -> tensor<i64>
%60 = "stablehlo.divide"(%59, %58) : (tensor<i64>, tensor<i64>) -> tensor<i64>
%61 = "stablehlo.reshape"(%60) : (tensor<i64>) -> tensor<1xi64>
%62 = "stablehlo.dynamic_iota"(%61) {iota_dimension = 0 : i64} : (tensor<1xi64>) -> tensor<3xi64>
%63 = "chlo.broadcast_multiply"(%62, %58) : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%64 = "chlo.broadcast_add"(%63, %52) : (tensor<3xi64>, tensor<i64>) -> tensor<3xi64>
%65 = "torch.aten.arange.start_step"(%4, %2, %4, %1, %1, %1, %1) : (!torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[3],si64>
%66 = "tensor.from_elements"(%7) : (i64) -> tensor<1xi64>
%67 = "stablehlo.convert"(%66) : (tensor<1xi64>) -> tensor<1xi64>
%68 = "stablehlo.reshape"(%67) : (tensor<1xi64>) -> tensor<i64>
%69 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
%70 = "stablehlo.convert"(%69) : (tensor<1xi64>) -> tensor<1xi64>
%71 = "stablehlo.reshape"(%70) : (tensor<1xi64>) -> tensor<i64>
%72 = "tensor.from_elements"(%5) : (i64) -> tensor<1xi64>
%73 = "stablehlo.convert"(%72) : (tensor<1xi64>) -> tensor<1xi64>
%74 = "stablehlo.reshape"(%73) : (tensor<1xi64>) -> tensor<i64>
%75 = "stablehlo.subtract"(%71, %68) : (tensor<i64>, tensor<i64>) -> tensor<i64>
%76 = "stablehlo.divide"(%75, %74) : (tensor<i64>, tensor<i64>) -> tensor<i64>
%77 = "stablehlo.reshape"(%76) : (tensor<i64>) -> tensor<1xi64>
%78 = "stablehlo.dynamic_iota"(%77) {iota_dimension = 0 : i64} : (tensor<1xi64>) -> tensor<1xi64>
%79 = "chlo.broadcast_multiply"(%78, %74) : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
%80 = "chlo.broadcast_add"(%79, %68) : (tensor<1xi64>, tensor<i64>) -> tensor<1xi64>
%81 = "torch.aten.arange.start_step"(%6, %4, %4, %2, %1, %1, %1) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[1],si64>
%82 = "arith.constant"() <{value = 0 : index}> : () -> index
%83 = "tensor.dim"(%80, %82) : (tensor<1xi64>, index) -> index
%84 = "arith.index_cast"(%83) : (index) -> i64
%85 = "arith.constant"() <{value = 1 : i64}> : () -> i64
%86 = "tensor.from_elements"(%84, %85) : (i64, i64) -> tensor<2xi64>
%87 = "stablehlo.dynamic_reshape"(%80, %86) : (tensor<1xi64>, tensor<2xi64>) -> tensor<1x1xi64>
%88 = "torch.aten.unsqueeze"(%81, %8) : (!torch.vtensor<[1],si64>, !torch.int) -> !torch.vtensor<[1,1],si64>
%89 = "torch.prim.ListConstruct"(%88, %65) : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%90 = "torch_c.to_builtin_tensor"(%88) : (!torch.vtensor<[1,1],si64>) -> tensor<1x1xi64>
%91 = "torch_c.to_builtin_tensor"(%65) : (!torch.vtensor<[3],si64>) -> tensor<3xi64>
%92 = "stablehlo.reshape"(%90) : (tensor<1x1xi64>) -> tensor<1x3x1xi64>
%93 = "stablehlo.reshape"(%91) : (tensor<3xi64>) -> tensor<1x3x1xi64>
%94 = "stablehlo.concatenate"(%92, %93) {dimension = 2 : i64} : (tensor<1x3x1xi64>, tensor<1x3x1xi64>) -> tensor<1x3x2xi64>
%95 = "stablehlo.reshape"(%48) : (tensor<1x3xi64>) -> tensor<3x1xi64>
%96 = "stablehlo.reshape"(%94) : (tensor<1x3x2xi64>) -> tensor<3x2xi64>
%97 = "stablehlo.scatter"(%14, %96, %95) ({
^bb0(%arg2: tensor<i64>, %arg3: tensor<i64>):
"stablehlo.return"(%arg3) : (tensor<i64>) -> ()
}) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false} : (tensor<1x4xi64>, tensor<3x2xi64>, tensor<3x1xi64>) -> tensor<1x4xi64>
%98 = "torch.aten.index_put.hacked_twin"(%15, %89, %49, %10) : (!torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool) -> !torch.vtensor<[1,4],si64>
"func.return"(%98) : (!torch.vtensor<[1,4],si64>) -> ()
}) : () -> ()
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'func.return'(0xf3f0480) {
"func.return"(%98) : (!torch.vtensor<[1,4],si64>) -> ()
* Fold {
} -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert : 'torch_c.to_i64'(0xf470b20)
** Insert : 'torch_c.to_i64'(0xf46c240)
** Insert : 'torch_c.to_i64'(0xf46c2d0)
** Insert : 'torch_c.to_builtin_tensor'(0xf46c3a0)
** Insert : 'torch_c.to_i64'(0xf46c430)
** Insert : 'torch_c.from_builtin_tensor'(0xf46c4c0)
** Insert : 'torch_c.from_builtin_tensor'(0xf46c550)
** Insert : 'torch_c.from_builtin_tensor'(0xf46c5e0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
/nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir:17:10: error: number of output elements (3) doesn't match expected number of elements (1)
%8 = torch.aten.index_put.hacked_twin %1, %7, %3, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
@AmosLewis
Copy link
Author

test_indexputhackedtwin.mlir

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
    %none = torch.constant.none
    %int4 = torch.constant.int 4
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int-1 = torch.constant.int -1
    %false = torch.constant.bool false
    %0 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.zeros %0, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
    %2 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
    %3 = torch.aten.clone %2, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
    %4 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
    %5 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
    %6 = torch.aten.unsqueeze %5, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
    %7 = torch.prim.ListConstruct %6, %4 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
    %8 = torch.aten.index_put.hacked_twin %1, %7, %3, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
    return %8 : !torch.vtensor<[1,4],si64>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment