-
-
Save AmosLewis/20895fa06e8b61ce83e9a2b9c2ee8ca6 to your computer and use it in GitHub Desktop.
module attributes {torch.debug_module_name = "_lambda"} { | |
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>}) -> !torch.tensor { | |
%none_1 = torch.constant.none | |
%int-1 = torch.constant.int -1 | |
%false = torch.constant.bool false | |
%cpu = torch.constant.device "cpu" | |
%int1 = torch.constant.int 1 | |
%int4 = torch.constant.int 4 | |
%int0 = torch.constant.int 0 | |
%int-100 = torch.constant.int -100 | |
%int9223372036854775807 = torch.constant.int 9223372036854775807 | |
%133 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int> | |
%134 = torch.aten.new_zeros %arg2, %133, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor | |
%135 = torch.aten.slice.Tensor %arg2, %int1, %int0, %int-1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor | |
%136 = torch.aten.clone %135, %none_1 : !torch.tensor, !torch.none -> !torch.tensor | |
%137 = torch.aten.slice.Tensor %134, %int1, %int1, %int9223372036854775807, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor | |
%138 = torch.aten.copy_ %137, %136, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor | |
%141 = torch.aten.select.int %134, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor | |
return %141 : !torch.tensor | |
} | |
torch.class_type @__torch__.torch.fx.graph_module._lambda { | |
torch.method "forward", @__torch__.torch.fx.graph_module._lambda.forward | |
} | |
%132 = torch.nn_module { | |
} : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> | |
} |
torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' ./test_torchscript.mlir -mlir-print-ir-after-failure -mlir-disable-threading
./test_torchscript.mlir:12:12: error: unsupported by backend contract: tensor with unknown rank
%134 = torch.aten.new_zeros %arg2, %133, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
^
./test_torchscript.mlir:12:12: note: see current operation: %11 = "torch.tensor_static_info_cast"(%10) : (!torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64>
./test_torchscript.mlir:12:12: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int9223372036854775807 = torch.constant.int 9223372036854775807
%cpu = torch.constant.device "cpu"
%0 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%2 = torch.tensor_static_info_cast %1 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
%3 = torch.copy.to_tensor %2 : !torch.tensor<*,si64>
%4 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
%5 = torch.aten.clone %4, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
%6 = torch.aten.slice.Tensor %3, %int1, %int1, %int9223372036854775807, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,3],si64>
%7 = torch.tensor_static_info_cast %6 : !torch.tensor<[1,3],si64> to !torch.tensor<*,si64>
%8 = torch.prim.ListConstruct %int1, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%9 = torch.aten.broadcast_to %5, %8 : !torch.vtensor<[1,3],si64>, !torch.list<int> -> !torch.vtensor<[1,3],si64>
%10 = torch.tensor_static_info_cast %9 : !torch.vtensor<[1,3],si64> to !torch.vtensor<*,si64>
torch.overwrite.tensor.contents %10 overwrites %7 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
%11 = torch.aten.slice.Tensor %3, %int1, %int0, %int1, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,1],si64>
%12 = torch.aten.squeeze.dim %11, %int1 : !torch.tensor<[1,1],si64>, !torch.int -> !torch.tensor<[1],si64>
%13 = torch.tensor_static_info_cast %12 : !torch.tensor<[1],si64> to !torch.tensor<*,si64>
%14 = torch.copy.to_vtensor %13 : !torch.vtensor<*,si64>
return %14 : !torch.vtensor<*,si64>
}
}
torch.tensor_static_info_cast
At the beginning, there is no shape info in its return value !torch.vtensor
//===-------------------------------------------===//
** Insert : 'torch.tensor_static_info_cast'(0x55ea4a6426f0)
** Insert : 'torch.copy.to_tensor'(0x55ea4a652030)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneOperand<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::NonValueTensorType>::Impl<Empty>)
//===-------------------------------------------===//
Processing operation : 'torch.tensor_static_info_cast'(0x55ea4a6426f0) {
%7 = "torch.tensor_static_info_cast"(%arg1) : (!torch.vtensor<[1,4],si64>) -> !torch.vtensor
* Pattern : 'torch.tensor_static_info_cast -> ()' {
Trying to match ""
"" result 0
} -> failure : pattern failed to match
* Pattern : 'torch.tensor_static_info_cast -> ()' {
Trying to match ""
"" result 0
} -> failure : pattern failed to match
} -> failure : pattern failed to match
//===-------------------------------------------===//
//===-------------------------------------------===//
Processing operation : 'torch.copy.to_tensor'(0x55ea4a652030) {
%8 = "torch.copy.to_tensor"(%7) : (!torch.vtensor) -> !torch.tensor
} -> failure : pattern failed to match
//===-------------------------------------------===//
torch.aten.new_zeros
%0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%3 = torch.copy.to_vtensor %1 : !torch.vtensor
%4 = torch.aten.new_zeros %3, %2, %int4, %int0, %cpu, %false : !torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor
%5 = torch.copy.to_tensor %4 : !torch.tensor
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.new_zeros'(0x55ea4a63f1d0) {
%11 = "torch.aten.new_zeros"(%8, %10, %4, %5, %9, %0) : (!torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool) -> !torch.tensor
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.new_zeros -> ()' {
Trying to match "(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors"
** Insert : 'torch.copy.to_vtensor'(0x55ea4a646630)
** Insert : 'torch.copy.to_tensor'(0x55ea4a642550)
"(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors" result 1
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.new_zeros'(0x55ea4a63f1d0) {
%12 = "torch.aten.new_zeros"(%11, %10, %4, %5, %9, %0) : (!torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool) -> !torch.vtensor
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a646630) {
%11 = "torch.copy.to_vtensor"(%8) : (!torch.tensor) -> !torch.vtensor
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::ValueTensorType>::Impl<Empty>)
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.copy.to_tensor'(0x55ea4a642550) {
%13 = "torch.copy.to_tensor"(%12) : (!torch.vtensor) -> !torch.tensor
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
torch.aten.copy_
%11 = torch.copy.to_vtensor %10 : !torch.vtensor
%12 = torch.copy.to_vtensor %9 : !torch.vtensor
%13 = torch.aten.copy %11, %12, %false : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
%14 = torch.copy.to_tensor %13 : !torch.tensor
%15 = torch.copy.to_vtensor %14 : !torch.vtensor
torch.overwrite.tensor.contents %15 overwrites %10 : !torch.vtensor, !torch.tensor
The overwrite.tensor.contents is from
-> ReduceOpVariants.cpp
-> ReduceTrailingUnderscoreInplaceVariant
-> createOverwriteTensorContents
Replaces the contents of overwritten
%10(Torch_NonValueTensorType from slice) with corresponding values from value
%15/%8(Torch_ValueTensorType from clone).
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%int9223372036854775807 = torch.constant.int 9223372036854775807
%0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%cpu = torch.constant.device "cpu"
%2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.copy.to_vtensor %1 : !torch.vtensor
%4 = torch.aten.new_zeros %3, %2, %int4, %int0, %cpu, %false : !torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor
%5 = torch.copy.to_tensor %4 : !torch.tensor
%6 = torch.aten.slice.Tensor %1, %int1, %int0, %int-1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
%7 = torch.copy.to_vtensor %6 : !torch.vtensor
%8 = torch.aten.clone %7, %none : !torch.vtensor, !torch.none -> !torch.vtensor
%9 = torch.copy.to_tensor %8 : !torch.tensor
%10 = torch.aten.slice.Tensor %5, %int1, %int1, %int9223372036854775807, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
%11 = torch.copy.to_vtensor %10 : !torch.vtensor
%12 = torch.copy.to_vtensor %9 : !torch.vtensor
%13 = torch.aten.copy %11, %12, %false : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
%14 = torch.copy.to_tensor %13 : !torch.tensor
%15 = torch.copy.to_vtensor %14 : !torch.vtensor
torch.overwrite.tensor.contents %15 overwrites %10 : !torch.vtensor, !torch.tensor
%16 = torch.aten.copy_ %10, %9, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
%17 = torch.aten.select.int %5, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
return %17 : !torch.tensor
}
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.copy_'(0x55ea4a645e60) {
%19 = "torch.aten.copy_"(%18, %17, %0) : (!torch.tensor, !torch.tensor, !torch.bool) -> !torch.tensor
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.copy_ -> ()' {
Trying to match "(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors"
** Failure : does not have value semantics
"(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors" result 0
} -> FAILURE : pattern failed to match
* Pattern : 'torch.aten.copy_ -> ()' {
Trying to match "(anonymous namespace)::ReduceTrailingUnderscoreInplaceVariant"
** Insert : 'torch.aten.copy'(0x55ea4a652640)
** Insert : 'torch.copy.to_vtensor'(0x55ea4a66eed0)
** Insert : 'torch.overwrite.tensor.contents'(0x55ea4a66ef50)
** Replace : 'torch.aten.copy_'(0x55ea4a645e60)
"(anonymous namespace)::ReduceTrailingUnderscoreInplaceVariant" result 1
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.copy'(0x55ea4a652640) {
%19 = "torch.aten.copy"(%18, %17, %0) : (!torch.tensor, !torch.tensor, !torch.bool) -> !torch.tensor
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.copy -> ()' {
Trying to match "(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors"
** Insert : 'torch.copy.to_vtensor'(0x55ea4a66f020)
** Insert : 'torch.copy.to_vtensor'(0x55ea4a673380)
** Insert : 'torch.copy.to_tensor'(0x55ea4a673410)
"(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors" result 1
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.copy'(0x55ea4a652640) {
%21 = "torch.aten.copy"(%19, %20, %0) : (!torch.vtensor, !torch.vtensor, !torch.bool) -> !torch.vtensor
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a66f020) {
%19 = "torch.copy.to_vtensor"(%18) : (!torch.tensor) -> !torch.vtensor
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a673380) {
%20 = "torch.copy.to_vtensor"(%17) : (!torch.tensor) -> !torch.vtensor
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.copy.to_tensor'(0x55ea4a673410) {
%22 = "torch.copy.to_tensor"(%21) : (!torch.vtensor) -> !torch.tensor
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a66eed0) {
%23 = "torch.copy.to_vtensor"(%22) : (!torch.tensor) -> !torch.vtensor
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
//===-------------------------------------------===//
Legalizing operation : 'torch.overwrite.tensor.contents'(0x55ea4a66ef50) {
"torch.overwrite.tensor.contents"(%23, %18) : (!torch.vtensor, !torch.tensor) -> ()
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
Erase redundantly copy by AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock:
%9 = torch.copy.to_tensor %8 : !torch.tensor
%11 = torch.copy.to_vtensor %10 : !torch.vtensor
%14 = torch.copy.to_tensor %13 : !torch.tensor
%15 = torch.copy.to_vtensor %14 : !torch.vtensor
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%int9223372036854775807 = torch.constant.int 9223372036854775807
%0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%cpu = torch.constant.device "cpu"
%2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.copy.to_vtensor %1 : !torch.vtensor
%4 = torch.aten.new_zeros %3, %2, %int4, %int0, %cpu, %false : !torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor
%5 = torch.copy.to_tensor %4 : !torch.tensor
%6 = torch.aten.slice.Tensor %1, %int1, %int0, %int-1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
%7 = torch.copy.to_vtensor %6 : !torch.vtensor
%8 = torch.aten.clone %7, %none : !torch.vtensor, !torch.none -> !torch.vtensor
%9 = torch.aten.slice.Tensor %5, %int1, %int1, %int9223372036854775807, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
%10 = torch.copy.to_vtensor %9 : !torch.vtensor
%11 = torch.aten.copy %10, %8, %false : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
torch.overwrite.tensor.contents %11 overwrites %9 : !torch.vtensor, !torch.tensor
%12 = torch.aten.select.int %5, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
return %12 : !torch.tensor
}
%5 = torch.copy.to_tensor %4 : !torch.tensor
//===-------------------------------------------===//
Processing operation : 'torch.copy.to_tensor'(0x55ea4a642550) {
%13 = "torch.copy.to_tensor"(%12) : (!torch.vtensor) -> !torch.tensor
* Pattern (anonymous namespace)::AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock : 'torch.copy.to_tensor -> ()' {
Trying to match "(anonymous namespace)::AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock"
mlir-asm-printer: Verifying operation: func.func
%5 = torch.copy.to_tensor %4 : !torch.tensor
- Bf OverwriteTensorContentsOp:
mlir-asm-printer: Verifying operation: func.func
torch.overwrite.tensor.contents %11 overwrites %9 : !torch.vtensor, !torch.tensor
- Af OverwriteTensorContentsOp:
** Failure : operand of op is not a valid tensor alias
"(anonymous namespace)::AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock" result 0
} -> failure : pattern failed to match
* Pattern (anonymous namespace)::RewriteViewLikeSubgraph : 'torch.copy.to_tensor -> ()' {
Trying to match "(anonymous namespace)::RewriteViewLikeSubgraph"
** Failure : can only handle these transitive user ops
"(anonymous namespace)::RewriteViewLikeSubgraph" result 0
} -> failure : pattern failed to match
} -> failure : pattern failed to match
//===-------------------------------------------===//
torch.aten.copy_ %137, %136 will copy 136 to 137/134, which will change the 134.