Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active March 2, 2023 16:53
Show Gist options
  • Save AmosLewis/20895fa06e8b61ce83e9a2b9c2ee8ca6 to your computer and use it in GitHub Desktop.
Save AmosLewis/20895fa06e8b61ce83e9a2b9c2ee8ca6 to your computer and use it in GitHub Desktop.
module attributes {torch.debug_module_name = "_lambda"} {
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>}) -> !torch.tensor {
%none_1 = torch.constant.none
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%cpu = torch.constant.device "cpu"
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%int-100 = torch.constant.int -100
%int9223372036854775807 = torch.constant.int 9223372036854775807
%133 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%134 = torch.aten.new_zeros %arg2, %133, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
%135 = torch.aten.slice.Tensor %arg2, %int1, %int0, %int-1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
%136 = torch.aten.clone %135, %none_1 : !torch.tensor, !torch.none -> !torch.tensor
%137 = torch.aten.slice.Tensor %134, %int1, %int1, %int9223372036854775807, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
%138 = torch.aten.copy_ %137, %136, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
%141 = torch.aten.select.int %134, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
return %141 : !torch.tensor
}
torch.class_type @__torch__.torch.fx.graph_module._lambda {
torch.method "forward", @__torch__.torch.fx.graph_module._lambda.forward
}
%132 = torch.nn_module {
} : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">
}
@AmosLewis
Copy link
Author

AmosLewis commented Feb 28, 2023

torch.tensor_static_info_cast
At the beginning, there is no shape info in its return value !torch.vtensor

//===-------------------------------------------===//
** Insert  : 'torch.tensor_static_info_cast'(0x55ea4a6426f0)
** Insert  : 'torch.copy.to_tensor'(0x55ea4a652030)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneOperand<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::NonValueTensorType>::Impl<Empty>)

//===-------------------------------------------===//
Processing operation : 'torch.tensor_static_info_cast'(0x55ea4a6426f0) {
  %7 = "torch.tensor_static_info_cast"(%arg1) : (!torch.vtensor<[1,4],si64>) -> !torch.vtensor


  * Pattern  : 'torch.tensor_static_info_cast -> ()' {
Trying to match ""
"" result 0
  } -> failure : pattern failed to match

  * Pattern  : 'torch.tensor_static_info_cast -> ()' {
Trying to match ""
"" result 0
  } -> failure : pattern failed to match
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'torch.copy.to_tensor'(0x55ea4a652030) {
  %8 = "torch.copy.to_tensor"(%7) : (!torch.vtensor) -> !torch.tensor

} -> failure : pattern failed to match
//===-------------------------------------------===//

@AmosLewis
Copy link
Author

AmosLewis commented Feb 28, 2023

torch.aten.new_zeros

  %0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor

  %3 = torch.copy.to_vtensor %1 : !torch.vtensor
  %4 = torch.aten.new_zeros %3, %2, %int4, %int0, %cpu, %false : !torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor
  %5 = torch.copy.to_tensor %4 : !torch.tensor
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.new_zeros'(0x55ea4a63f1d0) {
  %11 = "torch.aten.new_zeros"(%8, %10, %4, %5, %9, %0) : (!torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool) -> !torch.tensor

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.new_zeros -> ()' {
Trying to match "(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors"
    ** Insert  : 'torch.copy.to_vtensor'(0x55ea4a646630)
    ** Insert  : 'torch.copy.to_tensor'(0x55ea4a642550)
"(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'torch.aten.new_zeros'(0x55ea4a63f1d0) {
      %12 = "torch.aten.new_zeros"(%11, %10, %4, %5, %9, %0) : (!torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool) -> !torch.vtensor

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a646630) {
      %11 = "torch.copy.to_vtensor"(%8) : (!torch.tensor) -> !torch.vtensor

ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::torch::Torch::ValueTensorType>::Impl<Empty>)
    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'torch.copy.to_tensor'(0x55ea4a642550) {
      %13 = "torch.copy.to_tensor"(%12) : (!torch.vtensor) -> !torch.tensor

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully

@AmosLewis
Copy link
Author

AmosLewis commented Feb 28, 2023

torch.aten.copy_

  %11 = torch.copy.to_vtensor %10 : !torch.vtensor  
  %12 = torch.copy.to_vtensor %9 : !torch.vtensor
  %13 = torch.aten.copy %11, %12, %false : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
  %14 = torch.copy.to_tensor %13 : !torch.tensor
  %15 = torch.copy.to_vtensor %14 : !torch.vtensor
  torch.overwrite.tensor.contents %15 overwrites %10 : !torch.vtensor, !torch.tensor

The overwrite.tensor.contents is from
-> ReduceOpVariants.cpp
-> ReduceTrailingUnderscoreInplaceVariant
-> createOverwriteTensorContents
Replaces the contents of overwritten %10(Torch_NonValueTensorType from slice) with corresponding values from value %15/%8(Torch_ValueTensorType from clone).

// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func

func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
  %false = torch.constant.bool false
  %int-1 = torch.constant.int -1
  %none = torch.constant.none
  %int1 = torch.constant.int 1
  %int4 = torch.constant.int 4
  %int0 = torch.constant.int 0
  %int9223372036854775807 = torch.constant.int 9223372036854775807
  %0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor
  %cpu = torch.constant.device "cpu"
  %2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %3 = torch.copy.to_vtensor %1 : !torch.vtensor
  %4 = torch.aten.new_zeros %3, %2, %int4, %int0, %cpu, %false : !torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor
  %5 = torch.copy.to_tensor %4 : !torch.tensor
  %6 = torch.aten.slice.Tensor %1, %int1, %int0, %int-1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %7 = torch.copy.to_vtensor %6 : !torch.vtensor
  %8 = torch.aten.clone %7, %none : !torch.vtensor, !torch.none -> !torch.vtensor
  %9 = torch.copy.to_tensor %8 : !torch.tensor
  %10 = torch.aten.slice.Tensor %5, %int1, %int1, %int9223372036854775807, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %11 = torch.copy.to_vtensor %10 : !torch.vtensor
  %12 = torch.copy.to_vtensor %9 : !torch.vtensor
  %13 = torch.aten.copy %11, %12, %false : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
  %14 = torch.copy.to_tensor %13 : !torch.tensor
  %15 = torch.copy.to_vtensor %14 : !torch.vtensor
  torch.overwrite.tensor.contents %15 overwrites %10 : !torch.vtensor, !torch.tensor
  %16 = torch.aten.copy_ %10, %9, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
  %17 = torch.aten.select.int %5, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
  return %17 : !torch.tensor
}
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.copy_'(0x55ea4a645e60) {
  %19 = "torch.aten.copy_"(%18, %17, %0) : (!torch.tensor, !torch.tensor, !torch.bool) -> !torch.tensor

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.copy_ -> ()' {
Trying to match "(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors"
    ** Failure : does not have value semantics
"(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors" result 0
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.copy_ -> ()' {
Trying to match "(anonymous namespace)::ReduceTrailingUnderscoreInplaceVariant"
    ** Insert  : 'torch.aten.copy'(0x55ea4a652640)
    ** Insert  : 'torch.copy.to_vtensor'(0x55ea4a66eed0)
    ** Insert  : 'torch.overwrite.tensor.contents'(0x55ea4a66ef50)
    ** Replace : 'torch.aten.copy_'(0x55ea4a645e60)
"(anonymous namespace)::ReduceTrailingUnderscoreInplaceVariant" result 1

    //===-------------------------------------------===//
    Legalizing operation : 'torch.aten.copy'(0x55ea4a652640) {
      %19 = "torch.aten.copy"(%18, %17, %0) : (!torch.tensor, !torch.tensor, !torch.bool) -> !torch.tensor

      * Fold {
      } -> FAILURE : unable to fold

      * Pattern : 'torch.aten.copy -> ()' {
Trying to match "(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors"
        ** Insert  : 'torch.copy.to_vtensor'(0x55ea4a66f020)
        ** Insert  : 'torch.copy.to_vtensor'(0x55ea4a673380)
        ** Insert  : 'torch.copy.to_tensor'(0x55ea4a673410)
"(anonymous namespace)::ConvertHasValueSemanticsOpsToValueTensors" result 1

        //===-------------------------------------------===//
        Legalizing operation : 'torch.aten.copy'(0x55ea4a652640) {
          %21 = "torch.aten.copy"(%19, %20, %0) : (!torch.vtensor, !torch.vtensor, !torch.bool) -> !torch.vtensor

        } -> SUCCESS : operation marked legal by the target
        //===-------------------------------------------===//

        //===-------------------------------------------===//
        Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a66f020) {
          %19 = "torch.copy.to_vtensor"(%18) : (!torch.tensor) -> !torch.vtensor

        } -> SUCCESS : operation marked legal by the target
        //===-------------------------------------------===//

        //===-------------------------------------------===//
        Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a673380) {
          %20 = "torch.copy.to_vtensor"(%17) : (!torch.tensor) -> !torch.vtensor

        } -> SUCCESS : operation marked legal by the target
        //===-------------------------------------------===//

        //===-------------------------------------------===//
        Legalizing operation : 'torch.copy.to_tensor'(0x55ea4a673410) {
          %22 = "torch.copy.to_tensor"(%21) : (!torch.vtensor) -> !torch.tensor

        } -> SUCCESS : operation marked legal by the target
        //===-------------------------------------------===//
      } -> SUCCESS : pattern applied successfully
    } -> SUCCESS
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'torch.copy.to_vtensor'(0x55ea4a66eed0) {
      %23 = "torch.copy.to_vtensor"(%22) : (!torch.tensor) -> !torch.vtensor

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//

    //===-------------------------------------------===//
    Legalizing operation : 'torch.overwrite.tensor.contents'(0x55ea4a66ef50) {
      "torch.overwrite.tensor.contents"(%23, %18) : (!torch.vtensor, !torch.tensor) -> ()

    } -> SUCCESS : operation marked legal by the target
    //===-------------------------------------------===//
  } -> SUCCESS : pattern applied successfully

@AmosLewis
Copy link
Author

AmosLewis commented Feb 28, 2023

Erase redundantly copy by AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock:

%9 = torch.copy.to_tensor %8 : !torch.tensor  
%11 = torch.copy.to_vtensor %10 : !torch.vtensor
%14 = torch.copy.to_tensor %13 : !torch.tensor
%15 = torch.copy.to_vtensor %14 : !torch.vtensor
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
  %false = torch.constant.bool false
  %int-1 = torch.constant.int -1
  %none = torch.constant.none
  %int1 = torch.constant.int 1
  %int4 = torch.constant.int 4
  %int0 = torch.constant.int 0
  %int9223372036854775807 = torch.constant.int 9223372036854775807
  %0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor
  %cpu = torch.constant.device "cpu"
  %2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %3 = torch.copy.to_vtensor %1 : !torch.vtensor
  %4 = torch.aten.new_zeros %3, %2, %int4, %int0, %cpu, %false : !torch.vtensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor
  %5 = torch.copy.to_tensor %4 : !torch.tensor
  %6 = torch.aten.slice.Tensor %1, %int1, %int0, %int-1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %7 = torch.copy.to_vtensor %6 : !torch.vtensor
  %8 = torch.aten.clone %7, %none : !torch.vtensor, !torch.none -> !torch.vtensor
  %9 = torch.aten.slice.Tensor %5, %int1, %int1, %int9223372036854775807, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %10 = torch.copy.to_vtensor %9 : !torch.vtensor
  %11 = torch.aten.copy %10, %8, %false : !torch.vtensor, !torch.vtensor, !torch.bool -> !torch.vtensor
  torch.overwrite.tensor.contents %11 overwrites %9 : !torch.vtensor, !torch.tensor
  %12 = torch.aten.select.int %5, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
  return %12 : !torch.tensor
}

@AmosLewis
Copy link
Author

%5 = torch.copy.to_tensor %4 : !torch.tensor

//===-------------------------------------------===//
Processing operation : 'torch.copy.to_tensor'(0x55ea4a642550) {
  %13 = "torch.copy.to_tensor"(%12) : (!torch.vtensor) -> !torch.tensor


  * Pattern (anonymous namespace)::AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock : 'torch.copy.to_tensor -> ()' {
Trying to match "(anonymous namespace)::AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock"
mlir-asm-printer: Verifying operation: func.func
%5 = torch.copy.to_tensor %4 : !torch.tensor
 - Bf OverwriteTensorContentsOp: 
mlir-asm-printer: Verifying operation: func.func
torch.overwrite.tensor.contents %11 overwrites %9 : !torch.vtensor, !torch.tensor
 - Af OverwriteTensorContentsOp: 
    ** Failure : operand of op is not a valid tensor alias
"(anonymous namespace)::AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::RewriteViewLikeSubgraph : 'torch.copy.to_tensor -> ()' {
Trying to match "(anonymous namespace)::RewriteViewLikeSubgraph"
    ** Failure : can only handle these transitive user ops
"(anonymous namespace)::RewriteViewLikeSubgraph" result 0
  } -> failure : pattern failed to match
} -> failure : pattern failed to match
//===-------------------------------------------===//

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment