Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created October 29, 2024 14:21
Show Gist options
  • Save pashu123/fb9a9d29b9f199d6f10bfb3c2d55ed49 to your computer and use it in GitHub Desktop.
Save pashu123/fb9a9d29b9f199d6f10bfb3c2d55ed49 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
// -----// IR Dump After ConvertTorchOnnxToTorch (convert-torch-onnx-to-torch) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%int0_0 = torch.constant.int 0
%1 = torch.aten.select.int %arg4, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int
%int1 = torch.constant.int 1
%3 = torch.aten.select.int %arg4, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
%4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
%int2 = torch.constant.int 2
%5 = torch.aten.select.int %arg4, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[],si64> -> !torch.int
%int3 = torch.constant.int 3
%7 = torch.aten.select.int %arg4, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[],si64> -> !torch.int
%9 = torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int
%10 = torch.prim.ListConstruct %4, %8, %2, %6 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%11 = torch.aten.constant_pad_nd %arg1, %10, %9 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%12 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%int0_1 = torch.constant.int 0
%int0_2 = torch.constant.int 0
%int1_3 = torch.constant.int 1
%13 = torch.aten.lt.Scalar %12, %int0_2 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],i1>
%14 = torch.aten.size.int %arg5, %int0_1 : !torch.vtensor<[2],si64>, !torch.int -> !torch.int
%15 = torch.aten.add.Scalar %12, %14, %int1_3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
%16 = torch.aten.where.self %13, %15, %12 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%17 = torch.prim.ListConstruct : () -> !torch.list<int>
%18 = torch.aten.unsqueeze %16, %int0_2 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.aten.index_select %arg5, %int0_1, %18 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%20 = torch.aten.squeeze.dim %19, %int0_1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%21 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%22 = torch.aten.div.Tensor %20, %21 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%23 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%24 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int0_4 = torch.constant.int 0
%25 = torch.aten.unsqueeze %22, %int0_4 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%26 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%27 = torch.prim.ListConstruct %23, %25, %26 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%int0_5 = torch.constant.int 0
%28 = torch.aten.cat %27, %int0_5 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%int0_6 = torch.constant.int 0
%int0_7 = torch.constant.int 0
%29 = torch.aten.select.int %28, %int0_6, %int0_7 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[1],si64> -> !torch.int
%31 = torch.aten.eq.int %30, %int0_6 : !torch.int, !torch.int -> !torch.bool
%32 = torch.aten.Int.bool %31 : !torch.bool -> !torch.int
%int0_8 = torch.constant.int 0
%33 = torch.aten.size.int %11, %int0_8 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%34 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],i1>
%35 = torch.prim.NumToTensor.Scalar %33 : !torch.int -> !torch.vtensor<[],si64>
%36 = torch.prim.NumToTensor.Scalar %30 : !torch.int -> !torch.vtensor<[],si64>
%37 = torch.aten.where.self %34, %35, %36 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%38 = torch.aten.item %37 : !torch.vtensor<[],si64> -> !torch.int
%int1_9 = torch.constant.int 1
%39 = torch.aten.select.int %28, %int0_6, %int1_9 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[1],si64> -> !torch.int
%41 = torch.aten.eq.int %40, %int0_6 : !torch.int, !torch.int -> !torch.bool
%42 = torch.aten.Int.bool %41 : !torch.bool -> !torch.int
%int1_10 = torch.constant.int 1
%43 = torch.aten.size.int %11, %int1_10 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%44 = torch.prim.NumToTensor.Scalar %42 : !torch.int -> !torch.vtensor<[],i1>
%45 = torch.prim.NumToTensor.Scalar %43 : !torch.int -> !torch.vtensor<[],si64>
%46 = torch.prim.NumToTensor.Scalar %40 : !torch.int -> !torch.vtensor<[],si64>
%47 = torch.aten.where.self %44, %45, %46 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%48 = torch.aten.item %47 : !torch.vtensor<[],si64> -> !torch.int
%int2_11 = torch.constant.int 2
%49 = torch.aten.select.int %28, %int0_6, %int2_11 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%50 = torch.aten.item %49 : !torch.vtensor<[1],si64> -> !torch.int
%51 = torch.aten.eq.int %50, %int0_6 : !torch.int, !torch.int -> !torch.bool
%52 = torch.aten.Int.bool %51 : !torch.bool -> !torch.int
%53 = torch.prim.ListConstruct %38, %48, %50 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%54 = torch.aten.reshape %11, %53 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%55 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%56 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%57 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%58 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%59 = torch.vtensor.literal(dense<-1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3_12 = torch.constant.int 3
%60 = torch.aten.unsqueeze %54, %int3_12 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%61 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int2_13 = torch.constant.int 2
%62 = torch.aten.unsqueeze %arg3, %int2_13 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%int11 = torch.constant.int 11
%none = torch.constant.none
%false = torch.constant.bool false
%63 = torch.aten.to.dtype %60, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%int11_14 = torch.constant.int 11
%none_15 = torch.constant.none
%false_16 = torch.constant.bool false
%64 = torch.aten.to.dtype %62, %int11_14, %false_16, %false_16, %none_15 : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%65 = torch.aten.logical_and %63, %64 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%int11_17 = torch.constant.int 11
%none_18 = torch.constant.none
%false_19 = torch.constant.bool false
%66 = torch.aten.to.dtype %65, %int11_17, %false_19, %false_19, %none_18 : !torch.vtensor<[?,?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],i1>
%67 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%68 = torch.aten.logical_and %66, %67 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %68 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After DecomposeComplexOps (torch-decompose-complex-ops) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%9 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%15 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%19 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%20 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%21 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%22 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%25 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
%31 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
%32 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
%33 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
%41 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
%42 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
%43 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%51 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%52 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%53 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%54 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%55 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%56 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%9 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%15 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%19 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%20 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%21 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%22 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%25 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
%31 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
%32 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
%33 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
%41 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
%42 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
%43 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%51 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%52 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%53 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%54 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%55 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%56 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ReifyShapeCalculations (torch-reify-shape-calculations) //----- //
module {
func.func private @__torch__.pad_shape_fn(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension"
%none = torch.constant.none
%str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values"
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %2 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %6 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg2: !torch.int):
%9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %arg0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.constant_pad_nd(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%3 = torch.aten.neg.int %2 : !torch.int -> !torch.int
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %18 : !torch.bool
}
%7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
%18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %18 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%10 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%11 = torch.prim.Loop %10, %true, init(%int1) {
^bb0(%arg3: !torch.int, %arg4: !torch.int):
%18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%19 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%12 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %13 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
}
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg3: !torch.int):
%18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %18 -> () {
%19 = torch.aten.append.t %16, %11 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %16 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.squeeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %12 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %11, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %12 -> () {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %14 -> () {
%15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.append.t %0, %15 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %0, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.squeeze.dim(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.div.Tensor(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%none = torch.constant.none
%str_1 = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %0, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%2 = torch.derefine %none : !torch.none to !torch.optional<int>
%3 = torch.prim.Loop %1, %true, init(%2) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.bool) {
%19 = torch.aten.__getitem__.t %13, %int0 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.optional<int>) {
%19 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%20 = torch.prim.If %19 -> (!torch.int) {
%22 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.prim.If %23 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %22 : !torch.int
}
%25 = torch.aten.neg.int %24 : !torch.int -> !torch.int
%26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool
%28 = torch.prim.If %27 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %32 : !torch.bool
}
%29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool
torch.prim.If %29 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%31 = torch.prim.If %30 -> (!torch.int) {
%32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %32 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
torch.prim.If.yield %31 : !torch.int
} else {
%22 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %22 : !torch.int
}
%21 = torch.derefine %20 : !torch.int to !torch.optional<int>
torch.prim.If.yield %21 : !torch.optional<int>
} else {
torch.prim.If.yield %arg3 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%18 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%4 = torch.aten.__is__ %3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.int) {
torch.prim.If.yield %arg1 : !torch.int
} else {
%13 = torch.prim.unchecked_cast %3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %13 : !torch.int
}
%6 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%9 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%10 = torch.prim.Loop %8, %true, init(%9) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%21 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.bool) {
%20 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %21 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.optional<list<int>>) {
%20 = torch.derefine %13 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %20 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg3 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%19 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%11 = torch.aten.__is__ %10, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.list<int>) {
%13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %13 : !torch.list<int>
} else {
%13 = torch.prim.unchecked_cast %10 : !torch.optional<list<int>> -> !torch.list<int>
%14 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int0) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%20 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%21 = torch.prim.Loop %20, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%27 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
%26 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %27 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
%26 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%27 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %28 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %29, %true, init() {
^bb0(%arg4: !torch.int):
%32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %33 -> () {
%34 = torch.aten.__getitem__.t %13, %32 : !torch.list<int>, !torch.int -> !torch.int
%35 = torch.aten.__getitem__.t %19, %32 : !torch.list<int>, !torch.int -> !torch.int
%36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %36 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%30 = torch.aten.__getitem__.t %19, %5 : !torch.list<int>, !torch.int -> !torch.int
%31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %31 : !torch.int
} else {
torch.prim.If.yield %arg3 : !torch.int
}
torch.prim.Loop.condition %true, iter(%25 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg2: !torch.int):
%19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%18 = torch.aten._set_item.t %16, %5, %15 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %16 : !torch.list<int>
}
return %12 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%arg0: !torch.float) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.where.self(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
%1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.slice(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%int9223372036854775807 = torch.constant.int 9223372036854775807
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %1 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %33 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int0 : !torch.int
}
%13 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int9223372036854775807 : !torch.int
}
%15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
torch.prim.If.yield %12 : !torch.int
}
%18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %17 : !torch.int
}
%20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
%21 = torch.prim.If %20 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %14 : !torch.int
}
%22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %19 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
torch.prim.If.yield %23 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %21 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int
%27 = torch.prim.ListConstruct : () -> !torch.list<int>
%28 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %28, %true, init() {
^bb0(%arg5: !torch.int):
%33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.append.t %27, %33 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int
%32 = torch.aten._set_item.t %27, %10, %31 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
return %27 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.slice.Tensor(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: invalid shape"
%false = torch.constant.bool false
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: only one dimension can be inferred"
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.prim.Loop %0, %true, init(%int1) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%13 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%2 = torch.prim.Uninitialized : !torch.int
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.derefine %none : !torch.none to !torch.optional<int>
%5:2 = torch.prim.Loop %3, %true, init(%int1, %4) {
^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional<int>):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool
%14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional<int>) {
%15 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%16 = torch.derefine %arg2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional<int>
} else {
%15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %2 : !torch.int
}
torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %16 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%14 = torch.prim.If %13 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %17 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %14 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
torch.prim.Loop %10, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %11 -> () {
%12 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten._set_item.t %9, %12, %13 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
return %9 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %13 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.prim.ListConstruct : () -> !torch.list<int>
%12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %12, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %11, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %11, %10, %int1 : !torch.list<int>, !torch.int, !torch.int
return %11 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg1: !torch.int):
%2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
%3 = torch.aten.append.t %0, %2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.to.dtype(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.broadcast(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.logical_and(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%58 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %5 : !torch.vtensor<[1],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%58 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%9 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %8 : !torch.vtensor<[1],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%58 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %11 : !torch.vtensor<[1],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%58 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int4 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%15 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %14 : !torch.vtensor<[1],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.shape.calculate {
%57 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?],si64>
} shapes {
%57 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%58 = torch.aten.Float.Scalar %int0 : !torch.int -> !torch.float
%59 = func.call @__torch_mlir_shape_fn.aten.constant_pad_nd(%57, %17, %58) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%19 = torch.shape.calculate {
%57 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %arg5 : !torch.vtensor<[2],si64> -> !torch.list<int>
%58 = torch.aten.size %0 : !torch.vtensor<[1],si64> -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.index_select(%57, %int0, %58) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%20 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %19 : !torch.vtensor<[1],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%21 = torch.shape.calculate {
%57 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %20 : !torch.vtensor<[],si64> -> !torch.list<int>
%58 = torch.aten.size %4 : !torch.vtensor<[],si64> -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.div.Tensor(%57, %58) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%22 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %21 : !torch.vtensor<[],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.shape.calculate {
%57 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[3],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
%58 = torch.aten.len.t %23 : !torch.list<vtensor> -> !torch.int
%true = torch.constant.bool true
torch.prim.Loop %58, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %23, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%61 = torch.aten.size %60 : !torch.vtensor -> !torch.list<int>
%62 = torch.aten.append.t %57, %61 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = func.call @__torch_mlir_shape_fn.aten.cat(%57, %int0) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%25 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %24 : !torch.vtensor<[3],si64> -> !torch.list<int>
%58 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %28 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%31 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%32 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %26 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%33 = torch.shape.calculate {
%57 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %30 : !torch.vtensor<[],i1> -> !torch.list<int>
%58 = torch.aten.size %31 : !torch.vtensor<[],si64> -> !torch.list<int>
%59 = torch.aten.size %32 : !torch.vtensor<[],si64> -> !torch.list<int>
%60 = func.call @__torch_mlir_shape_fn.aten.where.self(%57, %58, %59) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %24 : !torch.vtensor<[3],si64> -> !torch.list<int>
%58 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %38 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%41 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %39 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%42 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %36 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%43 = torch.shape.calculate {
%57 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.size %40 : !torch.vtensor<[],i1> -> !torch.list<int>
%58 = torch.aten.size %41 : !torch.vtensor<[],si64> -> !torch.list<int>
%59 = torch.aten.size %42 : !torch.vtensor<[],si64> -> !torch.list<int>
%60 = func.call @__torch_mlir_shape_fn.aten.where.self(%57, %58, %59) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.aten.size %24 : !torch.vtensor<[3],si64> -> !torch.list<int>
%58 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.shape.calculate {
%57 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?],si64>
} shapes {
%57 = torch.aten.size %18 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.view(%57, %49) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%51 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%57 = torch.aten.size %50 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%57, %int3) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%52 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%57 = torch.aten.size %arg3 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%57, %int2) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%53 = torch.shape.calculate {
%57 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%57 = torch.aten.size %51 : !torch.vtensor<[?,?,?,1],si64> -> !torch.list<int>
%58 = torch.derefine %none : !torch.none to !torch.optional<int>
%59 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%57, %int11, %false, %false, %58) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%54 = torch.shape.calculate {
%57 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%57 = torch.aten.size %52 : !torch.vtensor<[?,?,1,?],si64> -> !torch.list<int>
%58 = torch.derefine %none : !torch.none to !torch.optional<int>
%59 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%57, %int11, %false, %false, %58) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%55 = torch.shape.calculate {
%57 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%57 = torch.aten.size %53 : !torch.vtensor<[?,?,?,1],i1> -> !torch.list<int>
%58 = torch.aten.size %54 : !torch.vtensor<[?,?,1,?],i1> -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.logical_and(%57, %58) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%56 = torch.shape.calculate {
%57 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%57 = torch.aten.size %55 : !torch.vtensor<[?,?,?,?],i1> -> !torch.list<int>
%58 = torch.aten.size %1 : !torch.vtensor<[1,1,128,384],i1> -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.logical_and(%57, %58) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.pad_shape_fn(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension"
%none = torch.constant.none
%str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values"
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %2 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %6 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg2: !torch.int):
%9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %arg0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.constant_pad_nd(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.constant_pad_nd(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%true = torch.constant.bool true
%str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension"
%none = torch.constant.none
%str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values"
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %2 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %6 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg3: !torch.int):
%9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %arg0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%3 = torch.aten.neg.int %2 : !torch.int -> !torch.int
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %18 : !torch.bool
}
%7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
%18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %18 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%10 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%11 = torch.prim.Loop %10, %true, init(%int1) {
^bb0(%arg3: !torch.int, %arg4: !torch.int):
%18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%19 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%12 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %13 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
}
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg3: !torch.int):
%18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %18 -> () {
%19 = torch.aten.append.t %16, %11 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %16 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%3 = torch.aten.neg.int %2 : !torch.int -> !torch.int
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %18 : !torch.bool
}
%7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
%18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %18 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%10 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%11 = torch.prim.Loop %10, %true, init(%int1) {
^bb0(%arg3: !torch.int, %arg4: !torch.int):
%18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%19 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%12 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %13 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
}
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg3: !torch.int):
%18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %18 -> () {
%19 = torch.aten.append.t %16, %11 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %16 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.squeeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %12 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %11, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %12 -> () {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %14 -> () {
%15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.append.t %0, %15 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %0, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.squeeze.dim(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.squeeze.dim(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %12 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %11, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %12 -> () {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %14 -> () {
%15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.append.t %0, %15 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %0, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.broadcast(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.div.Tensor(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.div.Tensor(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%none = torch.constant.none
%str_1 = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %0, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%2 = torch.derefine %none : !torch.none to !torch.optional<int>
%3 = torch.prim.Loop %1, %true, init(%2) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.bool) {
%19 = torch.aten.__getitem__.t %13, %int0 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.optional<int>) {
%19 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%20 = torch.prim.If %19 -> (!torch.int) {
%22 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.prim.If %23 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %22 : !torch.int
}
%25 = torch.aten.neg.int %24 : !torch.int -> !torch.int
%26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool
%28 = torch.prim.If %27 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %32 : !torch.bool
}
%29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool
torch.prim.If %29 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%31 = torch.prim.If %30 -> (!torch.int) {
%32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %32 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
torch.prim.If.yield %31 : !torch.int
} else {
%22 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %22 : !torch.int
}
%21 = torch.derefine %20 : !torch.int to !torch.optional<int>
torch.prim.If.yield %21 : !torch.optional<int>
} else {
torch.prim.If.yield %arg3 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%18 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%4 = torch.aten.__is__ %3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.int) {
torch.prim.If.yield %arg1 : !torch.int
} else {
%13 = torch.prim.unchecked_cast %3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %13 : !torch.int
}
%6 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%9 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%10 = torch.prim.Loop %8, %true, init(%9) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%21 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.bool) {
%20 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %21 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.optional<list<int>>) {
%20 = torch.derefine %13 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %20 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg3 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%19 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%11 = torch.aten.__is__ %10, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.list<int>) {
%13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %13 : !torch.list<int>
} else {
%13 = torch.prim.unchecked_cast %10 : !torch.optional<list<int>> -> !torch.list<int>
%14 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int0) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%20 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%21 = torch.prim.Loop %20, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%27 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
%26 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %27 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
%26 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%27 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %28 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %29, %true, init() {
^bb0(%arg4: !torch.int):
%32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %33 -> () {
%34 = torch.aten.__getitem__.t %13, %32 : !torch.list<int>, !torch.int -> !torch.int
%35 = torch.aten.__getitem__.t %19, %32 : !torch.list<int>, !torch.int -> !torch.int
%36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %36 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%30 = torch.aten.__getitem__.t %19, %5 : !torch.list<int>, !torch.int -> !torch.int
%31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %31 : !torch.int
} else {
torch.prim.If.yield %arg3 : !torch.int
}
torch.prim.Loop.condition %true, iter(%25 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg2: !torch.int):
%19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%18 = torch.aten._set_item.t %16, %5, %15 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %16 : !torch.list<int>
}
return %12 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%none = torch.constant.none
%str_1 = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %0, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%2 = torch.derefine %none : !torch.none to !torch.optional<int>
%3 = torch.prim.Loop %1, %true, init(%2) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.bool) {
%19 = torch.aten.__getitem__.t %13, %int0 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.optional<int>) {
%19 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%20 = torch.prim.If %19 -> (!torch.int) {
%22 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.prim.If %23 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %22 : !torch.int
}
%25 = torch.aten.neg.int %24 : !torch.int -> !torch.int
%26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool
%28 = torch.prim.If %27 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %32 : !torch.bool
}
%29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool
torch.prim.If %29 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%31 = torch.prim.If %30 -> (!torch.int) {
%32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %32 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
torch.prim.If.yield %31 : !torch.int
} else {
%22 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %22 : !torch.int
}
%21 = torch.derefine %20 : !torch.int to !torch.optional<int>
torch.prim.If.yield %21 : !torch.optional<int>
} else {
torch.prim.If.yield %arg3 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%18 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%4 = torch.aten.__is__ %3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.int) {
torch.prim.If.yield %arg1 : !torch.int
} else {
%13 = torch.prim.unchecked_cast %3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %13 : !torch.int
}
%6 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%9 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%10 = torch.prim.Loop %8, %true, init(%9) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%21 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.bool) {
%20 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %21 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.optional<list<int>>) {
%20 = torch.derefine %13 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %20 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg3 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%19 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%11 = torch.aten.__is__ %10, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.list<int>) {
%13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %13 : !torch.list<int>
} else {
%13 = torch.prim.unchecked_cast %10 : !torch.optional<list<int>> -> !torch.list<int>
%14 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int0) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%20 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%21 = torch.prim.Loop %20, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%27 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
%26 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %27 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
%26 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%27 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %28 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %29, %true, init() {
^bb0(%arg4: !torch.int):
%32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %33 -> () {
%34 = torch.aten.__getitem__.t %13, %32 : !torch.list<int>, !torch.int -> !torch.int
%35 = torch.aten.__getitem__.t %19, %32 : !torch.list<int>, !torch.int -> !torch.int
%36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %36 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%30 = torch.aten.__getitem__.t %19, %5 : !torch.list<int>, !torch.int -> !torch.int
%31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %31 : !torch.int
} else {
torch.prim.If.yield %arg3 : !torch.int
}
torch.prim.Loop.condition %true, iter(%25 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg2: !torch.int):
%19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%18 = torch.aten._set_item.t %16, %5, %15 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %16 : !torch.list<int>
}
return %12 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%arg0: !torch.float) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.where.self(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
%1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %1 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.where.self(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg3: !torch.int):
%8 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %arg2, %13 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.bool) {
%24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%20 = torch.prim.If %19 -> (!torch.bool) {
%24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %20 -> () {
%24 = torch.aten.format(%str_0, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%25 = torch.aten.add.str %str, %24 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %25, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
%22 = torch.prim.If %21 -> (!torch.int) {
torch.prim.If.yield %17 : !torch.int
} else {
torch.prim.If.yield %15 : !torch.int
}
%23 = torch.aten.append.t %3, %22 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%5 = torch.aten.len.t %3 : !torch.list<int> -> !torch.int
%6 = torch.prim.max.int %4, %5 : !torch.int, !torch.int -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg3: !torch.int):
%8 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int
%11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %arg0, %11 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %3, %13 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.bool) {
%24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%20 = torch.prim.If %19 -> (!torch.bool) {
%24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %20 -> () {
%24 = torch.aten.format(%str_0, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%25 = torch.aten.add.str %str, %24 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %25, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
%22 = torch.prim.If %21 -> (!torch.int) {
torch.prim.If.yield %17 : !torch.int
} else {
torch.prim.If.yield %15 : !torch.int
}
%23 = torch.aten.append.t %7, %22 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %7 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.slice(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%int9223372036854775807 = torch.constant.int 9223372036854775807
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %1 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %33 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int0 : !torch.int
}
%13 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int9223372036854775807 : !torch.int
}
%15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
torch.prim.If.yield %12 : !torch.int
}
%18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %17 : !torch.int
}
%20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
%21 = torch.prim.If %20 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %14 : !torch.int
}
%22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %19 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
torch.prim.If.yield %23 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %21 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int
%27 = torch.prim.ListConstruct : () -> !torch.list<int>
%28 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %28, %true, init() {
^bb0(%arg5: !torch.int):
%33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.append.t %27, %33 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int
%32 = torch.aten._set_item.t %27, %10, %31 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
return %27 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.slice.Tensor(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.slice.Tensor(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%int9223372036854775807 = torch.constant.int 9223372036854775807
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %1 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %33 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int0 : !torch.int
}
%13 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int9223372036854775807 : !torch.int
}
%15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
torch.prim.If.yield %12 : !torch.int
}
%18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %17 : !torch.int
}
%20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
%21 = torch.prim.If %20 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %14 : !torch.int
}
%22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %19 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
torch.prim.If.yield %23 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %21 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int
%27 = torch.prim.ListConstruct : () -> !torch.list<int>
%28 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %28, %true, init() {
^bb0(%arg5: !torch.int):
%33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.append.t %27, %33 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int
%32 = torch.aten._set_item.t %27, %10, %31 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
return %27 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: invalid shape"
%false = torch.constant.bool false
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: only one dimension can be inferred"
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.prim.Loop %0, %true, init(%int1) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%13 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%2 = torch.prim.Uninitialized : !torch.int
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.derefine %none : !torch.none to !torch.optional<int>
%5:2 = torch.prim.Loop %3, %true, init(%int1, %4) {
^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional<int>):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool
%14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional<int>) {
%15 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%16 = torch.derefine %arg2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional<int>
} else {
%15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %2 : !torch.int
}
torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %16 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%14 = torch.prim.If %13 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %17 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %14 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
torch.prim.Loop %10, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %11 -> () {
%12 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten._set_item.t %9, %12, %13 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
return %9 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: invalid shape"
%false = torch.constant.bool false
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: only one dimension can be inferred"
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.prim.Loop %0, %true, init(%int1) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%13 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%2 = torch.prim.Uninitialized : !torch.int
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.derefine %none : !torch.none to !torch.optional<int>
%5:2 = torch.prim.Loop %3, %true, init(%int1, %4) {
^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional<int>):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool
%14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional<int>) {
%15 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%16 = torch.derefine %arg2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional<int>
} else {
%15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %2 : !torch.int
}
torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %16 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%14 = torch.prim.If %13 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %17 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %14 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
torch.prim.Loop %10, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %11 -> () {
%12 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten._set_item.t %9, %12, %13 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
return %9 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %13 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.prim.ListConstruct : () -> !torch.list<int>
%12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %12, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %11, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %11, %10, %int1 : !torch.list<int>, !torch.int, !torch.int
return %11 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %13 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.prim.ListConstruct : () -> !torch.list<int>
%12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %12, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %11, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %11, %10, %int1 : !torch.list<int>, !torch.int, !torch.int
return %11 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg1: !torch.int):
%2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
%3 = torch.aten.append.t %0, %2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.to.dtype(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.to.dtype(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg5: !torch.int):
%2 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%3 = torch.aten.append.t %0, %2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.logical_and(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.logical_and(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%float0.000000e00 = torch.constant.float 0.000000e+00
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%9 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int4 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%15 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.shape.calculate {
%57 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?],si64>
} shapes {
%57 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.constant_pad_nd(%57, %17, %float0.000000e00) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%19 = torch.shape.calculate {
%57 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.index_select(%57, %int0, %58) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%20 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%21 = torch.shape.calculate {
%57 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.div.Tensor(%57, %58) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%22 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%57, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.shape.calculate {
%57 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[3],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.__getitem__.t %23, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%60 = torch.aten.size %59 : !torch.vtensor -> !torch.list<int>
%61 = torch.aten.append.t %57, %60 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%58 = func.call @__torch_mlir_shape_fn.aten.cat(%57, %int0) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%25 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %28 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%31 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%32 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %26 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%33 = torch.shape.calculate {
%57 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
%60 = func.call @__torch_mlir_shape_fn.aten.where.self(%57, %58, %59) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %38 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%41 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %39 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%42 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %36 : !torch.int -> !torch.float
%58 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%57) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%43 = torch.shape.calculate {
%57 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
%60 = func.call @__torch_mlir_shape_fn.aten.where.self(%57, %58, %59) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%59 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%60 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%57, %int0, %58, %59, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.shape.calculate {
%57 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?],si64>
} shapes {
%57 = torch.aten.size %18 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.view(%57, %49) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%51 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%57 = torch.aten.size %50 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%57, %int3) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%52 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%57 = torch.aten.size %arg3 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%58 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%57, %int2) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%53 = torch.shape.calculate {
%57 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%57 = torch.aten.size %51 : !torch.vtensor<[?,?,?,1],si64> -> !torch.list<int>
%58 = torch.derefine %none : !torch.none to !torch.optional<int>
%59 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%57, %int11, %false, %false, %58) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%54 = torch.shape.calculate {
%57 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%57 = torch.aten.size %52 : !torch.vtensor<[?,?,1,?],si64> -> !torch.list<int>
%58 = torch.derefine %none : !torch.none to !torch.optional<int>
%59 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%57, %int11, %false, %false, %58) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%55 = torch.shape.calculate {
%57 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%57 = torch.aten.size %53 : !torch.vtensor<[?,?,?,1],i1> -> !torch.list<int>
%58 = torch.aten.size %54 : !torch.vtensor<[?,?,1,?],i1> -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.logical_and(%57, %58) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%56 = torch.shape.calculate {
%57 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%57 = torch.aten.size %55 : !torch.vtensor<[?,?,?,?],i1> -> !torch.list<int>
%58 = torch.prim.ListConstruct %int1, %int1, %int128, %int384 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%59 = func.call @__torch_mlir_shape_fn.aten.logical_and(%57, %58) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%str_3 = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_4 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_5 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%9 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%15 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.shape.calculate {
%57 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?],si64>
} shapes {
%57 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
torch.prim.Loop %int2, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.add.int %arg6, %int1 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.neg.int %58 : !torch.int -> !torch.int
%60 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.__getitem__.t %17, %60 : !torch.list<int>, !torch.int -> !torch.int
%62 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.add.int %62, %int1 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten.__getitem__.t %17, %63 : !torch.list<int>, !torch.int -> !torch.int
%65 = torch.aten.add.int %61, %64 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.size.int %arg1, %59 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%67 = torch.aten.add.int %66, %65 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten._set_item.t %57, %59, %67 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%19 = torch.shape.calculate {
%57 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%59 = torch.prim.Loop %int1, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%61 = torch.aten.__getitem__.t %58, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%62 = torch.aten.mul.int %arg7, %61 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%62 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%61 = torch.aten.eq.int %int0, %arg6 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.append.t %60, %59 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %60, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%20 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%21 = torch.shape.calculate {
%57 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %57, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %58, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ne.int %64, %66 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.bool) {
%73 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%69 = torch.prim.If %68 -> (!torch.bool) {
%73 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %69 -> () {
%73 = torch.aten.format(%str_4, %64, %66, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%74 = torch.aten.add.str %str_5, %73 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %74, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
torch.prim.If.yield %66 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%72 = torch.aten.append.t %59, %71 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%22 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%60 = torch.aten.append.t %58, %59 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.shape.calculate {
%57 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[3],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%71 = torch.aten.__getitem__.t %23, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%72 = torch.aten.size %71 : !torch.vtensor -> !torch.list<int>
%73 = torch.aten.append.t %57, %72 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%58 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %58, %true, init() {
^bb0(%arg6: !torch.int):
%71 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%72 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%73 = torch.aten.gt.int %72, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %73 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%60 = torch.derefine %none : !torch.none to !torch.optional<int>
%61 = torch.prim.Loop %59, %true, init(%60) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<int>):
%71 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%72 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%73 = torch.aten.eq.int %72, %int1 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%77 = torch.aten.__getitem__.t %71, %int0 : !torch.list<int>, !torch.int -> !torch.int
%78 = torch.aten.eq.int %77, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.aten.__not__ %74 : !torch.bool -> !torch.bool
%76 = torch.prim.If %75 -> (!torch.optional<int>) {
%77 = torch.aten.__is__ %arg7, %none : !torch.optional<int>, !torch.none -> !torch.bool
%78 = torch.prim.If %77 -> (!torch.int) {
%80 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%81 = torch.aten.le.int %80, %int0 : !torch.int, !torch.int -> !torch.bool
%82 = torch.prim.If %81 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %80 : !torch.int
}
%83 = torch.aten.neg.int %82 : !torch.int -> !torch.int
%84 = torch.aten.sub.int %82, %int1 : !torch.int, !torch.int -> !torch.int
%85 = torch.aten.lt.int %int0, %83 : !torch.int, !torch.int -> !torch.bool
%86 = torch.prim.If %85 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%88 = torch.aten.gt.int %int0, %84 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %88 : !torch.bool
}
%87 = torch.aten.__not__ %86 : !torch.bool -> !torch.bool
torch.prim.If %87 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield %int0 : !torch.int
} else {
%80 = torch.prim.unchecked_cast %arg7 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %80 : !torch.int
}
%79 = torch.derefine %78 : !torch.int to !torch.optional<int>
torch.prim.If.yield %79 : !torch.optional<int>
} else {
torch.prim.If.yield %arg7 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%76 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%62 = torch.aten.__is__ %61, %none : !torch.optional<int>, !torch.none -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%71 = torch.prim.unchecked_cast %61 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %71 : !torch.int
}
%64 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%65 = torch.aten.gt.int %64, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %65 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%66 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%67 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%68 = torch.prim.Loop %66, %true, init(%67) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<list<int>>):
%71 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%72 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%73 = torch.prim.Loop %72, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%78 = torch.aten.__getitem__.t %71, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%79 = torch.aten.mul.int %arg9, %78 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%79 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%74 = torch.aten.eq.int %73, %int0 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.bool) {
%78 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%79 = torch.aten.eq.int %78, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%76 = torch.aten.__not__ %75 : !torch.bool -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.optional<list<int>>) {
%78 = torch.derefine %71 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %78 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg7 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%77 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%69 = torch.aten.__is__ %68, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.list<int>) {
%71 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %71 : !torch.list<int>
} else {
%71 = torch.prim.unchecked_cast %68 : !torch.optional<list<int>> -> !torch.list<int>
%72 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%73 = torch.prim.Loop %72, %true, init(%int0) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%77 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%78 = torch.aten.len.t %77 : !torch.list<int> -> !torch.int
%79 = torch.prim.Loop %78, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%84 = torch.aten.__getitem__.t %77, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%85 = torch.aten.mul.int %arg9, %84 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%85 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%80 = torch.aten.eq.int %79, %int0 : !torch.int, !torch.int -> !torch.bool
%81 = torch.prim.If %80 -> (!torch.bool) {
%84 = torch.aten.len.t %77 : !torch.list<int> -> !torch.int
%85 = torch.aten.eq.int %84, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %85 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%82 = torch.aten.__not__ %81 : !torch.bool -> !torch.bool
%83 = torch.prim.If %82 -> (!torch.int) {
%84 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%85 = torch.aten.len.t %77 : !torch.list<int> -> !torch.int
%86 = torch.aten.eq.int %84, %85 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %86 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%87 = torch.aten.__range_length %int0, %84, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %87, %true, init() {
^bb0(%arg8: !torch.int):
%90 = torch.aten.__derive_index %arg8, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%91 = torch.aten.ne.int %90, %63 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %91 -> () {
%92 = torch.aten.__getitem__.t %71, %90 : !torch.list<int>, !torch.int -> !torch.int
%93 = torch.aten.__getitem__.t %77, %90 : !torch.list<int>, !torch.int -> !torch.int
%94 = torch.aten.eq.int %92, %93 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %94 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%88 = torch.aten.__getitem__.t %77, %63 : !torch.list<int>, !torch.int -> !torch.int
%89 = torch.aten.add.int %arg7, %88 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %89 : !torch.int
} else {
torch.prim.If.yield %arg7 : !torch.int
}
torch.prim.Loop.condition %true, iter(%83 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%74 = torch.prim.ListConstruct : () -> !torch.list<int>
%75 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
torch.prim.Loop %75, %true, init() {
^bb0(%arg6: !torch.int):
%77 = torch.aten.__getitem__.t %71, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%78 = torch.aten.append.t %74, %77 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%76 = torch.aten._set_item.t %74, %63, %73 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %74 : !torch.list<int>
}
torch.shape.calculate.yield.shapes %70 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%25 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %28 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%31 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%32 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %26 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%33 = torch.shape.calculate {
%57 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.ge.int %65, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %58, %65 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %59, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ne.int %68, %70 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.bool) {
%77 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%73 = torch.prim.If %72 -> (!torch.bool) {
%77 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %73 -> () {
%77 = torch.aten.format(%str_4, %68, %70, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%78 = torch.aten.add.str %str_5, %77 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %78, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.aten.eq.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.int) {
torch.prim.If.yield %70 : !torch.int
} else {
torch.prim.If.yield %68 : !torch.int
}
%76 = torch.aten.append.t %60, %75 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%61 = torch.aten.len.t %60 : !torch.list<int> -> !torch.int
%62 = torch.prim.max.int %int0, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %62, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %62, %int1 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %64, %arg6 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %65 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.sub.int %61, %int1 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten.sub.int %67, %65 : !torch.int, !torch.int -> !torch.int
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %57, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ge.int %68, %int0 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %60, %68 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%73 = torch.aten.ne.int %70, %72 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%79 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.prim.If %74 -> (!torch.bool) {
%79 = torch.aten.ne.int %72, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %75 -> () {
%79 = torch.aten.format(%str_4, %70, %72, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%80 = torch.aten.add.str %str_5, %79 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %80, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%76 = torch.aten.eq.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.int) {
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %70 : !torch.int
}
%78 = torch.aten.append.t %63, %77 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %63 : !torch.list<int>
} : !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %38 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%41 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %39 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%42 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %36 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%43 = torch.shape.calculate {
%57 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.ge.int %65, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %58, %65 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %59, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ne.int %68, %70 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.bool) {
%77 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%73 = torch.prim.If %72 -> (!torch.bool) {
%77 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %73 -> () {
%77 = torch.aten.format(%str_4, %68, %70, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%78 = torch.aten.add.str %str_5, %77 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %78, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.aten.eq.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.int) {
torch.prim.If.yield %70 : !torch.int
} else {
torch.prim.If.yield %68 : !torch.int
}
%76 = torch.aten.append.t %60, %75 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%61 = torch.aten.len.t %60 : !torch.list<int> -> !torch.int
%62 = torch.prim.max.int %int0, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %62, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %62, %int1 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %64, %arg6 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %65 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.sub.int %61, %int1 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten.sub.int %67, %65 : !torch.int, !torch.int -> !torch.int
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %57, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ge.int %68, %int0 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %60, %68 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%73 = torch.aten.ne.int %70, %72 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%79 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.prim.If %74 -> (!torch.bool) {
%79 = torch.aten.ne.int %72, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %75 -> () {
%79 = torch.aten.format(%str_4, %70, %72, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%80 = torch.aten.add.str %str_5, %79 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %80, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%76 = torch.aten.eq.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.int) {
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %70 : !torch.int
}
%78 = torch.aten.append.t %63, %77 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %63 : !torch.list<int>
} : !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.shape.calculate {
%57 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?],si64>
} shapes {
%57 = torch.prim.Loop %int2, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%66 = torch.aten.size.int %18, %arg6 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%67 = torch.aten.mul.int %arg7, %66 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%67 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%58 = torch.prim.Uninitialized : !torch.int
%59 = torch.derefine %none : !torch.none to !torch.optional<int>
%60:2 = torch.prim.Loop %int3, %true, init(%int1, %59) {
^bb0(%arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.optional<int>):
%66 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%67 = torch.aten.eq.int %66, %int-1 : !torch.int, !torch.int -> !torch.bool
%68:2 = torch.prim.If %67 -> (!torch.int, !torch.optional<int>) {
%69 = torch.aten.__isnot__ %arg8, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.derefine %arg6 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg7, %70 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%70 = torch.aten.ge.int %69, %int0 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
%72 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%73 = torch.aten.mul.int %arg7, %72 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %58 : !torch.int
}
torch.prim.If.yield %71, %arg8 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%68#0, %68#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%61 = torch.aten.eq.int %57, %60#0 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%66 = torch.aten.__isnot__ %60#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%69 = torch.prim.unchecked_cast %60#1 : !torch.optional<int> -> !torch.int
%70 = torch.aten.gt.int %60#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %70 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
%69 = torch.prim.unchecked_cast %60#1 : !torch.optional<int> -> !torch.int
%70 = torch.aten.remainder.int %57, %60#0 : !torch.int, !torch.int -> !torch.int
%71 = torch.aten.eq.int %70, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %71 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %68 : !torch.bool
}
%63 = torch.aten.__not__ %62 : !torch.bool -> !torch.bool
torch.prim.If %63 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%64 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%66 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%67 = torch.aten.append.t %64, %66 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%65 = torch.aten.__isnot__ %60#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %65 -> () {
%66 = torch.prim.unchecked_cast %60#1 : !torch.optional<int> -> !torch.int
%67 = torch.aten.floordiv.int %57, %60#0 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten._set_item.t %64, %66, %67 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %64 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%51 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %50, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %57, %int3, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%52 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %arg3, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %57, %int2, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%53 = torch.shape.calculate {
%57 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %51, %arg6 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%54 = torch.shape.calculate {
%57 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %52, %arg6 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%55 = torch.shape.calculate {
%57 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.sub.int %int3, %58 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.sub.int %int3, %58 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.ge.int %59, %int0 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.int) {
%71 = torch.aten.size.int %53, %59 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%63 = torch.aten.ge.int %60, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%71 = torch.aten.size.int %54, %60 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ne.int %62, %64 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.bool) {
%71 = torch.aten.ne.int %62, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %71 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%67 = torch.prim.If %66 -> (!torch.bool) {
%71 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %71 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %67 -> () {
%71 = torch.aten.format(%str_4, %62, %64, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%72 = torch.aten.add.str %str_5, %71 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %72, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%68 = torch.aten.eq.int %62, %int1 : !torch.int, !torch.int -> !torch.bool
%69 = torch.prim.If %68 -> (!torch.int) {
torch.prim.If.yield %64 : !torch.int
} else {
torch.prim.If.yield %62 : !torch.int
}
%70 = torch.aten.append.t %57, %69 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%56 = torch.shape.calculate {
%57 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%57 = torch.prim.ListConstruct %int1, %int1, %int128, %int384 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.sub.int %int3, %59 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int3, %59 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.ge.int %60, %int0 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.int) {
%72 = torch.aten.size.int %55, %60 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%64 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%65 = torch.prim.If %64 -> (!torch.int) {
%72 = torch.aten.__getitem__.t %57, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%66 = torch.aten.ne.int %63, %65 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%72 = torch.aten.ne.int %63, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %72 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
%72 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %72 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%72 = torch.aten.format(%str_4, %63, %65, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%73 = torch.aten.add.str %str_5, %72 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %73, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %63, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %65 : !torch.int
} else {
torch.prim.If.yield %63 : !torch.int
}
%71 = torch.aten.append.t %58, %70 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Inliner (inline) //----- //
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%str_3 = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_4 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_5 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%9 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%15 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.shape.calculate {
%57 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?],si64>
} shapes {
%57 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
torch.prim.Loop %int2, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.add.int %arg6, %int1 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.neg.int %58 : !torch.int -> !torch.int
%60 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.__getitem__.t %17, %60 : !torch.list<int>, !torch.int -> !torch.int
%62 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.add.int %62, %int1 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten.__getitem__.t %17, %63 : !torch.list<int>, !torch.int -> !torch.int
%65 = torch.aten.add.int %61, %64 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.size.int %arg1, %59 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%67 = torch.aten.add.int %66, %65 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten._set_item.t %57, %59, %67 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%19 = torch.shape.calculate {
%57 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%59 = torch.prim.Loop %int1, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%61 = torch.aten.__getitem__.t %58, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%62 = torch.aten.mul.int %arg7, %61 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%62 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%61 = torch.aten.eq.int %int0, %arg6 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.append.t %60, %59 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %60, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%20 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %59 -> () {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %58, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%21 = torch.shape.calculate {
%57 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %57, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %58, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ne.int %64, %66 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.bool) {
%73 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%69 = torch.prim.If %68 -> (!torch.bool) {
%73 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %69 -> () {
%73 = torch.aten.format(%str_4, %64, %66, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%74 = torch.aten.add.str %str_5, %73 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %74, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
torch.prim.If.yield %66 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%72 = torch.aten.append.t %59, %71 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%22 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%60 = torch.aten.append.t %58, %59 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.shape.calculate {
%57 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[3],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%71 = torch.aten.__getitem__.t %23, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%72 = torch.aten.size %71 : !torch.vtensor -> !torch.list<int>
%73 = torch.aten.append.t %57, %72 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%58 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %58, %true, init() {
^bb0(%arg6: !torch.int):
%71 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%72 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%73 = torch.aten.gt.int %72, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %73 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%60 = torch.derefine %none : !torch.none to !torch.optional<int>
%61 = torch.prim.Loop %59, %true, init(%60) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<int>):
%71 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%72 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%73 = torch.aten.eq.int %72, %int1 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%77 = torch.aten.__getitem__.t %71, %int0 : !torch.list<int>, !torch.int -> !torch.int
%78 = torch.aten.eq.int %77, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.aten.__not__ %74 : !torch.bool -> !torch.bool
%76 = torch.prim.If %75 -> (!torch.optional<int>) {
%77 = torch.aten.__is__ %arg7, %none : !torch.optional<int>, !torch.none -> !torch.bool
%78 = torch.prim.If %77 -> (!torch.int) {
%80 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%81 = torch.aten.le.int %80, %int0 : !torch.int, !torch.int -> !torch.bool
%82 = torch.prim.If %81 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %80 : !torch.int
}
%83 = torch.aten.neg.int %82 : !torch.int -> !torch.int
%84 = torch.aten.sub.int %82, %int1 : !torch.int, !torch.int -> !torch.int
%85 = torch.aten.lt.int %int0, %83 : !torch.int, !torch.int -> !torch.bool
%86 = torch.prim.If %85 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%88 = torch.aten.gt.int %int0, %84 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %88 : !torch.bool
}
%87 = torch.aten.__not__ %86 : !torch.bool -> !torch.bool
torch.prim.If %87 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield %int0 : !torch.int
} else {
%80 = torch.prim.unchecked_cast %arg7 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %80 : !torch.int
}
%79 = torch.derefine %78 : !torch.int to !torch.optional<int>
torch.prim.If.yield %79 : !torch.optional<int>
} else {
torch.prim.If.yield %arg7 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%76 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%62 = torch.aten.__is__ %61, %none : !torch.optional<int>, !torch.none -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%71 = torch.prim.unchecked_cast %61 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %71 : !torch.int
}
%64 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%65 = torch.aten.gt.int %64, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %65 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%66 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%67 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%68 = torch.prim.Loop %66, %true, init(%67) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<list<int>>):
%71 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%72 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%73 = torch.prim.Loop %72, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%78 = torch.aten.__getitem__.t %71, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%79 = torch.aten.mul.int %arg9, %78 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%79 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%74 = torch.aten.eq.int %73, %int0 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.bool) {
%78 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%79 = torch.aten.eq.int %78, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%76 = torch.aten.__not__ %75 : !torch.bool -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.optional<list<int>>) {
%78 = torch.derefine %71 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %78 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg7 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%77 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%69 = torch.aten.__is__ %68, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.list<int>) {
%71 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %71 : !torch.list<int>
} else {
%71 = torch.prim.unchecked_cast %68 : !torch.optional<list<int>> -> !torch.list<int>
%72 = torch.aten.len.t %57 : !torch.list<list<int>> -> !torch.int
%73 = torch.prim.Loop %72, %true, init(%int0) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%77 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%78 = torch.aten.len.t %77 : !torch.list<int> -> !torch.int
%79 = torch.prim.Loop %78, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%84 = torch.aten.__getitem__.t %77, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%85 = torch.aten.mul.int %arg9, %84 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%85 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%80 = torch.aten.eq.int %79, %int0 : !torch.int, !torch.int -> !torch.bool
%81 = torch.prim.If %80 -> (!torch.bool) {
%84 = torch.aten.len.t %77 : !torch.list<int> -> !torch.int
%85 = torch.aten.eq.int %84, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %85 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%82 = torch.aten.__not__ %81 : !torch.bool -> !torch.bool
%83 = torch.prim.If %82 -> (!torch.int) {
%84 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
%85 = torch.aten.len.t %77 : !torch.list<int> -> !torch.int
%86 = torch.aten.eq.int %84, %85 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %86 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%87 = torch.aten.__range_length %int0, %84, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %87, %true, init() {
^bb0(%arg8: !torch.int):
%90 = torch.aten.__derive_index %arg8, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%91 = torch.aten.ne.int %90, %63 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %91 -> () {
%92 = torch.aten.__getitem__.t %71, %90 : !torch.list<int>, !torch.int -> !torch.int
%93 = torch.aten.__getitem__.t %77, %90 : !torch.list<int>, !torch.int -> !torch.int
%94 = torch.aten.eq.int %92, %93 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %94 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%88 = torch.aten.__getitem__.t %77, %63 : !torch.list<int>, !torch.int -> !torch.int
%89 = torch.aten.add.int %arg7, %88 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %89 : !torch.int
} else {
torch.prim.If.yield %arg7 : !torch.int
}
torch.prim.Loop.condition %true, iter(%83 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%74 = torch.prim.ListConstruct : () -> !torch.list<int>
%75 = torch.aten.len.t %71 : !torch.list<int> -> !torch.int
torch.prim.Loop %75, %true, init() {
^bb0(%arg6: !torch.int):
%77 = torch.aten.__getitem__.t %71, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%78 = torch.aten.append.t %74, %77 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%76 = torch.aten._set_item.t %74, %63, %73 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %74 : !torch.list<int>
}
torch.shape.calculate.yield.shapes %70 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%25 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %28 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%31 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%32 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %26 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%33 = torch.shape.calculate {
%57 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.ge.int %65, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %58, %65 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %59, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ne.int %68, %70 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.bool) {
%77 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%73 = torch.prim.If %72 -> (!torch.bool) {
%77 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %73 -> () {
%77 = torch.aten.format(%str_4, %68, %70, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%78 = torch.aten.add.str %str_5, %77 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %78, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.aten.eq.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.int) {
torch.prim.If.yield %70 : !torch.int
} else {
torch.prim.If.yield %68 : !torch.int
}
%76 = torch.aten.append.t %60, %75 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%61 = torch.aten.len.t %60 : !torch.list<int> -> !torch.int
%62 = torch.prim.max.int %int0, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %62, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %62, %int1 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %64, %arg6 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %65 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.sub.int %61, %int1 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten.sub.int %67, %65 : !torch.int, !torch.int -> !torch.int
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %57, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ge.int %68, %int0 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %60, %68 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%73 = torch.aten.ne.int %70, %72 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%79 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.prim.If %74 -> (!torch.bool) {
%79 = torch.aten.ne.int %72, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %75 -> () {
%79 = torch.aten.format(%str_4, %70, %72, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%80 = torch.aten.add.str %str_5, %79 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %80, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%76 = torch.aten.eq.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.int) {
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %70 : !torch.int
}
%78 = torch.aten.append.t %63, %77 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %63 : !torch.list<int>
} : !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %38 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%41 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %39 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%42 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %36 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%43 = torch.shape.calculate {
%57 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %64 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.ge.int %65, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %58, %65 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%77 = torch.aten.__getitem__.t %59, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %77 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ne.int %68, %70 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.bool) {
%77 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%73 = torch.prim.If %72 -> (!torch.bool) {
%77 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %73 -> () {
%77 = torch.aten.format(%str_4, %68, %70, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%78 = torch.aten.add.str %str_5, %77 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %78, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.aten.eq.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.int) {
torch.prim.If.yield %70 : !torch.int
} else {
torch.prim.If.yield %68 : !torch.int
}
%76 = torch.aten.append.t %60, %75 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%61 = torch.aten.len.t %60 : !torch.list<int> -> !torch.int
%62 = torch.prim.max.int %int0, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %62, %true, init() {
^bb0(%arg6: !torch.int):
%64 = torch.aten.sub.int %62, %int1 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.sub.int %64, %arg6 : !torch.int, !torch.int -> !torch.int
%66 = torch.aten.sub.int %int-1, %65 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.sub.int %61, %int1 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten.sub.int %67, %65 : !torch.int, !torch.int -> !torch.int
%69 = torch.aten.ge.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %57, %66 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%71 = torch.aten.ge.int %68, %int0 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.int) {
%79 = torch.aten.__getitem__.t %60, %68 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %79 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%73 = torch.aten.ne.int %70, %72 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%79 = torch.aten.ne.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.prim.If %74 -> (!torch.bool) {
%79 = torch.aten.ne.int %72, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %75 -> () {
%79 = torch.aten.format(%str_4, %70, %72, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%80 = torch.aten.add.str %str_5, %79 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %80, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%76 = torch.aten.eq.int %70, %int1 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.int) {
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %70 : !torch.int
}
%78 = torch.aten.append.t %63, %77 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %63 : !torch.list<int>
} : !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.__getitem__.t %57, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.append.t %58, %60 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%59 = torch.aten._set_item.t %58, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.shape.calculate {
%57 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?],si64>
} shapes {
%57 = torch.prim.Loop %int2, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%66 = torch.aten.size.int %18, %arg6 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%67 = torch.aten.mul.int %arg7, %66 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%67 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%58 = torch.prim.Uninitialized : !torch.int
%59 = torch.derefine %none : !torch.none to !torch.optional<int>
%60:2 = torch.prim.Loop %int3, %true, init(%int1, %59) {
^bb0(%arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.optional<int>):
%66 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%67 = torch.aten.eq.int %66, %int-1 : !torch.int, !torch.int -> !torch.bool
%68:2 = torch.prim.If %67 -> (!torch.int, !torch.optional<int>) {
%69 = torch.aten.__isnot__ %arg8, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.derefine %arg6 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg7, %70 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%70 = torch.aten.ge.int %69, %int0 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
%72 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%73 = torch.aten.mul.int %arg7, %72 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %58 : !torch.int
}
torch.prim.If.yield %71, %arg8 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%68#0, %68#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%61 = torch.aten.eq.int %57, %60#0 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%66 = torch.aten.__isnot__ %60#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%69 = torch.prim.unchecked_cast %60#1 : !torch.optional<int> -> !torch.int
%70 = torch.aten.gt.int %60#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %70 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
%69 = torch.prim.unchecked_cast %60#1 : !torch.optional<int> -> !torch.int
%70 = torch.aten.remainder.int %57, %60#0 : !torch.int, !torch.int -> !torch.int
%71 = torch.aten.eq.int %70, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %71 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %68 : !torch.bool
}
%63 = torch.aten.__not__ %62 : !torch.bool -> !torch.bool
torch.prim.If %63 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%64 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%66 = torch.aten.__getitem__.t %49, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%67 = torch.aten.append.t %64, %66 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%65 = torch.aten.__isnot__ %60#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %65 -> () {
%66 = torch.prim.unchecked_cast %60#1 : !torch.optional<int> -> !torch.int
%67 = torch.aten.floordiv.int %57, %60#0 : !torch.int, !torch.int -> !torch.int
%68 = torch.aten._set_item.t %64, %66, %67 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %64 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%51 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %50, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %57, %int3, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%52 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %arg3, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %57, %int2, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%53 = torch.shape.calculate {
%57 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %51, %arg6 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%54 = torch.shape.calculate {
%57 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.size.int %52, %arg6 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.append.t %57, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%55 = torch.shape.calculate {
%57 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%58 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.sub.int %int3, %58 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.sub.int %int3, %58 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.ge.int %59, %int0 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.int) {
%71 = torch.aten.size.int %53, %59 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%63 = torch.aten.ge.int %60, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%71 = torch.aten.size.int %54, %60 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ne.int %62, %64 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.bool) {
%71 = torch.aten.ne.int %62, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %71 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%67 = torch.prim.If %66 -> (!torch.bool) {
%71 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %71 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %67 -> () {
%71 = torch.aten.format(%str_4, %62, %64, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%72 = torch.aten.add.str %str_5, %71 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %72, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%68 = torch.aten.eq.int %62, %int1 : !torch.int, !torch.int -> !torch.bool
%69 = torch.prim.If %68 -> (!torch.int) {
torch.prim.If.yield %64 : !torch.int
} else {
torch.prim.If.yield %62 : !torch.int
}
%70 = torch.aten.append.t %57, %69 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%56 = torch.shape.calculate {
%57 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%57 = torch.prim.ListConstruct %int1, %int1, %int128, %int384 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%59 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.sub.int %int3, %59 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int3, %59 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.ge.int %60, %int0 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.int) {
%72 = torch.aten.size.int %55, %60 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%64 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%65 = torch.prim.If %64 -> (!torch.int) {
%72 = torch.aten.__getitem__.t %57, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%66 = torch.aten.ne.int %63, %65 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%72 = torch.aten.ne.int %63, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %72 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
%72 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %72 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%72 = torch.aten.format(%str_4, %63, %65, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%73 = torch.aten.add.str %str_5, %72 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %73, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %63, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %65 : !torch.int
} else {
torch.prim.If.yield %63 : !torch.int
}
%71 = torch.aten.append.t %58, %70 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
}
// -----// IR Dump After SimplifyShapeCalculations (torch-simplify-shape-calculations) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%9 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%15 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.shape.calculate {
%57 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?],si64>
} shapes {
%57 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %arg1, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.add.int %10, %16 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.size.int %arg1, %int-1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%61 = torch.aten.add.int %60, %59 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.add.int %7, %13 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.size.int %arg1, %int-2 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%64 = torch.aten.add.int %63, %62 : !torch.int, !torch.int -> !torch.int
%65 = torch.prim.ListConstruct %64, %61 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %65 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%19 = torch.shape.calculate {
%57 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%20 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%21 = torch.shape.calculate {
%57 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%22 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.shape.calculate {
%57 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[3],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%25 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %28 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%31 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%32 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %26 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%33 = torch.shape.calculate {
%57 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %38 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%41 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %39 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%42 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %36 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%43 = torch.shape.calculate {
%57 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.shape.calculate {
%57 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?],si64>
} shapes {
%57 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.mul.int %int1, %57 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%60 = torch.aten.mul.int %58, %59 : !torch.int, !torch.int -> !torch.int
%61 = torch.prim.Uninitialized : !torch.int
%62 = torch.derefine %none : !torch.none to !torch.optional<int>
%63 = torch.aten.eq.int %34, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%74 = torch.derefine %int0 : !torch.int to !torch.optional<int>
torch.prim.If.yield %int1, %74 : !torch.int, !torch.optional<int>
} else {
%74 = torch.aten.ge.int %34, %int0 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.int) {
%76 = torch.aten.mul.int %int1, %34 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %76 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %75, %62 : !torch.int, !torch.optional<int>
}
%65 = torch.aten.eq.int %44, %int-1 : !torch.int, !torch.int -> !torch.bool
%66:2 = torch.prim.If %65 -> (!torch.int, !torch.optional<int>) {
%74 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %74 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%75 = torch.derefine %int1 : !torch.int to !torch.optional<int>
torch.prim.If.yield %64#0, %75 : !torch.int, !torch.optional<int>
} else {
%74 = torch.aten.ge.int %44, %int0 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.int) {
%76 = torch.aten.mul.int %64#0, %44 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %76 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %75, %64#1 : !torch.int, !torch.optional<int>
}
%67 = torch.aten.eq.int %46, %int-1 : !torch.int, !torch.int -> !torch.bool
%68:2 = torch.prim.If %67 -> (!torch.int, !torch.optional<int>) {
%74 = torch.aten.__isnot__ %66#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %74 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%75 = torch.derefine %int2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %66#0, %75 : !torch.int, !torch.optional<int>
} else {
%74 = torch.aten.ge.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.int) {
%76 = torch.aten.mul.int %66#0, %46 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %76 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %75, %66#1 : !torch.int, !torch.optional<int>
}
%69 = torch.aten.eq.int %60, %68#0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%74 = torch.aten.__isnot__ %68#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%75 = torch.prim.If %74 -> (!torch.bool) {
%77 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%78 = torch.aten.gt.int %68#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%76 = torch.prim.If %75 -> (!torch.bool) {
%77 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%78 = torch.aten.remainder.int %60, %68#0 : !torch.int, !torch.int -> !torch.int
%79 = torch.aten.eq.int %78, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %79 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %76 : !torch.bool
}
%71 = torch.aten.__not__ %70 : !torch.bool -> !torch.bool
torch.prim.If %71 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%72 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%73 = torch.aten.__isnot__ %68#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %73 -> () {
%74 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%75 = torch.aten.floordiv.int %60, %68#0 : !torch.int, !torch.int -> !torch.int
%76 = torch.aten._set_item.t %72, %74, %75 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %72 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%51 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%57 = torch.aten.size.int %50, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %50, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %50, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %59, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%52 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%57 = torch.aten.size.int %arg3, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %int1, %59 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%53 = torch.shape.calculate {
%57 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%57 = torch.aten.size.int %51, %int0 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %51, %int1 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %51, %int2 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %59, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%54 = torch.shape.calculate {
%57 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%57 = torch.aten.size.int %52, %int0 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %52, %int1 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %52, %int3 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %int1, %59 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%55 = torch.shape.calculate {
%57 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%57 = torch.aten.size.int %53, %int0 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%58 = torch.aten.size.int %54, %int0 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%59 = torch.aten.ne.int %57, %58 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.bool) {
%78 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%61 = torch.prim.If %60 -> (!torch.bool) {
%78 = torch.aten.ne.int %58, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %61 -> () {
%78 = torch.aten.format(%str_2, %57, %58, %int0) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%62 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.int) {
torch.prim.If.yield %58 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%64 = torch.aten.size.int %53, %int1 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%65 = torch.aten.size.int %54, %int1 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%66 = torch.aten.ne.int %64, %65 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%78 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
%78 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%78 = torch.aten.format(%str_2, %64, %65, %int1) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %65 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%71 = torch.aten.size.int %53, %int2 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %72 -> () {
%78 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%73 = torch.aten.eq.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %71 : !torch.int
}
%75 = torch.aten.size.int %54, %int3 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%76 = torch.aten.ne.int %int1, %75 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.ListConstruct %63, %70, %74, %75 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %77 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%56 = torch.shape.calculate {
%57 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%57 = torch.aten.size.int %55, %int0 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%58 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %58 -> () {
%78 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%59 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%61 = torch.aten.size.int %55, %int1 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %62 -> () {
%78 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%63 = torch.aten.eq.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %61 : !torch.int
}
%65 = torch.aten.size.int %55, %int2 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%66 = torch.aten.ne.int %65, %int128 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%78 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%78 = torch.aten.format(%str_2, %65, %int128, %int2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %int128 : !torch.int
} else {
torch.prim.If.yield %65 : !torch.int
}
%71 = torch.aten.size.int %55, %int3 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %71, %int384 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.bool) {
%78 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%74 = torch.prim.If %73 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %74 -> () {
%78 = torch.aten.format(%str_2, %71, %int384, %int3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%75 = torch.aten.eq.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
%76 = torch.prim.If %75 -> (!torch.int) {
torch.prim.If.yield %int384 : !torch.int
} else {
torch.prim.If.yield %71 : !torch.int
}
%77 = torch.prim.ListConstruct %60, %64, %70, %76 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %77 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After CSE (cse) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%9 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%15 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.shape.calculate {
%57 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?],si64>
} shapes {
%57 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %arg1, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.add.int %10, %16 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.size.int %arg1, %int-1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%61 = torch.aten.add.int %60, %59 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.add.int %7, %13 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.size.int %arg1, %int-2 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%64 = torch.aten.add.int %63, %62 : !torch.int, !torch.int -> !torch.int
%65 = torch.prim.ListConstruct %64, %61 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %65 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%19 = torch.shape.calculate {
%57 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%20 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%21 = torch.shape.calculate {
%57 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%22 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.shape.calculate {
%57 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[3],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%25 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %28 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%31 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%32 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %26 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%33 = torch.shape.calculate {
%57 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %38 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%41 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %39 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%42 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %36 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%43 = torch.shape.calculate {
%57 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.shape.calculate {
%57 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?],si64>
} shapes {
%57 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.mul.int %int1, %57 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%60 = torch.aten.mul.int %58, %59 : !torch.int, !torch.int -> !torch.int
%61 = torch.prim.Uninitialized : !torch.int
%62 = torch.derefine %none : !torch.none to !torch.optional<int>
%63 = torch.aten.eq.int %34, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%73 = torch.derefine %int0 : !torch.int to !torch.optional<int>
torch.prim.If.yield %int1, %73 : !torch.int, !torch.optional<int>
} else {
%73 = torch.aten.ge.int %34, %int0 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%75 = torch.aten.mul.int %int1, %34 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %74, %62 : !torch.int, !torch.optional<int>
}
%65 = torch.aten.eq.int %44, %int-1 : !torch.int, !torch.int -> !torch.bool
%66:2 = torch.prim.If %65 -> (!torch.int, !torch.optional<int>) {
%73 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %73 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.derefine %int1 : !torch.int to !torch.optional<int>
torch.prim.If.yield %64#0, %74 : !torch.int, !torch.optional<int>
} else {
%73 = torch.aten.ge.int %44, %int0 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%75 = torch.aten.mul.int %64#0, %44 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %74, %64#1 : !torch.int, !torch.optional<int>
}
%67 = torch.aten.eq.int %46, %int-1 : !torch.int, !torch.int -> !torch.bool
%68:2 = torch.prim.If %67 -> (!torch.int, !torch.optional<int>) {
%73 = torch.aten.__isnot__ %66#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %73 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.derefine %int2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %66#0, %74 : !torch.int, !torch.optional<int>
} else {
%73 = torch.aten.ge.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%75 = torch.aten.mul.int %66#0, %46 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %74, %66#1 : !torch.int, !torch.optional<int>
}
%69 = torch.aten.eq.int %60, %68#0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%73 = torch.aten.__isnot__ %68#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%76 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%77 = torch.aten.gt.int %68#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.prim.If %74 -> (!torch.bool) {
%76 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%77 = torch.aten.remainder.int %60, %68#0 : !torch.int, !torch.int -> !torch.int
%78 = torch.aten.eq.int %77, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %75 : !torch.bool
}
%71 = torch.aten.__not__ %70 : !torch.bool -> !torch.bool
torch.prim.If %71 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%72 = torch.aten.__isnot__ %68#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %72 -> () {
%73 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%74 = torch.aten.floordiv.int %60, %68#0 : !torch.int, !torch.int -> !torch.int
%75 = torch.aten._set_item.t %49, %73, %74 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %49 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%51 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%57 = torch.aten.size.int %50, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %50, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %50, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %59, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%52 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%57 = torch.aten.size.int %arg3, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %int1, %59 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%53 = torch.shape.calculate {
%57 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%57 = torch.aten.size.int %51, %int0 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %51, %int1 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %51, %int2 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %59, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%54 = torch.shape.calculate {
%57 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%57 = torch.aten.size.int %52, %int0 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %52, %int1 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %52, %int3 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %int1, %59 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%55 = torch.shape.calculate {
%57 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%57 = torch.aten.size.int %53, %int0 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%58 = torch.aten.size.int %54, %int0 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%59 = torch.aten.ne.int %57, %58 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.bool) {
%78 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%61 = torch.prim.If %60 -> (!torch.bool) {
%78 = torch.aten.ne.int %58, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %61 -> () {
%78 = torch.aten.format(%str_2, %57, %58, %int0) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%62 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.int) {
torch.prim.If.yield %58 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%64 = torch.aten.size.int %53, %int1 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%65 = torch.aten.size.int %54, %int1 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%66 = torch.aten.ne.int %64, %65 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%78 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
%78 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%78 = torch.aten.format(%str_2, %64, %65, %int1) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %65 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%71 = torch.aten.size.int %53, %int2 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %72 -> () {
%78 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%73 = torch.aten.eq.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %71 : !torch.int
}
%75 = torch.aten.size.int %54, %int3 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%76 = torch.aten.ne.int %int1, %75 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.ListConstruct %63, %70, %74, %75 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %77 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%56 = torch.shape.calculate {
%57 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%57 = torch.aten.size.int %55, %int0 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%58 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %58 -> () {
%78 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%59 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%61 = torch.aten.size.int %55, %int1 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %62 -> () {
%78 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%63 = torch.aten.eq.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %61 : !torch.int
}
%65 = torch.aten.size.int %55, %int2 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%66 = torch.aten.ne.int %65, %int128 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%78 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%78 = torch.aten.format(%str_2, %65, %int128, %int2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %int128 : !torch.int
} else {
torch.prim.If.yield %65 : !torch.int
}
%71 = torch.aten.size.int %55, %int3 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %71, %int384 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.bool) {
%78 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%74 = torch.prim.If %73 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %74 -> () {
%78 = torch.aten.format(%str_2, %71, %int384, %int3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%75 = torch.aten.eq.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
%76 = torch.prim.If %75 -> (!torch.int) {
torch.prim.If.yield %int384 : !torch.int
} else {
torch.prim.If.yield %71 : !torch.int
}
%77 = torch.prim.ListConstruct %60, %64, %70, %76 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %77 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After SimplifyShapeCalculations (torch-simplify-shape-calculations) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%9 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%15 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.shape.calculate {
%57 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?],si64>
} shapes {
%57 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %arg1, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.add.int %10, %16 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.size.int %arg1, %int-1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%61 = torch.aten.add.int %60, %59 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.add.int %7, %13 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.size.int %arg1, %int-2 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%64 = torch.aten.add.int %63, %62 : !torch.int, !torch.int -> !torch.int
%65 = torch.prim.ListConstruct %64, %61 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %65 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%19 = torch.shape.calculate {
%57 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%20 = torch.shape.calculate {
%57 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%21 = torch.shape.calculate {
%57 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%22 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.shape.calculate {
%57 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[3],si64>
} shapes {
%57 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%25 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %28 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%31 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%32 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %26 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%33 = torch.shape.calculate {
%57 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[],i1>
} shapes {
%57 = torch.aten.Float.Scalar %38 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],i1>
%41 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %39 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%42 = torch.shape.calculate {
%57 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.aten.Float.Scalar %36 : !torch.int -> !torch.float
%58 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %58 : !torch.list<int>
} : !torch.vtensor<[],si64>
%43 = torch.shape.calculate {
%57 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[],si64>
} shapes {
%57 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.shape.calculate {
%57 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[1],si64>
} shapes {
%57 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %57 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.shape.calculate {
%57 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?],si64>
} shapes {
%57 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.mul.int %int1, %57 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%60 = torch.aten.mul.int %58, %59 : !torch.int, !torch.int -> !torch.int
%61 = torch.prim.Uninitialized : !torch.int
%62 = torch.derefine %none : !torch.none to !torch.optional<int>
%63 = torch.aten.eq.int %34, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%73 = torch.derefine %int0 : !torch.int to !torch.optional<int>
torch.prim.If.yield %int1, %73 : !torch.int, !torch.optional<int>
} else {
%73 = torch.aten.ge.int %34, %int0 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%75 = torch.aten.mul.int %int1, %34 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %74, %62 : !torch.int, !torch.optional<int>
}
%65 = torch.aten.eq.int %44, %int-1 : !torch.int, !torch.int -> !torch.bool
%66:2 = torch.prim.If %65 -> (!torch.int, !torch.optional<int>) {
%73 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %73 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.derefine %int1 : !torch.int to !torch.optional<int>
torch.prim.If.yield %64#0, %74 : !torch.int, !torch.optional<int>
} else {
%73 = torch.aten.ge.int %44, %int0 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%75 = torch.aten.mul.int %64#0, %44 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %74, %64#1 : !torch.int, !torch.optional<int>
}
%67 = torch.aten.eq.int %46, %int-1 : !torch.int, !torch.int -> !torch.bool
%68:2 = torch.prim.If %67 -> (!torch.int, !torch.optional<int>) {
%73 = torch.aten.__isnot__ %66#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %73 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%74 = torch.derefine %int2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %66#0, %74 : !torch.int, !torch.optional<int>
} else {
%73 = torch.aten.ge.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%75 = torch.aten.mul.int %66#0, %46 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %61 : !torch.int
}
torch.prim.If.yield %74, %66#1 : !torch.int, !torch.optional<int>
}
%69 = torch.aten.eq.int %60, %68#0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%73 = torch.aten.__isnot__ %68#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.bool) {
%76 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%77 = torch.aten.gt.int %68#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %77 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%75 = torch.prim.If %74 -> (!torch.bool) {
%76 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%77 = torch.aten.remainder.int %60, %68#0 : !torch.int, !torch.int -> !torch.int
%78 = torch.aten.eq.int %77, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %75 : !torch.bool
}
%71 = torch.aten.__not__ %70 : !torch.bool -> !torch.bool
torch.prim.If %71 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%72 = torch.aten.__isnot__ %68#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %72 -> () {
%73 = torch.prim.unchecked_cast %68#1 : !torch.optional<int> -> !torch.int
%74 = torch.aten.floordiv.int %60, %68#0 : !torch.int, !torch.int -> !torch.int
%75 = torch.aten._set_item.t %49, %73, %74 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %49 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%51 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%57 = torch.aten.size.int %50, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %50, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %50, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %59, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%52 = torch.shape.calculate {
%57 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%57 = torch.aten.size.int %arg3, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %int1, %59 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%53 = torch.shape.calculate {
%57 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%57 = torch.aten.size.int %51, %int0 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %51, %int1 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %51, %int2 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %59, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%54 = torch.shape.calculate {
%57 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%57 = torch.aten.size.int %52, %int0 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%58 = torch.aten.size.int %52, %int1 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%59 = torch.aten.size.int %52, %int3 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%60 = torch.prim.ListConstruct %57, %58, %int1, %59 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%55 = torch.shape.calculate {
%57 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%57 = torch.aten.size.int %53, %int0 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%58 = torch.aten.size.int %54, %int0 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%59 = torch.aten.ne.int %57, %58 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.bool) {
%78 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%61 = torch.prim.If %60 -> (!torch.bool) {
%78 = torch.aten.ne.int %58, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %61 -> () {
%78 = torch.aten.format(%str_2, %57, %58, %int0) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%62 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.int) {
torch.prim.If.yield %58 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%64 = torch.aten.size.int %53, %int1 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%65 = torch.aten.size.int %54, %int1 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%66 = torch.aten.ne.int %64, %65 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%78 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
%78 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%78 = torch.aten.format(%str_2, %64, %65, %int1) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %65 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%71 = torch.aten.size.int %53, %int2 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %72 -> () {
%78 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%73 = torch.aten.eq.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %71 : !torch.int
}
%75 = torch.aten.size.int %54, %int3 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%76 = torch.aten.ne.int %int1, %75 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.ListConstruct %63, %70, %74, %75 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %77 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%56 = torch.shape.calculate {
%57 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %57 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%57 = torch.aten.size.int %55, %int0 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%58 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %58 -> () {
%78 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%59 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%61 = torch.aten.size.int %55, %int1 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %62 -> () {
%78 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%63 = torch.aten.eq.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %61 : !torch.int
}
%65 = torch.aten.size.int %55, %int2 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%66 = torch.aten.ne.int %65, %int128 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.bool) {
%78 = torch.aten.ne.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%68 = torch.prim.If %67 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %68 -> () {
%78 = torch.aten.format(%str_2, %65, %int128, %int2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %65, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %int128 : !torch.int
} else {
torch.prim.If.yield %65 : !torch.int
}
%71 = torch.aten.size.int %55, %int3 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %71, %int384 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.bool) {
%78 = torch.aten.ne.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %78 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%74 = torch.prim.If %73 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %74 -> () {
%78 = torch.aten.format(%str_2, %71, %int384, %int3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%79 = torch.aten.add.str %str_3, %78 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %79, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%75 = torch.aten.eq.int %71, %int1 : !torch.int, !torch.int -> !torch.bool
%76 = torch.prim.If %75 -> (!torch.int) {
torch.prim.If.yield %int384 : !torch.int
} else {
torch.prim.If.yield %71 : !torch.int
}
%77 = torch.prim.ListConstruct %60, %64, %70, %76 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %77 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After DropAbstractInterpCalculations (torch-drop-abstract-interp-calculations) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%9 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%15 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.prim.ListConstruct %10, %16, %7, %13 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.aten.constant_pad_nd %arg1, %17, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%19 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%20 = torch.aten.squeeze.dim %19, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%21 = torch.aten.div.Tensor %20, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%22 = torch.aten.unsqueeze %21, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%23 = torch.prim.ListConstruct %3, %22, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%24 = torch.aten.cat %23, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%25 = torch.aten.slice.Tensor %24, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%26 = torch.aten.item %25 : !torch.vtensor<[1],si64> -> !torch.int
%27 = torch.aten.eq.int %26, %int0 : !torch.int, !torch.int -> !torch.bool
%28 = torch.aten.Int.bool %27 : !torch.bool -> !torch.int
%29 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%30 = torch.prim.NumToTensor.Scalar %28 : !torch.int -> !torch.vtensor<[],i1>
%31 = torch.prim.NumToTensor.Scalar %29 : !torch.int -> !torch.vtensor<[],si64>
%32 = torch.prim.NumToTensor.Scalar %26 : !torch.int -> !torch.vtensor<[],si64>
%33 = torch.aten.where.self %30, %31, %32 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%34 = torch.aten.item %33 : !torch.vtensor<[],si64> -> !torch.int
%35 = torch.aten.slice.Tensor %24, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%36 = torch.aten.item %35 : !torch.vtensor<[1],si64> -> !torch.int
%37 = torch.aten.eq.int %36, %int0 : !torch.int, !torch.int -> !torch.bool
%38 = torch.aten.Int.bool %37 : !torch.bool -> !torch.int
%39 = torch.aten.size.int %18, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%40 = torch.prim.NumToTensor.Scalar %38 : !torch.int -> !torch.vtensor<[],i1>
%41 = torch.prim.NumToTensor.Scalar %39 : !torch.int -> !torch.vtensor<[],si64>
%42 = torch.prim.NumToTensor.Scalar %36 : !torch.int -> !torch.vtensor<[],si64>
%43 = torch.aten.where.self %40, %41, %42 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%44 = torch.aten.item %43 : !torch.vtensor<[],si64> -> !torch.int
%45 = torch.aten.slice.Tensor %24, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%46 = torch.aten.item %45 : !torch.vtensor<[1],si64> -> !torch.int
%47 = torch.aten.eq.int %46, %int0 : !torch.int, !torch.int -> !torch.bool
%48 = torch.aten.Int.bool %47 : !torch.bool -> !torch.int
%49 = torch.prim.ListConstruct %34, %44, %46 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.aten.view %18, %49 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%51 = torch.aten.unsqueeze %50, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%52 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%53 = torch.aten.to.dtype %51, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%54 = torch.aten.to.dtype %52, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%55 = torch.aten.logical_and %53, %54 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%56 = torch.aten.logical_and %55, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %56 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ScalarizeShapes (torch-scalarize-shapes) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ReifyShapeCalculations (torch-reify-shape-calculations) //----- //
module {
func.func private @__torch__.pad_shape_fn(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension"
%none = torch.constant.none
%str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values"
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %2 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %6 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg2: !torch.int):
%9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %arg0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.constant_pad_nd(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%3 = torch.aten.neg.int %2 : !torch.int -> !torch.int
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %18 : !torch.bool
}
%7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
%18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %18 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%10 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%11 = torch.prim.Loop %10, %true, init(%int1) {
^bb0(%arg3: !torch.int, %arg4: !torch.int):
%18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%19 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%12 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %13 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
}
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg3: !torch.int):
%18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %18 -> () {
%19 = torch.aten.append.t %16, %11 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %16 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.squeeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %12 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %11, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %12 -> () {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %14 -> () {
%15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.append.t %0, %15 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %0, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.squeeze.dim(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.div.Tensor(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%none = torch.constant.none
%str_1 = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %0, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%2 = torch.derefine %none : !torch.none to !torch.optional<int>
%3 = torch.prim.Loop %1, %true, init(%2) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.bool) {
%19 = torch.aten.__getitem__.t %13, %int0 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.optional<int>) {
%19 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%20 = torch.prim.If %19 -> (!torch.int) {
%22 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.prim.If %23 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %22 : !torch.int
}
%25 = torch.aten.neg.int %24 : !torch.int -> !torch.int
%26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool
%28 = torch.prim.If %27 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %32 : !torch.bool
}
%29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool
torch.prim.If %29 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%31 = torch.prim.If %30 -> (!torch.int) {
%32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %32 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
torch.prim.If.yield %31 : !torch.int
} else {
%22 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %22 : !torch.int
}
%21 = torch.derefine %20 : !torch.int to !torch.optional<int>
torch.prim.If.yield %21 : !torch.optional<int>
} else {
torch.prim.If.yield %arg3 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%18 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%4 = torch.aten.__is__ %3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.int) {
torch.prim.If.yield %arg1 : !torch.int
} else {
%13 = torch.prim.unchecked_cast %3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %13 : !torch.int
}
%6 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%9 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%10 = torch.prim.Loop %8, %true, init(%9) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%21 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.bool) {
%20 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %21 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.optional<list<int>>) {
%20 = torch.derefine %13 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %20 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg3 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%19 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%11 = torch.aten.__is__ %10, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.list<int>) {
%13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %13 : !torch.list<int>
} else {
%13 = torch.prim.unchecked_cast %10 : !torch.optional<list<int>> -> !torch.list<int>
%14 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int0) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%20 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%21 = torch.prim.Loop %20, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%27 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
%26 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %27 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
%26 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%27 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %28 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %29, %true, init() {
^bb0(%arg4: !torch.int):
%32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %33 -> () {
%34 = torch.aten.__getitem__.t %13, %32 : !torch.list<int>, !torch.int -> !torch.int
%35 = torch.aten.__getitem__.t %19, %32 : !torch.list<int>, !torch.int -> !torch.int
%36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %36 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%30 = torch.aten.__getitem__.t %19, %5 : !torch.list<int>, !torch.int -> !torch.int
%31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %31 : !torch.int
} else {
torch.prim.If.yield %arg3 : !torch.int
}
torch.prim.Loop.condition %true, iter(%25 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg2: !torch.int):
%19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%18 = torch.aten._set_item.t %16, %5, %15 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %16 : !torch.list<int>
}
return %12 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%arg0: !torch.float) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.where.self(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
%1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.slice(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%int9223372036854775807 = torch.constant.int 9223372036854775807
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %1 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %33 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int0 : !torch.int
}
%13 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int9223372036854775807 : !torch.int
}
%15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
torch.prim.If.yield %12 : !torch.int
}
%18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %17 : !torch.int
}
%20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
%21 = torch.prim.If %20 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %14 : !torch.int
}
%22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %19 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
torch.prim.If.yield %23 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %21 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int
%27 = torch.prim.ListConstruct : () -> !torch.list<int>
%28 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %28, %true, init() {
^bb0(%arg5: !torch.int):
%33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.append.t %27, %33 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int
%32 = torch.aten._set_item.t %27, %10, %31 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
return %27 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.slice.Tensor(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: invalid shape"
%false = torch.constant.bool false
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: only one dimension can be inferred"
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.prim.Loop %0, %true, init(%int1) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%13 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%2 = torch.prim.Uninitialized : !torch.int
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.derefine %none : !torch.none to !torch.optional<int>
%5:2 = torch.prim.Loop %3, %true, init(%int1, %4) {
^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional<int>):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool
%14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional<int>) {
%15 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%16 = torch.derefine %arg2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional<int>
} else {
%15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %2 : !torch.int
}
torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %16 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%14 = torch.prim.If %13 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %17 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %14 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
torch.prim.Loop %10, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %11 -> () {
%12 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten._set_item.t %9, %12, %13 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
return %9 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %13 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.prim.ListConstruct : () -> !torch.list<int>
%12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %12, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %11, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %11, %10, %int1 : !torch.list<int>, !torch.int, !torch.int
return %11 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg1: !torch.int):
%2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
%3 = torch.aten.append.t %0, %2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.to.dtype(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func private @__torch__.torch.jit._shape_functions.broadcast(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
func.func private @__torch_mlir_shape_fn.aten.logical_and(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%54 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%54 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%54 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %arg4 : !torch.vtensor<[4],si64> -> !torch.list<int>
%54 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int4 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.shape.calculate {
%53 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?],si64>
} shapes {
%53 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%54 = torch.aten.Float.Scalar %int0 : !torch.int -> !torch.float
%55 = func.call @__torch_mlir_shape_fn.aten.constant_pad_nd(%53, %13, %54) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%15 = torch.shape.calculate {
%53 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %arg5 : !torch.vtensor<[2],si64> -> !torch.list<int>
%54 = torch.aten.size %0 : !torch.vtensor<[1],si64> -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.index_select(%53, %int0, %54) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%16 = torch.shape.calculate {
%53 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.size %15 : !torch.vtensor<[1],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%53, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%17 = torch.shape.calculate {
%53 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.size %16 : !torch.vtensor<[],si64> -> !torch.list<int>
%54 = torch.aten.size %4 : !torch.vtensor<[],si64> -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.div.Tensor(%53, %54) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[],si64>
%18 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %17 : !torch.vtensor<[],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%53, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.shape.calculate {
%53 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[3],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
%54 = torch.aten.len.t %19 : !torch.list<vtensor> -> !torch.int
%true_4 = torch.constant.bool true
torch.prim.Loop %54, %true_4, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %19, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%57 = torch.aten.size %56 : !torch.vtensor -> !torch.list<int>
%58 = torch.aten.append.t %53, %57 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true_4, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = func.call @__torch_mlir_shape_fn.aten.cat(%53, %int0) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%21 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %20 : !torch.vtensor<[3],si64> -> !torch.list<int>
%54 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %24 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%27 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %25 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%28 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %22 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%29 = torch.shape.calculate {
%53 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.size %26 : !torch.vtensor<[],i1> -> !torch.list<int>
%54 = torch.aten.size %27 : !torch.vtensor<[],si64> -> !torch.list<int>
%55 = torch.aten.size %28 : !torch.vtensor<[],si64> -> !torch.list<int>
%56 = func.call @__torch_mlir_shape_fn.aten.where.self(%53, %54, %55) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %20 : !torch.vtensor<[3],si64> -> !torch.list<int>
%54 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %34 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%37 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %35 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%38 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %32 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%39 = torch.shape.calculate {
%53 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.size %36 : !torch.vtensor<[],i1> -> !torch.list<int>
%54 = torch.aten.size %37 : !torch.vtensor<[],si64> -> !torch.list<int>
%55 = torch.aten.size %38 : !torch.vtensor<[],si64> -> !torch.list<int>
%56 = func.call @__torch_mlir_shape_fn.aten.where.self(%53, %54, %55) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.aten.size %20 : !torch.vtensor<[3],si64> -> !torch.list<int>
%54 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.shape.calculate {
%53 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?],si64>
} shapes {
%53 = torch.aten.size %14 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.view(%53, %45) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%47 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%53 = torch.aten.size %46 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%53, %int3) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%48 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%53 = torch.aten.size %arg3 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%53, %int2) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%49 = torch.shape.calculate {
%53 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%53 = torch.aten.size %47 : !torch.vtensor<[?,?,?,1],si64> -> !torch.list<int>
%54 = torch.derefine %none : !torch.none to !torch.optional<int>
%55 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%53, %int11, %false, %false, %54) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%50 = torch.shape.calculate {
%53 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%53 = torch.aten.size %48 : !torch.vtensor<[?,?,1,?],si64> -> !torch.list<int>
%54 = torch.derefine %none : !torch.none to !torch.optional<int>
%55 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%53, %int11, %false, %false, %54) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%51 = torch.shape.calculate {
%53 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%53 = torch.aten.size %49 : !torch.vtensor<[?,?,?,1],i1> -> !torch.list<int>
%54 = torch.aten.size %50 : !torch.vtensor<[?,?,1,?],i1> -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.logical_and(%53, %54) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%52 = torch.shape.calculate {
%53 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%53 = torch.aten.size %51 : !torch.vtensor<[?,?,?,?],i1> -> !torch.list<int>
%54 = torch.aten.size %1 : !torch.vtensor<[1,1,128,384],i1> -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.logical_and(%53, %54) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.pad_shape_fn(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension"
%none = torch.constant.none
%str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values"
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %2 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %6 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg2: !torch.int):
%9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %arg0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.constant_pad_nd(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.constant_pad_nd(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%true = torch.constant.bool true
%str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension"
%none = torch.constant.none
%str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values"
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %2 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %6 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg3: !torch.int):
%9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %arg0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%3 = torch.aten.neg.int %2 : !torch.int -> !torch.int
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %18 : !torch.bool
}
%7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
%18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %18 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%10 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%11 = torch.prim.Loop %10, %true, init(%int1) {
^bb0(%arg3: !torch.int, %arg4: !torch.int):
%18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%19 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%12 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %13 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
}
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg3: !torch.int):
%18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %18 -> () {
%19 = torch.aten.append.t %16, %11 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %16 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.index_select(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%3 = torch.aten.neg.int %2 : !torch.int -> !torch.int
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %18 : !torch.bool
}
%7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.int) {
%18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %18 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%10 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%11 = torch.prim.Loop %10, %true, init(%int1) {
^bb0(%arg3: !torch.int, %arg4: !torch.int):
%18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%19 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%12 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %13 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%18 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %19 : !torch.bool
}
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg3: !torch.int):
%18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %18 -> () {
%19 = torch.aten.append.t %16, %11 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %16 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.squeeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %12 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %11, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %12 -> () {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %14 -> () {
%15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.append.t %0, %15 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %0, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.squeeze.dim(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.squeeze.dim(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %12 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %11, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %12 -> () {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %14 -> () {
%15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.append.t %0, %15 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %0, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.broadcast(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.div.Tensor(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.div.Tensor(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%none = torch.constant.none
%str_1 = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %0, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%2 = torch.derefine %none : !torch.none to !torch.optional<int>
%3 = torch.prim.Loop %1, %true, init(%2) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.bool) {
%19 = torch.aten.__getitem__.t %13, %int0 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.optional<int>) {
%19 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%20 = torch.prim.If %19 -> (!torch.int) {
%22 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.prim.If %23 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %22 : !torch.int
}
%25 = torch.aten.neg.int %24 : !torch.int -> !torch.int
%26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool
%28 = torch.prim.If %27 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %32 : !torch.bool
}
%29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool
torch.prim.If %29 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%31 = torch.prim.If %30 -> (!torch.int) {
%32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %32 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
torch.prim.If.yield %31 : !torch.int
} else {
%22 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %22 : !torch.int
}
%21 = torch.derefine %20 : !torch.int to !torch.optional<int>
torch.prim.If.yield %21 : !torch.optional<int>
} else {
torch.prim.If.yield %arg3 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%18 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%4 = torch.aten.__is__ %3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.int) {
torch.prim.If.yield %arg1 : !torch.int
} else {
%13 = torch.prim.unchecked_cast %3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %13 : !torch.int
}
%6 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%9 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%10 = torch.prim.Loop %8, %true, init(%9) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%21 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.bool) {
%20 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %21 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.optional<list<int>>) {
%20 = torch.derefine %13 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %20 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg3 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%19 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%11 = torch.aten.__is__ %10, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.list<int>) {
%13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %13 : !torch.list<int>
} else {
%13 = torch.prim.unchecked_cast %10 : !torch.optional<list<int>> -> !torch.list<int>
%14 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int0) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%20 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%21 = torch.prim.Loop %20, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%27 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
%26 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %27 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
%26 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%27 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %28 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %29, %true, init() {
^bb0(%arg4: !torch.int):
%32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %33 -> () {
%34 = torch.aten.__getitem__.t %13, %32 : !torch.list<int>, !torch.int -> !torch.int
%35 = torch.aten.__getitem__.t %19, %32 : !torch.list<int>, !torch.int -> !torch.int
%36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %36 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%30 = torch.aten.__getitem__.t %19, %5 : !torch.list<int>, !torch.int -> !torch.int
%31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %31 : !torch.int
} else {
torch.prim.If.yield %arg3 : !torch.int
}
torch.prim.Loop.condition %true, iter(%25 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg2: !torch.int):
%19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%18 = torch.aten._set_item.t %16, %5, %15 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %16 : !torch.list<int>
}
return %12 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.cat(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%none = torch.constant.none
%str_1 = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %0, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%2 = torch.derefine %none : !torch.none to !torch.optional<int>
%3 = torch.prim.Loop %1, %true, init(%2) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.bool) {
%19 = torch.aten.__getitem__.t %13, %int0 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.optional<int>) {
%19 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%20 = torch.prim.If %19 -> (!torch.int) {
%22 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.prim.If %23 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %22 : !torch.int
}
%25 = torch.aten.neg.int %24 : !torch.int -> !torch.int
%26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool
%28 = torch.prim.If %27 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %32 : !torch.bool
}
%29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool
torch.prim.If %29 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%31 = torch.prim.If %30 -> (!torch.int) {
%32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %32 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
torch.prim.If.yield %31 : !torch.int
} else {
%22 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %22 : !torch.int
}
%21 = torch.derefine %20 : !torch.int to !torch.optional<int>
torch.prim.If.yield %21 : !torch.optional<int>
} else {
torch.prim.If.yield %arg3 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%18 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%4 = torch.aten.__is__ %3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%5 = torch.prim.If %4 -> (!torch.int) {
torch.prim.If.yield %arg1 : !torch.int
} else {
%13 = torch.prim.unchecked_cast %3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %13 : !torch.int
}
%6 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %7 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%9 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%10 = torch.prim.Loop %8, %true, init(%9) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%14 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%21 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.bool) {
%20 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %21 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.optional<list<int>>) {
%20 = torch.derefine %13 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %20 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg3 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%19 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%11 = torch.aten.__is__ %10, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.list<int>) {
%13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %13 : !torch.list<int>
} else {
%13 = torch.prim.unchecked_cast %10 : !torch.optional<list<int>> -> !torch.list<int>
%14 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int
%15 = torch.prim.Loop %14, %true, init(%int0) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%20 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%21 = torch.prim.Loop %20, %true, init(%int1) {
^bb0(%arg4: !torch.int, %arg5: !torch.int):
%26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list<int>, !torch.int -> !torch.int
%27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%27 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.bool) {
%26 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %27 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
%26 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
%27 = torch.aten.len.t %19 : !torch.list<int> -> !torch.int
%28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %28 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %29, %true, init() {
^bb0(%arg4: !torch.int):
%32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %33 -> () {
%34 = torch.aten.__getitem__.t %13, %32 : !torch.list<int>, !torch.int -> !torch.int
%35 = torch.aten.__getitem__.t %19, %32 : !torch.list<int>, !torch.int -> !torch.int
%36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %36 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%30 = torch.aten.__getitem__.t %19, %5 : !torch.list<int>, !torch.int -> !torch.int
%31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %31 : !torch.int
} else {
torch.prim.If.yield %arg3 : !torch.int
}
torch.prim.Loop.condition %true, iter(%25 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%16 = torch.prim.ListConstruct : () -> !torch.list<int>
%17 = torch.aten.len.t %13 : !torch.list<int> -> !torch.int
torch.prim.Loop %17, %true, init() {
^bb0(%arg2: !torch.int):
%19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%20 = torch.aten.append.t %16, %19 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%18 = torch.aten._set_item.t %16, %5, %15 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %16 : !torch.list<int>
}
return %12 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%arg0: !torch.float) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.where.self(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
%1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %1 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.where.self(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg3: !torch.int):
%8 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %arg2, %13 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.bool) {
%24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%20 = torch.prim.If %19 -> (!torch.bool) {
%24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %20 -> () {
%24 = torch.aten.format(%str_0, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%25 = torch.aten.add.str %str, %24 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %25, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
%22 = torch.prim.If %21 -> (!torch.int) {
torch.prim.If.yield %17 : !torch.int
} else {
torch.prim.If.yield %15 : !torch.int
}
%23 = torch.aten.append.t %3, %22 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%5 = torch.aten.len.t %3 : !torch.list<int> -> !torch.int
%6 = torch.prim.max.int %4, %5 : !torch.int, !torch.int -> !torch.int
%7 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %6, %true, init() {
^bb0(%arg3: !torch.int):
%8 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int
%11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %arg0, %11 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%24 = torch.aten.__getitem__.t %3, %13 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %24 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.bool) {
%24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%20 = torch.prim.If %19 -> (!torch.bool) {
%24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %24 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %20 -> () {
%24 = torch.aten.format(%str_0, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%25 = torch.aten.add.str %str, %24 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %25, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool
%22 = torch.prim.If %21 -> (!torch.int) {
torch.prim.If.yield %17 : !torch.int
} else {
torch.prim.If.yield %15 : !torch.int
}
%23 = torch.aten.append.t %7, %22 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %7 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.slice(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%int9223372036854775807 = torch.constant.int 9223372036854775807
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %1 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %33 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int0 : !torch.int
}
%13 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int9223372036854775807 : !torch.int
}
%15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
torch.prim.If.yield %12 : !torch.int
}
%18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %17 : !torch.int
}
%20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
%21 = torch.prim.If %20 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %14 : !torch.int
}
%22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %19 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
torch.prim.If.yield %23 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %21 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int
%27 = torch.prim.ListConstruct : () -> !torch.list<int>
%28 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %28, %true, init() {
^bb0(%arg5: !torch.int):
%33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.append.t %27, %33 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int
%32 = torch.aten._set_item.t %27, %10, %31 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
return %27 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.slice.Tensor(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.slice.Tensor(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {
%int9223372036854775807 = torch.constant.int 9223372036854775807
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %1 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %0 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %33 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool
%12 = torch.prim.If %11 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int0 : !torch.int
}
%13 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.int) {
%33 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %33 : !torch.int
} else {
torch.prim.If.yield %int9223372036854775807 : !torch.int
}
%15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
torch.prim.If.yield %12 : !torch.int
}
%18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %17 : !torch.int
}
%20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
%21 = torch.prim.If %20 -> (!torch.int) {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %34 : !torch.int
} else {
torch.prim.If.yield %14 : !torch.int
}
%22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %19 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool
%25 = torch.prim.If %24 -> (!torch.int) {
torch.prim.If.yield %23 : !torch.int
} else {
%33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool
%35 = torch.prim.If %34 -> (!torch.int) {
%36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %36 : !torch.int
} else {
torch.prim.If.yield %21 : !torch.int
}
torch.prim.If.yield %35 : !torch.int
}
%26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int
%27 = torch.prim.ListConstruct : () -> !torch.list<int>
%28 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %28, %true, init() {
^bb0(%arg5: !torch.int):
%33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%34 = torch.aten.append.t %27, %33 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int
%30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int
%31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int
%32 = torch.aten._set_item.t %27, %10, %31 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
return %27 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: invalid shape"
%false = torch.constant.bool false
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: only one dimension can be inferred"
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.prim.Loop %0, %true, init(%int1) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%13 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%2 = torch.prim.Uninitialized : !torch.int
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.derefine %none : !torch.none to !torch.optional<int>
%5:2 = torch.prim.Loop %3, %true, init(%int1, %4) {
^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional<int>):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool
%14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional<int>) {
%15 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%16 = torch.derefine %arg2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional<int>
} else {
%15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %2 : !torch.int
}
torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %16 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%14 = torch.prim.If %13 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %17 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %14 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
torch.prim.Loop %10, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %11 -> () {
%12 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten._set_item.t %9, %12, %13 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
return %9 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: invalid shape"
%false = torch.constant.bool false
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: only one dimension can be inferred"
%int-1 = torch.constant.int -1
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.prim.Loop %0, %true, init(%int1) {
^bb0(%arg2: !torch.int, %arg3: !torch.int):
%12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%13 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%2 = torch.prim.Uninitialized : !torch.int
%3 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%4 = torch.derefine %none : !torch.none to !torch.optional<int>
%5:2 = torch.prim.Loop %3, %true, init(%int1, %4) {
^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional<int>):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool
%14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional<int>) {
%15 = torch.aten.__isnot__ %arg4, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %15 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%16 = torch.derefine %arg2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional<int>
} else {
%15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.int) {
%18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %2 : !torch.int
}
torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %16 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%14 = torch.prim.If %13 -> (!torch.bool) {
%15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %17 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %14 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%9 = torch.prim.ListConstruct : () -> !torch.list<int>
%10 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
torch.prim.Loop %10, %true, init() {
^bb0(%arg2: !torch.int):
%12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.append.t %9, %12 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%11 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %11 -> () {
%12 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int
%13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten._set_item.t %9, %12, %13 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
return %9 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %13 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.prim.ListConstruct : () -> !torch.list<int>
%12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %12, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %11, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %11, %10, %int1 : !torch.list<int>, !torch.int, !torch.int
return %11 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.unsqueeze(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.prim.If %2 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %1 : !torch.int
}
%4 = torch.aten.neg.int %3 : !torch.int -> !torch.int
%5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool
%7 = torch.prim.If %6 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %13 : !torch.bool
}
%8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool
torch.prim.If %8 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.prim.If %9 -> (!torch.int) {
%13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %arg1 : !torch.int
}
%11 = torch.prim.ListConstruct : () -> !torch.list<int>
%12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %12, %true, init() {
^bb0(%arg2: !torch.int):
%13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<int>, !torch.int -> !torch.int
%14 = torch.aten.append.t %11, %13 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %11, %10, %int1 : !torch.list<int>, !torch.int, !torch.int
return %11 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg1: !torch.int):
%2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
%3 = torch.aten.append.t %0, %2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.to.dtype(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.to.dtype(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg5: !torch.int):
%2 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
%3 = torch.aten.append.t %0, %2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.logical_and(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func private @__torch_mlir_shape_fn.aten.logical_and(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%false = torch.constant.bool false
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %2, %true, init() {
^bb0(%arg2: !torch.int):
%4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int
%6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int
%7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.prim.If %10 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
%20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %20 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool
%15 = torch.prim.If %14 -> (!torch.bool) {
%20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%16 = torch.prim.If %15 -> (!torch.bool) {
%20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %20 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %16 -> () {
%20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %21, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool
%18 = torch.prim.If %17 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %11 : !torch.int
}
%19 = torch.aten.append.t %3, %18 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %3 : !torch.list<int>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%float0.000000e00 = torch.constant.float 0.000000e+00
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int4 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.shape.calculate {
%53 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?],si64>
} shapes {
%53 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.constant_pad_nd(%53, %13, %float0.000000e00) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%15 = torch.shape.calculate {
%53 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.index_select(%53, %int0, %54) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%16 = torch.shape.calculate {
%53 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.squeeze.dim(%53, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%17 = torch.shape.calculate {
%53 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.div.Tensor(%53, %54) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[],si64>
%18 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%53, %int0) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.shape.calculate {
%53 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[3],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%55 = torch.aten.__getitem__.t %19, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%56 = torch.aten.size %55 : !torch.vtensor -> !torch.list<int>
%57 = torch.aten.append.t %53, %56 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%54 = func.call @__torch_mlir_shape_fn.aten.cat(%53, %int0) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%21 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.derefine %int0 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %24 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%27 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %25 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%28 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %22 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%29 = torch.shape.calculate {
%53 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
%56 = func.call @__torch_mlir_shape_fn.aten.where.self(%53, %54, %55) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.derefine %int1 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %34 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%37 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %35 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%38 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %32 : !torch.int -> !torch.float
%54 = func.call @__torch_mlir_shape_fn.prim.NumToTensor.Scalar(%53) : (!torch.float) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%39 = torch.shape.calculate {
%53 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
%56 = func.call @__torch_mlir_shape_fn.aten.where.self(%53, %54, %55) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.derefine %int2 : !torch.int to !torch.optional<int>
%55 = torch.derefine %int3 : !torch.int to !torch.optional<int>
%56 = func.call @__torch_mlir_shape_fn.aten.slice.Tensor(%53, %int0, %54, %55, %int1) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.shape.calculate {
%53 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?],si64>
} shapes {
%53 = torch.aten.size %14 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.view(%53, %45) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%47 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%53 = torch.aten.size %46 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%53, %int3) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%48 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%53 = torch.aten.size %arg3 : !torch.vtensor<[?,?,?],si64> -> !torch.list<int>
%54 = func.call @__torch_mlir_shape_fn.aten.unsqueeze(%53, %int2) : (!torch.list<int>, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%49 = torch.shape.calculate {
%53 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%53 = torch.aten.size %47 : !torch.vtensor<[?,?,?,1],si64> -> !torch.list<int>
%54 = torch.derefine %none : !torch.none to !torch.optional<int>
%55 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%53, %int11, %false, %false, %54) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%50 = torch.shape.calculate {
%53 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%53 = torch.aten.size %48 : !torch.vtensor<[?,?,1,?],si64> -> !torch.list<int>
%54 = torch.derefine %none : !torch.none to !torch.optional<int>
%55 = func.call @__torch_mlir_shape_fn.aten.to.dtype(%53, %int11, %false, %false, %54) : (!torch.list<int>, !torch.int, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%51 = torch.shape.calculate {
%53 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%53 = torch.aten.size %49 : !torch.vtensor<[?,?,?,1],i1> -> !torch.list<int>
%54 = torch.aten.size %50 : !torch.vtensor<[?,?,1,?],i1> -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.logical_and(%53, %54) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%52 = torch.shape.calculate {
%53 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%53 = torch.aten.size %51 : !torch.vtensor<[?,?,?,?],i1> -> !torch.list<int>
%54 = torch.prim.ListConstruct %int1, %int1, %int128, %int384 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%55 = func.call @__torch_mlir_shape_fn.aten.logical_and(%53, %54) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%str_3 = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_4 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_5 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.shape.calculate {
%53 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?],si64>
} shapes {
%53 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
torch.prim.Loop %int2, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.add.int %arg6, %int1 : !torch.int, !torch.int -> !torch.int
%55 = torch.aten.neg.int %54 : !torch.int -> !torch.int
%56 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.__getitem__.t %13, %56 : !torch.list<int>, !torch.int -> !torch.int
%58 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.add.int %58, %int1 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.__getitem__.t %13, %59 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.add.int %57, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.size.int %arg1, %55 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%63 = torch.aten.add.int %62, %61 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten._set_item.t %53, %55, %63 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%15 = torch.shape.calculate {
%53 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%55 = torch.prim.Loop %int1, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%57 = torch.aten.__getitem__.t %54, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%58 = torch.aten.mul.int %arg7, %57 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%58 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%56 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%57 = torch.aten.eq.int %int0, %arg6 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %57 -> () {
%58 = torch.aten.append.t %56, %55 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%58 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%59 = torch.aten.append.t %56, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%16 = torch.shape.calculate {
%53 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%55 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %55 -> () {
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.ne.int %56, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %57 -> () {
%58 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%59 = torch.aten.append.t %54, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%17 = torch.shape.calculate {
%53 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.sub.int %int-1, %56 : !torch.int, !torch.int -> !torch.int
%58 = torch.aten.sub.int %int-1, %56 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.ge.int %57, %int0 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
%69 = torch.aten.__getitem__.t %53, %57 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %69 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%61 = torch.aten.ge.int %58, %int0 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.int) {
%69 = torch.aten.__getitem__.t %54, %58 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %69 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%63 = torch.aten.ne.int %60, %62 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.bool) {
%69 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %69 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%65 = torch.prim.If %64 -> (!torch.bool) {
%69 = torch.aten.ne.int %62, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %69 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %65 -> () {
%69 = torch.aten.format(%str_4, %60, %62, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%70 = torch.aten.add.str %str_5, %69 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %70, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%66 = torch.aten.eq.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.int) {
torch.prim.If.yield %62 : !torch.int
} else {
torch.prim.If.yield %60 : !torch.int
}
%68 = torch.aten.append.t %55, %67 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[],si64>
%18 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%55 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%56 = torch.aten.append.t %54, %55 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.shape.calculate {
%53 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[3],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%67 = torch.aten.__getitem__.t %19, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%68 = torch.aten.size %67 : !torch.vtensor -> !torch.list<int>
%69 = torch.aten.append.t %53, %68 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%54 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %54, %true, init() {
^bb0(%arg6: !torch.int):
%67 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%68 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%69 = torch.aten.gt.int %68, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%56 = torch.derefine %none : !torch.none to !torch.optional<int>
%57 = torch.prim.Loop %55, %true, init(%56) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<int>):
%67 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%68 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%69 = torch.aten.eq.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%73 = torch.aten.__getitem__.t %67, %int0 : !torch.list<int>, !torch.int -> !torch.int
%74 = torch.aten.eq.int %73, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.aten.__not__ %70 : !torch.bool -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.optional<int>) {
%73 = torch.aten.__is__ %arg7, %none : !torch.optional<int>, !torch.none -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%76 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%77 = torch.aten.le.int %76, %int0 : !torch.int, !torch.int -> !torch.bool
%78 = torch.prim.If %77 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %76 : !torch.int
}
%79 = torch.aten.neg.int %78 : !torch.int -> !torch.int
%80 = torch.aten.sub.int %78, %int1 : !torch.int, !torch.int -> !torch.int
%81 = torch.aten.lt.int %int0, %79 : !torch.int, !torch.int -> !torch.bool
%82 = torch.prim.If %81 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%84 = torch.aten.gt.int %int0, %80 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %84 : !torch.bool
}
%83 = torch.aten.__not__ %82 : !torch.bool -> !torch.bool
torch.prim.If %83 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield %int0 : !torch.int
} else {
%76 = torch.prim.unchecked_cast %arg7 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %76 : !torch.int
}
%75 = torch.derefine %74 : !torch.int to !torch.optional<int>
torch.prim.If.yield %75 : !torch.optional<int>
} else {
torch.prim.If.yield %arg7 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%72 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%58 = torch.aten.__is__ %57, %none : !torch.optional<int>, !torch.none -> !torch.bool
%59 = torch.prim.If %58 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%67 = torch.prim.unchecked_cast %57 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %67 : !torch.int
}
%60 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%61 = torch.aten.gt.int %60, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%62 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%63 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%64 = torch.prim.Loop %62, %true, init(%63) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<list<int>>):
%67 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%68 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%69 = torch.prim.Loop %68, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%74 = torch.aten.__getitem__.t %67, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%75 = torch.aten.mul.int %arg9, %74 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%75 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%70 = torch.aten.eq.int %69, %int0 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.bool) {
%74 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%75 = torch.aten.eq.int %74, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%72 = torch.aten.__not__ %71 : !torch.bool -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.optional<list<int>>) {
%74 = torch.derefine %67 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %74 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg7 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%73 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%65 = torch.aten.__is__ %64, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.list<int>) {
%67 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %67 : !torch.list<int>
} else {
%67 = torch.prim.unchecked_cast %64 : !torch.optional<list<int>> -> !torch.list<int>
%68 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%69 = torch.prim.Loop %68, %true, init(%int0) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%73 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%74 = torch.aten.len.t %73 : !torch.list<int> -> !torch.int
%75 = torch.prim.Loop %74, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%80 = torch.aten.__getitem__.t %73, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%81 = torch.aten.mul.int %arg9, %80 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%81 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%76 = torch.aten.eq.int %75, %int0 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.bool) {
%80 = torch.aten.len.t %73 : !torch.list<int> -> !torch.int
%81 = torch.aten.eq.int %80, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %81 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%78 = torch.aten.__not__ %77 : !torch.bool -> !torch.bool
%79 = torch.prim.If %78 -> (!torch.int) {
%80 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%81 = torch.aten.len.t %73 : !torch.list<int> -> !torch.int
%82 = torch.aten.eq.int %80, %81 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %82 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%83 = torch.aten.__range_length %int0, %80, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %83, %true, init() {
^bb0(%arg8: !torch.int):
%86 = torch.aten.__derive_index %arg8, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%87 = torch.aten.ne.int %86, %59 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %87 -> () {
%88 = torch.aten.__getitem__.t %67, %86 : !torch.list<int>, !torch.int -> !torch.int
%89 = torch.aten.__getitem__.t %73, %86 : !torch.list<int>, !torch.int -> !torch.int
%90 = torch.aten.eq.int %88, %89 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %90 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%84 = torch.aten.__getitem__.t %73, %59 : !torch.list<int>, !torch.int -> !torch.int
%85 = torch.aten.add.int %arg7, %84 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %85 : !torch.int
} else {
torch.prim.If.yield %arg7 : !torch.int
}
torch.prim.Loop.condition %true, iter(%79 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%70 = torch.prim.ListConstruct : () -> !torch.list<int>
%71 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
torch.prim.Loop %71, %true, init() {
^bb0(%arg6: !torch.int):
%73 = torch.aten.__getitem__.t %67, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%74 = torch.aten.append.t %70, %73 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%72 = torch.aten._set_item.t %70, %59, %69 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %70 : !torch.list<int>
}
torch.shape.calculate.yield.shapes %66 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%21 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %24 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%27 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %25 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%28 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %22 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%29 = torch.shape.calculate {
%53 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
%56 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %54, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %55, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ne.int %64, %66 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.bool) {
%73 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%69 = torch.prim.If %68 -> (!torch.bool) {
%73 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %69 -> () {
%73 = torch.aten.format(%str_4, %64, %66, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%74 = torch.aten.add.str %str_5, %73 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %74, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
torch.prim.If.yield %66 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%72 = torch.aten.append.t %56, %71 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%57 = torch.aten.len.t %56 : !torch.list<int> -> !torch.int
%58 = torch.prim.max.int %int0, %57 : !torch.int, !torch.int -> !torch.int
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %58, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %58, %int1 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %60, %arg6 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten.sub.int %63, %61 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %53, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ge.int %64, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %56, %64 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ne.int %66, %68 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%75 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.prim.If %70 -> (!torch.bool) {
%75 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %71 -> () {
%75 = torch.aten.format(%str_4, %66, %68, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%76 = torch.aten.add.str %str_5, %75 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %76, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%72 = torch.aten.eq.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.int) {
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %66 : !torch.int
}
%74 = torch.aten.append.t %59, %73 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %34 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%37 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %35 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%38 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %32 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%39 = torch.shape.calculate {
%53 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
%56 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %54, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %55, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ne.int %64, %66 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.bool) {
%73 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%69 = torch.prim.If %68 -> (!torch.bool) {
%73 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %69 -> () {
%73 = torch.aten.format(%str_4, %64, %66, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%74 = torch.aten.add.str %str_5, %73 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %74, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
torch.prim.If.yield %66 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%72 = torch.aten.append.t %56, %71 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%57 = torch.aten.len.t %56 : !torch.list<int> -> !torch.int
%58 = torch.prim.max.int %int0, %57 : !torch.int, !torch.int -> !torch.int
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %58, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %58, %int1 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %60, %arg6 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten.sub.int %63, %61 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %53, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ge.int %64, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %56, %64 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ne.int %66, %68 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%75 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.prim.If %70 -> (!torch.bool) {
%75 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %71 -> () {
%75 = torch.aten.format(%str_4, %66, %68, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%76 = torch.aten.add.str %str_5, %75 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %76, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%72 = torch.aten.eq.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.int) {
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %66 : !torch.int
}
%74 = torch.aten.append.t %59, %73 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.shape.calculate {
%53 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?],si64>
} shapes {
%53 = torch.prim.Loop %int2, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%62 = torch.aten.size.int %14, %arg6 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%63 = torch.aten.mul.int %arg7, %62 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%63 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%54 = torch.prim.Uninitialized : !torch.int
%55 = torch.derefine %none : !torch.none to !torch.optional<int>
%56:2 = torch.prim.Loop %int3, %true, init(%int1, %55) {
^bb0(%arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.optional<int>):
%62 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.eq.int %62, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%65 = torch.aten.__isnot__ %arg8, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %65 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%66 = torch.derefine %arg6 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg7, %66 : !torch.int, !torch.optional<int>
} else {
%65 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%66 = torch.aten.ge.int %65, %int0 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.int) {
%68 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%69 = torch.aten.mul.int %arg7, %68 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %69 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %54 : !torch.int
}
torch.prim.If.yield %67, %arg8 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%64#0, %64#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%57 = torch.aten.eq.int %53, %56#0 : !torch.int, !torch.int -> !torch.bool
%58 = torch.prim.If %57 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%62 = torch.aten.__isnot__ %56#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%65 = torch.prim.unchecked_cast %56#1 : !torch.optional<int> -> !torch.int
%66 = torch.aten.gt.int %56#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %66 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
%65 = torch.prim.unchecked_cast %56#1 : !torch.optional<int> -> !torch.int
%66 = torch.aten.remainder.int %53, %56#0 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.eq.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %67 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %64 : !torch.bool
}
%59 = torch.aten.__not__ %58 : !torch.bool -> !torch.bool
torch.prim.If %59 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%62 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %60, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%61 = torch.aten.__isnot__ %56#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.prim.unchecked_cast %56#1 : !torch.optional<int> -> !torch.int
%63 = torch.aten.floordiv.int %53, %56#0 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten._set_item.t %60, %62, %63 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%47 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %46, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %53, %int3, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%48 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %arg3, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %53, %int2, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%49 = torch.shape.calculate {
%53 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %47, %arg6 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%50 = torch.shape.calculate {
%53 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %48, %arg6 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%51 = torch.shape.calculate {
%53 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%55 = torch.aten.sub.int %int3, %54 : !torch.int, !torch.int -> !torch.int
%56 = torch.aten.sub.int %int3, %54 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.ge.int %55, %int0 : !torch.int, !torch.int -> !torch.bool
%58 = torch.prim.If %57 -> (!torch.int) {
%67 = torch.aten.size.int %49, %55 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
torch.prim.If.yield %67 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%59 = torch.aten.ge.int %56, %int0 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
%67 = torch.aten.size.int %50, %56 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %67 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%61 = torch.aten.ne.int %58, %60 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.bool) {
%67 = torch.aten.ne.int %58, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %67 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%63 = torch.prim.If %62 -> (!torch.bool) {
%67 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %67 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %63 -> () {
%67 = torch.aten.format(%str_4, %58, %60, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%68 = torch.aten.add.str %str_5, %67 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %68, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%64 = torch.aten.eq.int %58, %int1 : !torch.int, !torch.int -> !torch.bool
%65 = torch.prim.If %64 -> (!torch.int) {
torch.prim.If.yield %60 : !torch.int
} else {
torch.prim.If.yield %58 : !torch.int
}
%66 = torch.aten.append.t %53, %65 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%52 = torch.shape.calculate {
%53 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%53 = torch.prim.ListConstruct %int1, %int1, %int128, %int384 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%55 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%56 = torch.aten.sub.int %int3, %55 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.sub.int %int3, %55 : !torch.int, !torch.int -> !torch.int
%58 = torch.aten.ge.int %56, %int0 : !torch.int, !torch.int -> !torch.bool
%59 = torch.prim.If %58 -> (!torch.int) {
%68 = torch.aten.size.int %51, %56 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%60 = torch.aten.ge.int %57, %int0 : !torch.int, !torch.int -> !torch.bool
%61 = torch.prim.If %60 -> (!torch.int) {
%68 = torch.aten.__getitem__.t %53, %57 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%62 = torch.aten.ne.int %59, %61 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%68 = torch.aten.ne.int %59, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %68 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
%68 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %68 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%68 = torch.aten.format(%str_4, %59, %61, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%69 = torch.aten.add.str %str_5, %68 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %69, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %59, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %61 : !torch.int
} else {
torch.prim.If.yield %59 : !torch.int
}
%67 = torch.aten.append.t %54, %66 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Inliner (inline) //----- //
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "AssertionError: Tensors must have same number of dimensions"
%str_3 = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension"
%str_4 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_5 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.shape.calculate {
%53 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?],si64>
} shapes {
%53 = torch.aten.size %arg1 : !torch.vtensor<[?,?],si64> -> !torch.list<int>
torch.prim.Loop %int2, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.add.int %arg6, %int1 : !torch.int, !torch.int -> !torch.int
%55 = torch.aten.neg.int %54 : !torch.int -> !torch.int
%56 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.__getitem__.t %13, %56 : !torch.list<int>, !torch.int -> !torch.int
%58 = torch.aten.mul.int %int2, %arg6 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.add.int %58, %int1 : !torch.int, !torch.int -> !torch.int
%60 = torch.aten.__getitem__.t %13, %59 : !torch.list<int>, !torch.int -> !torch.int
%61 = torch.aten.add.int %57, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.size.int %arg1, %55 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%63 = torch.aten.add.int %62, %61 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten._set_item.t %53, %55, %63 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%15 = torch.shape.calculate {
%53 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%55 = torch.prim.Loop %int1, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%57 = torch.aten.__getitem__.t %54, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%58 = torch.aten.mul.int %arg7, %57 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%58 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%56 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%57 = torch.aten.eq.int %int0, %arg6 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %57 -> () {
%58 = torch.aten.append.t %56, %55 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%58 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%59 = torch.aten.append.t %56, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%16 = torch.shape.calculate {
%53 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%55 = torch.aten.eq.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %55 -> () {
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.ne.int %56, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %57 -> () {
%58 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%59 = torch.aten.append.t %54, %58 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.If.yield
} else {
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%17 = torch.shape.calculate {
%53 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.sub.int %int-1, %56 : !torch.int, !torch.int -> !torch.int
%58 = torch.aten.sub.int %int-1, %56 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.ge.int %57, %int0 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
%69 = torch.aten.__getitem__.t %53, %57 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %69 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%61 = torch.aten.ge.int %58, %int0 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.int) {
%69 = torch.aten.__getitem__.t %54, %58 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %69 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%63 = torch.aten.ne.int %60, %62 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.bool) {
%69 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %69 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%65 = torch.prim.If %64 -> (!torch.bool) {
%69 = torch.aten.ne.int %62, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %69 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %65 -> () {
%69 = torch.aten.format(%str_4, %60, %62, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%70 = torch.aten.add.str %str_5, %69 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %70, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%66 = torch.aten.eq.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.int) {
torch.prim.If.yield %62 : !torch.int
} else {
torch.prim.If.yield %60 : !torch.int
}
%68 = torch.aten.append.t %55, %67 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %55 : !torch.list<int>
} : !torch.vtensor<[],si64>
%18 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%55 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%56 = torch.aten.append.t %54, %55 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.shape.calculate {
%53 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[3],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<list<int>>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%67 = torch.aten.__getitem__.t %19, %arg6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
%68 = torch.aten.size %67 : !torch.vtensor -> !torch.list<int>
%69 = torch.aten.append.t %53, %68 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%54 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %54, %true, init() {
^bb0(%arg6: !torch.int):
%67 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%68 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%69 = torch.aten.gt.int %68, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%56 = torch.derefine %none : !torch.none to !torch.optional<int>
%57 = torch.prim.Loop %55, %true, init(%56) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<int>):
%67 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%68 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%69 = torch.aten.eq.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%73 = torch.aten.__getitem__.t %67, %int0 : !torch.list<int>, !torch.int -> !torch.int
%74 = torch.aten.eq.int %73, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.aten.__not__ %70 : !torch.bool -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.optional<int>) {
%73 = torch.aten.__is__ %arg7, %none : !torch.optional<int>, !torch.none -> !torch.bool
%74 = torch.prim.If %73 -> (!torch.int) {
%76 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%77 = torch.aten.le.int %76, %int0 : !torch.int, !torch.int -> !torch.bool
%78 = torch.prim.If %77 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %76 : !torch.int
}
%79 = torch.aten.neg.int %78 : !torch.int -> !torch.int
%80 = torch.aten.sub.int %78, %int1 : !torch.int, !torch.int -> !torch.int
%81 = torch.aten.lt.int %int0, %79 : !torch.int, !torch.int -> !torch.bool
%82 = torch.prim.If %81 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%84 = torch.aten.gt.int %int0, %80 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %84 : !torch.bool
}
%83 = torch.aten.__not__ %82 : !torch.bool -> !torch.bool
torch.prim.If %83 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield %int0 : !torch.int
} else {
%76 = torch.prim.unchecked_cast %arg7 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %76 : !torch.int
}
%75 = torch.derefine %74 : !torch.int to !torch.optional<int>
torch.prim.If.yield %75 : !torch.optional<int>
} else {
torch.prim.If.yield %arg7 : !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%72 : !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
%58 = torch.aten.__is__ %57, %none : !torch.optional<int>, !torch.none -> !torch.bool
%59 = torch.prim.If %58 -> (!torch.int) {
torch.prim.If.yield %int0 : !torch.int
} else {
%67 = torch.prim.unchecked_cast %57 : !torch.optional<int> -> !torch.int
torch.prim.If.yield %67 : !torch.int
}
%60 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%61 = torch.aten.gt.int %60, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %61 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%62 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%63 = torch.derefine %none : !torch.none to !torch.optional<list<int>>
%64 = torch.prim.Loop %62, %true, init(%63) {
^bb0(%arg6: !torch.int, %arg7: !torch.optional<list<int>>):
%67 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%68 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%69 = torch.prim.Loop %68, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%74 = torch.aten.__getitem__.t %67, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%75 = torch.aten.mul.int %arg9, %74 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%75 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%70 = torch.aten.eq.int %69, %int0 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.bool) {
%74 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%75 = torch.aten.eq.int %74, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%72 = torch.aten.__not__ %71 : !torch.bool -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.optional<list<int>>) {
%74 = torch.derefine %67 : !torch.list<int> to !torch.optional<list<int>>
torch.prim.If.yield %74 : !torch.optional<list<int>>
} else {
torch.prim.If.yield %arg7 : !torch.optional<list<int>>
}
torch.prim.Loop.condition %true, iter(%73 : !torch.optional<list<int>>)
} : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>
%65 = torch.aten.__is__ %64, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.list<int>) {
%67 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
torch.prim.If.yield %67 : !torch.list<int>
} else {
%67 = torch.prim.unchecked_cast %64 : !torch.optional<list<int>> -> !torch.list<int>
%68 = torch.aten.len.t %53 : !torch.list<list<int>> -> !torch.int
%69 = torch.prim.Loop %68, %true, init(%int0) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%73 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%74 = torch.aten.len.t %73 : !torch.list<int> -> !torch.int
%75 = torch.prim.Loop %74, %true, init(%int1) {
^bb0(%arg8: !torch.int, %arg9: !torch.int):
%80 = torch.aten.__getitem__.t %73, %arg8 : !torch.list<int>, !torch.int -> !torch.int
%81 = torch.aten.mul.int %arg9, %80 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%81 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%76 = torch.aten.eq.int %75, %int0 : !torch.int, !torch.int -> !torch.bool
%77 = torch.prim.If %76 -> (!torch.bool) {
%80 = torch.aten.len.t %73 : !torch.list<int> -> !torch.int
%81 = torch.aten.eq.int %80, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %81 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%78 = torch.aten.__not__ %77 : !torch.bool -> !torch.bool
%79 = torch.prim.If %78 -> (!torch.int) {
%80 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
%81 = torch.aten.len.t %73 : !torch.list<int> -> !torch.int
%82 = torch.aten.eq.int %80, %81 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %82 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%83 = torch.aten.__range_length %int0, %80, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %83, %true, init() {
^bb0(%arg8: !torch.int):
%86 = torch.aten.__derive_index %arg8, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%87 = torch.aten.ne.int %86, %59 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %87 -> () {
%88 = torch.aten.__getitem__.t %67, %86 : !torch.list<int>, !torch.int -> !torch.int
%89 = torch.aten.__getitem__.t %73, %86 : !torch.list<int>, !torch.int -> !torch.int
%90 = torch.aten.eq.int %88, %89 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %90 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none
torch.prim.If.yield
}
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%84 = torch.aten.__getitem__.t %73, %59 : !torch.list<int>, !torch.int -> !torch.int
%85 = torch.aten.add.int %arg7, %84 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %85 : !torch.int
} else {
torch.prim.If.yield %arg7 : !torch.int
}
torch.prim.Loop.condition %true, iter(%79 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%70 = torch.prim.ListConstruct : () -> !torch.list<int>
%71 = torch.aten.len.t %67 : !torch.list<int> -> !torch.int
torch.prim.Loop %71, %true, init() {
^bb0(%arg6: !torch.int):
%73 = torch.aten.__getitem__.t %67, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%74 = torch.aten.append.t %70, %73 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%72 = torch.aten._set_item.t %70, %59, %69 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield %70 : !torch.list<int>
}
torch.shape.calculate.yield.shapes %66 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%21 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %24 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%27 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %25 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%28 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %22 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%29 = torch.shape.calculate {
%53 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
%56 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %54, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %55, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ne.int %64, %66 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.bool) {
%73 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%69 = torch.prim.If %68 -> (!torch.bool) {
%73 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %69 -> () {
%73 = torch.aten.format(%str_4, %64, %66, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%74 = torch.aten.add.str %str_5, %73 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %74, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
torch.prim.If.yield %66 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%72 = torch.aten.append.t %56, %71 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%57 = torch.aten.len.t %56 : !torch.list<int> -> !torch.int
%58 = torch.prim.max.int %int0, %57 : !torch.int, !torch.int -> !torch.int
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %58, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %58, %int1 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %60, %arg6 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten.sub.int %63, %61 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %53, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ge.int %64, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %56, %64 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ne.int %66, %68 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%75 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.prim.If %70 -> (!torch.bool) {
%75 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %71 -> () {
%75 = torch.aten.format(%str_4, %66, %68, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%76 = torch.aten.add.str %str_5, %75 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %76, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%72 = torch.aten.eq.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.int) {
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %66 : !torch.int
}
%74 = torch.aten.append.t %59, %73 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %34 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%37 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %35 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%38 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %32 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%39 = torch.shape.calculate {
%53 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
%55 = torch.prim.ListConstruct : () -> !torch.list<int>
%56 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int0, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %int-1, %arg6 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %60 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.ge.int %61, %int0 : !torch.int, !torch.int -> !torch.bool
%64 = torch.prim.If %63 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %54, %61 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%73 = torch.aten.__getitem__.t %55, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %73 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ne.int %64, %66 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.bool) {
%73 = torch.aten.ne.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%69 = torch.prim.If %68 -> (!torch.bool) {
%73 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %69 -> () {
%73 = torch.aten.format(%str_4, %64, %66, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%74 = torch.aten.add.str %str_5, %73 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %74, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.aten.eq.int %64, %int1 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
torch.prim.If.yield %66 : !torch.int
} else {
torch.prim.If.yield %64 : !torch.int
}
%72 = torch.aten.append.t %56, %71 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%57 = torch.aten.len.t %56 : !torch.list<int> -> !torch.int
%58 = torch.prim.max.int %int0, %57 : !torch.int, !torch.int -> !torch.int
%59 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %58, %true, init() {
^bb0(%arg6: !torch.int):
%60 = torch.aten.sub.int %58, %int1 : !torch.int, !torch.int -> !torch.int
%61 = torch.aten.sub.int %60, %arg6 : !torch.int, !torch.int -> !torch.int
%62 = torch.aten.sub.int %int-1, %61 : !torch.int, !torch.int -> !torch.int
%63 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten.sub.int %63, %61 : !torch.int, !torch.int -> !torch.int
%65 = torch.aten.ge.int %62, %int0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %53, %62 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%67 = torch.aten.ge.int %64, %int0 : !torch.int, !torch.int -> !torch.bool
%68 = torch.prim.If %67 -> (!torch.int) {
%75 = torch.aten.__getitem__.t %56, %64 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %75 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%69 = torch.aten.ne.int %66, %68 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%75 = torch.aten.ne.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.prim.If %70 -> (!torch.bool) {
%75 = torch.aten.ne.int %68, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %71 -> () {
%75 = torch.aten.format(%str_4, %66, %68, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%76 = torch.aten.add.str %str_5, %75 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %76, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%72 = torch.aten.eq.int %66, %int1 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.If %72 -> (!torch.int) {
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %66 : !torch.int
}
%74 = torch.aten.append.t %59, %73 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %59 : !torch.list<int>
} : !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int1, %true, init() {
^bb0(%arg6: !torch.int):
%56 = torch.aten.__getitem__.t %53, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%57 = torch.aten.append.t %54, %56 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%55 = torch.aten._set_item.t %54, %int0, %int1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.shape.calculate {
%53 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?],si64>
} shapes {
%53 = torch.prim.Loop %int2, %true, init(%int1) {
^bb0(%arg6: !torch.int, %arg7: !torch.int):
%62 = torch.aten.size.int %14, %arg6 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%63 = torch.aten.mul.int %arg7, %62 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop.condition %true, iter(%63 : !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> !torch.int
%54 = torch.prim.Uninitialized : !torch.int
%55 = torch.derefine %none : !torch.none to !torch.optional<int>
%56:2 = torch.prim.Loop %int3, %true, init(%int1, %55) {
^bb0(%arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.optional<int>):
%62 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.eq.int %62, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%65 = torch.aten.__isnot__ %arg8, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %65 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%66 = torch.derefine %arg6 : !torch.int to !torch.optional<int>
torch.prim.If.yield %arg7, %66 : !torch.int, !torch.optional<int>
} else {
%65 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%66 = torch.aten.ge.int %65, %int0 : !torch.int, !torch.int -> !torch.bool
%67 = torch.prim.If %66 -> (!torch.int) {
%68 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%69 = torch.aten.mul.int %arg7, %68 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %69 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %54 : !torch.int
}
torch.prim.If.yield %67, %arg8 : !torch.int, !torch.optional<int>
}
torch.prim.Loop.condition %true, iter(%64#0, %64#1 : !torch.int, !torch.optional<int>)
} : (!torch.int, !torch.bool, !torch.int, !torch.optional<int>) -> (!torch.int, !torch.optional<int>)
%57 = torch.aten.eq.int %53, %56#0 : !torch.int, !torch.int -> !torch.bool
%58 = torch.prim.If %57 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%62 = torch.aten.__isnot__ %56#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%65 = torch.prim.unchecked_cast %56#1 : !torch.optional<int> -> !torch.int
%66 = torch.aten.gt.int %56#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %66 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
%65 = torch.prim.unchecked_cast %56#1 : !torch.optional<int> -> !torch.int
%66 = torch.aten.remainder.int %53, %56#0 : !torch.int, !torch.int -> !torch.int
%67 = torch.aten.eq.int %66, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %67 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %64 : !torch.bool
}
%59 = torch.aten.__not__ %58 : !torch.bool -> !torch.bool
torch.prim.If %59 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%60 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%62 = torch.aten.__getitem__.t %45, %arg6 : !torch.list<int>, !torch.int -> !torch.int
%63 = torch.aten.append.t %60, %62 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%61 = torch.aten.__isnot__ %56#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %61 -> () {
%62 = torch.prim.unchecked_cast %56#1 : !torch.optional<int> -> !torch.int
%63 = torch.aten.floordiv.int %53, %56#0 : !torch.int, !torch.int -> !torch.int
%64 = torch.aten._set_item.t %60, %62, %63 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %60 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%47 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %46, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %53, %int3, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%48 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int3, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %arg3, %arg6 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.aten.insert.t %53, %int2, %int1 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%49 = torch.shape.calculate {
%53 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %47, %arg6 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%50 = torch.shape.calculate {
%53 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.size.int %48, %arg6 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.append.t %53, %54 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%51 = torch.shape.calculate {
%53 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%54 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%55 = torch.aten.sub.int %int3, %54 : !torch.int, !torch.int -> !torch.int
%56 = torch.aten.sub.int %int3, %54 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.ge.int %55, %int0 : !torch.int, !torch.int -> !torch.bool
%58 = torch.prim.If %57 -> (!torch.int) {
%67 = torch.aten.size.int %49, %55 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
torch.prim.If.yield %67 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%59 = torch.aten.ge.int %56, %int0 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
%67 = torch.aten.size.int %50, %56 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %67 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%61 = torch.aten.ne.int %58, %60 : !torch.int, !torch.int -> !torch.bool
%62 = torch.prim.If %61 -> (!torch.bool) {
%67 = torch.aten.ne.int %58, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %67 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%63 = torch.prim.If %62 -> (!torch.bool) {
%67 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %67 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %63 -> () {
%67 = torch.aten.format(%str_4, %58, %60, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%68 = torch.aten.add.str %str_5, %67 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %68, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%64 = torch.aten.eq.int %58, %int1 : !torch.int, !torch.int -> !torch.bool
%65 = torch.prim.If %64 -> (!torch.int) {
torch.prim.If.yield %60 : !torch.int
} else {
torch.prim.If.yield %58 : !torch.int
}
%66 = torch.aten.append.t %53, %65 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%52 = torch.shape.calculate {
%53 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%53 = torch.prim.ListConstruct %int1, %int1, %int128, %int384 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %int4, %true, init() {
^bb0(%arg6: !torch.int):
%55 = torch.aten.sub.int %int3, %arg6 : !torch.int, !torch.int -> !torch.int
%56 = torch.aten.sub.int %int3, %55 : !torch.int, !torch.int -> !torch.int
%57 = torch.aten.sub.int %int3, %55 : !torch.int, !torch.int -> !torch.int
%58 = torch.aten.ge.int %56, %int0 : !torch.int, !torch.int -> !torch.bool
%59 = torch.prim.If %58 -> (!torch.int) {
%68 = torch.aten.size.int %51, %56 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%60 = torch.aten.ge.int %57, %int0 : !torch.int, !torch.int -> !torch.bool
%61 = torch.prim.If %60 -> (!torch.int) {
%68 = torch.aten.__getitem__.t %53, %57 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %68 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%62 = torch.aten.ne.int %59, %61 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%68 = torch.aten.ne.int %59, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %68 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
%68 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %68 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%68 = torch.aten.format(%str_4, %59, %61, %arg6) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%69 = torch.aten.add.str %str_5, %68 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %69, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %59, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %61 : !torch.int
} else {
torch.prim.If.yield %59 : !torch.int
}
%67 = torch.aten.append.t %54, %66 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
}
// -----// IR Dump After SimplifyShapeCalculations (torch-simplify-shape-calculations) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.shape.calculate {
%53 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?],si64>
} shapes {
%53 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %arg1, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.add.int %8, %12 : !torch.int, !torch.int -> !torch.int
%56 = torch.aten.size.int %arg1, %int-1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%57 = torch.aten.add.int %56, %55 : !torch.int, !torch.int -> !torch.int
%58 = torch.aten.add.int %6, %10 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.size.int %arg1, %int-2 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%60 = torch.aten.add.int %59, %58 : !torch.int, !torch.int -> !torch.int
%61 = torch.prim.ListConstruct %60, %57 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %61 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%15 = torch.shape.calculate {
%53 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%16 = torch.shape.calculate {
%53 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%17 = torch.shape.calculate {
%53 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%18 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.shape.calculate {
%53 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[3],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%21 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %24 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%27 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %25 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%28 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %22 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%29 = torch.shape.calculate {
%53 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %34 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%37 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %35 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%38 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %32 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%39 = torch.shape.calculate {
%53 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.shape.calculate {
%53 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?],si64>
} shapes {
%53 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.mul.int %int1, %53 : !torch.int, !torch.int -> !torch.int
%55 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%56 = torch.aten.mul.int %54, %55 : !torch.int, !torch.int -> !torch.int
%57 = torch.prim.Uninitialized : !torch.int
%58 = torch.derefine %none : !torch.none to !torch.optional<int>
%59 = torch.aten.eq.int %30, %int-1 : !torch.int, !torch.int -> !torch.bool
%60:2 = torch.prim.If %59 -> (!torch.int, !torch.optional<int>) {
%70 = torch.derefine %int0 : !torch.int to !torch.optional<int>
torch.prim.If.yield %int1, %70 : !torch.int, !torch.optional<int>
} else {
%70 = torch.aten.ge.int %30, %int0 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
%72 = torch.aten.mul.int %int1, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %71, %58 : !torch.int, !torch.optional<int>
}
%61 = torch.aten.eq.int %40, %int-1 : !torch.int, !torch.int -> !torch.bool
%62:2 = torch.prim.If %61 -> (!torch.int, !torch.optional<int>) {
%70 = torch.aten.__isnot__ %60#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %70 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%71 = torch.derefine %int1 : !torch.int to !torch.optional<int>
torch.prim.If.yield %60#0, %71 : !torch.int, !torch.optional<int>
} else {
%70 = torch.aten.ge.int %40, %int0 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
%72 = torch.aten.mul.int %60#0, %40 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %71, %60#1 : !torch.int, !torch.optional<int>
}
%63 = torch.aten.eq.int %42, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%70 = torch.aten.__isnot__ %62#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %70 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%71 = torch.derefine %int2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %62#0, %71 : !torch.int, !torch.optional<int>
} else {
%70 = torch.aten.ge.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.int) {
%72 = torch.aten.mul.int %62#0, %42 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %72 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %71, %62#1 : !torch.int, !torch.optional<int>
}
%65 = torch.aten.eq.int %56, %64#0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%70 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%71 = torch.prim.If %70 -> (!torch.bool) {
%73 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%74 = torch.aten.gt.int %64#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%72 = torch.prim.If %71 -> (!torch.bool) {
%73 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%74 = torch.aten.remainder.int %56, %64#0 : !torch.int, !torch.int -> !torch.int
%75 = torch.aten.eq.int %74, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %75 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %72 : !torch.bool
}
%67 = torch.aten.__not__ %66 : !torch.bool -> !torch.bool
torch.prim.If %67 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%68 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%69 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %69 -> () {
%70 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%71 = torch.aten.floordiv.int %56, %64#0 : !torch.int, !torch.int -> !torch.int
%72 = torch.aten._set_item.t %68, %70, %71 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %68 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%47 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%53 = torch.aten.size.int %46, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %46, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %46, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %55, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%48 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%53 = torch.aten.size.int %arg3, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %int1, %55 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%49 = torch.shape.calculate {
%53 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%53 = torch.aten.size.int %47, %int0 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %47, %int1 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %47, %int2 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %55, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%50 = torch.shape.calculate {
%53 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%53 = torch.aten.size.int %48, %int0 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %48, %int1 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %48, %int3 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %int1, %55 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%51 = torch.shape.calculate {
%53 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%53 = torch.aten.size.int %49, %int0 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%54 = torch.aten.size.int %50, %int0 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%55 = torch.aten.ne.int %53, %54 : !torch.int, !torch.int -> !torch.bool
%56 = torch.prim.If %55 -> (!torch.bool) {
%74 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%57 = torch.prim.If %56 -> (!torch.bool) {
%74 = torch.aten.ne.int %54, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %57 -> () {
%74 = torch.aten.format(%str_2, %53, %54, %int0) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%58 = torch.aten.eq.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
%59 = torch.prim.If %58 -> (!torch.int) {
torch.prim.If.yield %54 : !torch.int
} else {
torch.prim.If.yield %53 : !torch.int
}
%60 = torch.aten.size.int %49, %int1 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%61 = torch.aten.size.int %50, %int1 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %60, %61 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%74 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
%74 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%74 = torch.aten.format(%str_2, %60, %61, %int1) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %61 : !torch.int
} else {
torch.prim.If.yield %60 : !torch.int
}
%67 = torch.aten.size.int %49, %int2 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%68 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %68 -> () {
%74 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %67 : !torch.int
}
%71 = torch.aten.size.int %50, %int3 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %int1, %71 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.ListConstruct %59, %66, %70, %71 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %73 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%52 = torch.shape.calculate {
%53 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%53 = torch.aten.size.int %51, %int0 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%54 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %54 -> () {
%74 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%55 = torch.aten.eq.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
%56 = torch.prim.If %55 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %53 : !torch.int
}
%57 = torch.aten.size.int %51, %int1 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%58 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %58 -> () {
%74 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%59 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%61 = torch.aten.size.int %51, %int2 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %61, %int128 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%74 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%74 = torch.aten.format(%str_2, %61, %int128, %int2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %int128 : !torch.int
} else {
torch.prim.If.yield %61 : !torch.int
}
%67 = torch.aten.size.int %51, %int3 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%68 = torch.aten.ne.int %67, %int384 : !torch.int, !torch.int -> !torch.bool
%69 = torch.prim.If %68 -> (!torch.bool) {
%74 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%70 = torch.prim.If %69 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %70 -> () {
%74 = torch.aten.format(%str_2, %67, %int384, %int3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%71 = torch.aten.eq.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.int) {
torch.prim.If.yield %int384 : !torch.int
} else {
torch.prim.If.yield %67 : !torch.int
}
%73 = torch.prim.ListConstruct %56, %60, %66, %72 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %73 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After CSE (cse) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.shape.calculate {
%53 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?],si64>
} shapes {
%53 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %arg1, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.add.int %8, %12 : !torch.int, !torch.int -> !torch.int
%56 = torch.aten.size.int %arg1, %int-1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%57 = torch.aten.add.int %56, %55 : !torch.int, !torch.int -> !torch.int
%58 = torch.aten.add.int %6, %10 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.size.int %arg1, %int-2 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%60 = torch.aten.add.int %59, %58 : !torch.int, !torch.int -> !torch.int
%61 = torch.prim.ListConstruct %60, %57 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %61 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%15 = torch.shape.calculate {
%53 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%16 = torch.shape.calculate {
%53 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%17 = torch.shape.calculate {
%53 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%18 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.shape.calculate {
%53 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[3],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%21 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %24 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%27 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %25 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%28 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %22 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%29 = torch.shape.calculate {
%53 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %34 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%37 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %35 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%38 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %32 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%39 = torch.shape.calculate {
%53 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.shape.calculate {
%53 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?],si64>
} shapes {
%53 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.mul.int %int1, %53 : !torch.int, !torch.int -> !torch.int
%55 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%56 = torch.aten.mul.int %54, %55 : !torch.int, !torch.int -> !torch.int
%57 = torch.prim.Uninitialized : !torch.int
%58 = torch.derefine %none : !torch.none to !torch.optional<int>
%59 = torch.aten.eq.int %30, %int-1 : !torch.int, !torch.int -> !torch.bool
%60:2 = torch.prim.If %59 -> (!torch.int, !torch.optional<int>) {
%69 = torch.derefine %int0 : !torch.int to !torch.optional<int>
torch.prim.If.yield %int1, %69 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.ge.int %30, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%71 = torch.aten.mul.int %int1, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %70, %58 : !torch.int, !torch.optional<int>
}
%61 = torch.aten.eq.int %40, %int-1 : !torch.int, !torch.int -> !torch.bool
%62:2 = torch.prim.If %61 -> (!torch.int, !torch.optional<int>) {
%69 = torch.aten.__isnot__ %60#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.derefine %int1 : !torch.int to !torch.optional<int>
torch.prim.If.yield %60#0, %70 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.ge.int %40, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%71 = torch.aten.mul.int %60#0, %40 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %70, %60#1 : !torch.int, !torch.optional<int>
}
%63 = torch.aten.eq.int %42, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%69 = torch.aten.__isnot__ %62#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.derefine %int2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %62#0, %70 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.ge.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%71 = torch.aten.mul.int %62#0, %42 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %70, %62#1 : !torch.int, !torch.optional<int>
}
%65 = torch.aten.eq.int %56, %64#0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%69 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%72 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%73 = torch.aten.gt.int %64#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.prim.If %70 -> (!torch.bool) {
%72 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%73 = torch.aten.remainder.int %56, %64#0 : !torch.int, !torch.int -> !torch.int
%74 = torch.aten.eq.int %73, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %71 : !torch.bool
}
%67 = torch.aten.__not__ %66 : !torch.bool -> !torch.bool
torch.prim.If %67 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%68 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %68 -> () {
%69 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%70 = torch.aten.floordiv.int %56, %64#0 : !torch.int, !torch.int -> !torch.int
%71 = torch.aten._set_item.t %45, %69, %70 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %45 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%47 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%53 = torch.aten.size.int %46, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %46, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %46, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %55, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%48 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%53 = torch.aten.size.int %arg3, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %int1, %55 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%49 = torch.shape.calculate {
%53 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%53 = torch.aten.size.int %47, %int0 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %47, %int1 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %47, %int2 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %55, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%50 = torch.shape.calculate {
%53 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%53 = torch.aten.size.int %48, %int0 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %48, %int1 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %48, %int3 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %int1, %55 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%51 = torch.shape.calculate {
%53 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%53 = torch.aten.size.int %49, %int0 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%54 = torch.aten.size.int %50, %int0 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%55 = torch.aten.ne.int %53, %54 : !torch.int, !torch.int -> !torch.bool
%56 = torch.prim.If %55 -> (!torch.bool) {
%74 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%57 = torch.prim.If %56 -> (!torch.bool) {
%74 = torch.aten.ne.int %54, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %57 -> () {
%74 = torch.aten.format(%str_2, %53, %54, %int0) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%58 = torch.aten.eq.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
%59 = torch.prim.If %58 -> (!torch.int) {
torch.prim.If.yield %54 : !torch.int
} else {
torch.prim.If.yield %53 : !torch.int
}
%60 = torch.aten.size.int %49, %int1 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%61 = torch.aten.size.int %50, %int1 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %60, %61 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%74 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
%74 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%74 = torch.aten.format(%str_2, %60, %61, %int1) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %61 : !torch.int
} else {
torch.prim.If.yield %60 : !torch.int
}
%67 = torch.aten.size.int %49, %int2 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%68 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %68 -> () {
%74 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %67 : !torch.int
}
%71 = torch.aten.size.int %50, %int3 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %int1, %71 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.ListConstruct %59, %66, %70, %71 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %73 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%52 = torch.shape.calculate {
%53 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%53 = torch.aten.size.int %51, %int0 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%54 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %54 -> () {
%74 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%55 = torch.aten.eq.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
%56 = torch.prim.If %55 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %53 : !torch.int
}
%57 = torch.aten.size.int %51, %int1 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%58 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %58 -> () {
%74 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%59 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%61 = torch.aten.size.int %51, %int2 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %61, %int128 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%74 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%74 = torch.aten.format(%str_2, %61, %int128, %int2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %int128 : !torch.int
} else {
torch.prim.If.yield %61 : !torch.int
}
%67 = torch.aten.size.int %51, %int3 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%68 = torch.aten.ne.int %67, %int384 : !torch.int, !torch.int -> !torch.bool
%69 = torch.prim.If %68 -> (!torch.bool) {
%74 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%70 = torch.prim.If %69 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %70 -> () {
%74 = torch.aten.format(%str_2, %67, %int384, %int3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%71 = torch.aten.eq.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.int) {
torch.prim.If.yield %int384 : !torch.int
} else {
torch.prim.If.yield %67 : !torch.int
}
%73 = torch.prim.ListConstruct %56, %60, %66, %72 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %73 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After SimplifyShapeCalculations (torch-simplify-shape-calculations) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.shape.calculate {
%53 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?],si64>
} shapes {
%53 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %arg1, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.add.int %8, %12 : !torch.int, !torch.int -> !torch.int
%56 = torch.aten.size.int %arg1, %int-1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%57 = torch.aten.add.int %56, %55 : !torch.int, !torch.int -> !torch.int
%58 = torch.aten.add.int %6, %10 : !torch.int, !torch.int -> !torch.int
%59 = torch.aten.size.int %arg1, %int-2 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%60 = torch.aten.add.int %59, %58 : !torch.int, !torch.int -> !torch.int
%61 = torch.prim.ListConstruct %60, %57 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %61 : !torch.list<int>
} : !torch.vtensor<[?,?],si64>
%15 = torch.shape.calculate {
%53 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%16 = torch.shape.calculate {
%53 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%17 = torch.shape.calculate {
%53 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%18 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.shape.calculate {
%53 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[3],si64>
} shapes {
%53 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[3],si64>
%21 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %24 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%27 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %25 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%28 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %22 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%29 = torch.shape.calculate {
%53 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[],i1>
} shapes {
%53 = torch.aten.Float.Scalar %34 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],i1>
%37 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %35 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%38 = torch.shape.calculate {
%53 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.aten.Float.Scalar %32 : !torch.int -> !torch.float
%54 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %54 : !torch.list<int>
} : !torch.vtensor<[],si64>
%39 = torch.shape.calculate {
%53 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[],si64>
} shapes {
%53 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.shape.calculate {
%53 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[1],si64>
} shapes {
%53 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %53 : !torch.list<int>
} : !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.shape.calculate {
%53 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?],si64>
} shapes {
%53 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.mul.int %int1, %53 : !torch.int, !torch.int -> !torch.int
%55 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%56 = torch.aten.mul.int %54, %55 : !torch.int, !torch.int -> !torch.int
%57 = torch.prim.Uninitialized : !torch.int
%58 = torch.derefine %none : !torch.none to !torch.optional<int>
%59 = torch.aten.eq.int %30, %int-1 : !torch.int, !torch.int -> !torch.bool
%60:2 = torch.prim.If %59 -> (!torch.int, !torch.optional<int>) {
%69 = torch.derefine %int0 : !torch.int to !torch.optional<int>
torch.prim.If.yield %int1, %69 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.ge.int %30, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%71 = torch.aten.mul.int %int1, %30 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %70, %58 : !torch.int, !torch.optional<int>
}
%61 = torch.aten.eq.int %40, %int-1 : !torch.int, !torch.int -> !torch.bool
%62:2 = torch.prim.If %61 -> (!torch.int, !torch.optional<int>) {
%69 = torch.aten.__isnot__ %60#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.derefine %int1 : !torch.int to !torch.optional<int>
torch.prim.If.yield %60#0, %70 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.ge.int %40, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%71 = torch.aten.mul.int %60#0, %40 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %70, %60#1 : !torch.int, !torch.optional<int>
}
%63 = torch.aten.eq.int %42, %int-1 : !torch.int, !torch.int -> !torch.bool
%64:2 = torch.prim.If %63 -> (!torch.int, !torch.optional<int>) {
%69 = torch.aten.__isnot__ %62#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %69 -> () {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%70 = torch.derefine %int2 : !torch.int to !torch.optional<int>
torch.prim.If.yield %62#0, %70 : !torch.int, !torch.optional<int>
} else {
%69 = torch.aten.ge.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
%71 = torch.aten.mul.int %62#0, %42 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %71 : !torch.int
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield %57 : !torch.int
}
torch.prim.If.yield %70, %62#1 : !torch.int, !torch.optional<int>
}
%65 = torch.aten.eq.int %56, %64#0 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%69 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.bool) {
%72 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%73 = torch.aten.gt.int %64#0, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %73 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%71 = torch.prim.If %70 -> (!torch.bool) {
%72 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%73 = torch.aten.remainder.int %56, %64#0 : !torch.int, !torch.int -> !torch.int
%74 = torch.aten.eq.int %73, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If.yield %71 : !torch.bool
}
%67 = torch.aten.__not__ %66 : !torch.bool -> !torch.bool
torch.prim.If %67 -> () {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%68 = torch.aten.__isnot__ %64#1, %none : !torch.optional<int>, !torch.none -> !torch.bool
torch.prim.If %68 -> () {
%69 = torch.prim.unchecked_cast %64#1 : !torch.optional<int> -> !torch.int
%70 = torch.aten.floordiv.int %56, %64#0 : !torch.int, !torch.int -> !torch.int
%71 = torch.aten._set_item.t %45, %69, %70 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %45 : !torch.list<int>
} : !torch.vtensor<[?,?,?],si64>
%47 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],si64>
} shapes {
%53 = torch.aten.size.int %46, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %46, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %46, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %55, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],si64>
%48 = torch.shape.calculate {
%53 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],si64>
} shapes {
%53 = torch.aten.size.int %arg3, %int0 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %int1, %55 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],si64>
%49 = torch.shape.calculate {
%53 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,1],i1>
} shapes {
%53 = torch.aten.size.int %47, %int0 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %47, %int1 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %47, %int2 : !torch.vtensor<[?,?,?,1],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %55, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,?,1],i1>
%50 = torch.shape.calculate {
%53 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,1,?],i1>
} shapes {
%53 = torch.aten.size.int %48, %int0 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%54 = torch.aten.size.int %48, %int1 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%55 = torch.aten.size.int %48, %int3 : !torch.vtensor<[?,?,1,?],si64>, !torch.int -> !torch.int
%56 = torch.prim.ListConstruct %53, %54, %int1, %55 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %56 : !torch.list<int>
} : !torch.vtensor<[?,?,1,?],i1>
%51 = torch.shape.calculate {
%53 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,?,?],i1>
} shapes {
%53 = torch.aten.size.int %49, %int0 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%54 = torch.aten.size.int %50, %int0 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%55 = torch.aten.ne.int %53, %54 : !torch.int, !torch.int -> !torch.bool
%56 = torch.prim.If %55 -> (!torch.bool) {
%74 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%57 = torch.prim.If %56 -> (!torch.bool) {
%74 = torch.aten.ne.int %54, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %57 -> () {
%74 = torch.aten.format(%str_2, %53, %54, %int0) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%58 = torch.aten.eq.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
%59 = torch.prim.If %58 -> (!torch.int) {
torch.prim.If.yield %54 : !torch.int
} else {
torch.prim.If.yield %53 : !torch.int
}
%60 = torch.aten.size.int %49, %int1 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%61 = torch.aten.size.int %50, %int1 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %60, %61 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%74 = torch.aten.ne.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
%74 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%74 = torch.aten.format(%str_2, %60, %61, %int1) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %60, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %61 : !torch.int
} else {
torch.prim.If.yield %60 : !torch.int
}
%67 = torch.aten.size.int %49, %int2 : !torch.vtensor<[?,?,?,1],i1>, !torch.int -> !torch.int
%68 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %68 -> () {
%74 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%69 = torch.aten.eq.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
%70 = torch.prim.If %69 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %67 : !torch.int
}
%71 = torch.aten.size.int %50, %int3 : !torch.vtensor<[?,?,1,?],i1>, !torch.int -> !torch.int
%72 = torch.aten.ne.int %int1, %71 : !torch.int, !torch.int -> !torch.bool
%73 = torch.prim.ListConstruct %59, %66, %70, %71 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %73 : !torch.list<int>
} : !torch.vtensor<[?,?,?,?],i1>
%52 = torch.shape.calculate {
%53 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
torch.shape.calculate.yield %53 : !torch.vtensor<[?,?,128,384],i1>
} shapes {
%53 = torch.aten.size.int %51, %int0 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%54 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %54 -> () {
%74 = torch.aten.ne.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%55 = torch.aten.eq.int %53, %int1 : !torch.int, !torch.int -> !torch.bool
%56 = torch.prim.If %55 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %53 : !torch.int
}
%57 = torch.aten.size.int %51, %int1 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%58 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %58 -> () {
%74 = torch.aten.ne.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%59 = torch.aten.eq.int %57, %int1 : !torch.int, !torch.int -> !torch.bool
%60 = torch.prim.If %59 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %57 : !torch.int
}
%61 = torch.aten.size.int %51, %int2 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%62 = torch.aten.ne.int %61, %int128 : !torch.int, !torch.int -> !torch.bool
%63 = torch.prim.If %62 -> (!torch.bool) {
%74 = torch.aten.ne.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%64 = torch.prim.If %63 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %64 -> () {
%74 = torch.aten.format(%str_2, %61, %int128, %int2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%65 = torch.aten.eq.int %61, %int1 : !torch.int, !torch.int -> !torch.bool
%66 = torch.prim.If %65 -> (!torch.int) {
torch.prim.If.yield %int128 : !torch.int
} else {
torch.prim.If.yield %61 : !torch.int
}
%67 = torch.aten.size.int %51, %int3 : !torch.vtensor<[?,?,?,?],i1>, !torch.int -> !torch.int
%68 = torch.aten.ne.int %67, %int384 : !torch.int, !torch.int -> !torch.bool
%69 = torch.prim.If %68 -> (!torch.bool) {
%74 = torch.aten.ne.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %74 : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
%70 = torch.prim.If %69 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
torch.prim.If.yield %false : !torch.bool
}
torch.prim.If %70 -> () {
%74 = torch.aten.format(%str_2, %67, %int384, %int3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str
%75 = torch.aten.add.str %str_3, %74 : !torch.str, !torch.str -> !torch.str
torch.prim.RaiseException %75, %none : !torch.str, !torch.none
torch.prim.If.yield
} else {
torch.prim.If.yield
}
%71 = torch.aten.eq.int %67, %int1 : !torch.int, !torch.int -> !torch.bool
%72 = torch.prim.If %71 -> (!torch.int) {
torch.prim.If.yield %int384 : !torch.int
} else {
torch.prim.If.yield %67 : !torch.int
}
%73 = torch.prim.ListConstruct %56, %60, %66, %72 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %73 : !torch.list<int>
} : !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After DropAbstractInterpCalculations (torch-drop-abstract-interp-calculations) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After RefinePublicReturn (torch-refine-public-return) //----- //
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%str = torch.constant.str "AssertionError: only one dimension can be inferred"
%str_0 = torch.constant.str "AssertionError: invalid shape dimensions"
%str_1 = torch.constant.str "AssertionError: invalid shape"
%str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
%str_3 = torch.constant.str "AssertionError: "
%int384 = torch.constant.int 384
%int128 = torch.constant.int 128
%true = torch.constant.bool true
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After DecomposeComplexOps (torch-decompose-complex-ops) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After BindSymbolicShapesPass (torch-iree-bind-symbolic-shapes) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After SetStrictSymbolicShapesPass (torch-iree-set-strict-symbolic-shapes) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After BitCastQuantTensorPass (torch-iree-bitcast-quant-tensor) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ReduceOpVariants (torch-reduce-op-variants) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ConvertCustomQuantOp (torch-convert-custom-quant-op) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After DecomposeComplexOps (torch-decompose-complex-ops) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After FuseQuantizedOps (torch-fuse-quantized-ops) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ScalarizeShapes (torch-scalarize-shapes) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ConvertTorchToTMTensor (convert-torch-to-tmtensor) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.constant_pad_nd %arg1, %13, %int0 : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?],si64>
%15 = torch.aten.index_select %arg5, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%16 = torch.aten.squeeze.dim %15, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%17 = torch.aten.div.Tensor %16, %4 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%18 = torch.aten.unsqueeze %17, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%19 = torch.prim.ListConstruct %3, %18, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%20 = torch.aten.cat %19, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%21 = torch.aten.slice.Tensor %20, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool
%24 = torch.aten.Int.bool %23 : !torch.bool -> !torch.int
%25 = torch.aten.size.int %14, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%26 = torch.prim.NumToTensor.Scalar %24 : !torch.int -> !torch.vtensor<[],i1>
%27 = torch.prim.NumToTensor.Scalar %25 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %22 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.aten.where.self %26, %27, %28 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%30 = torch.aten.item %29 : !torch.vtensor<[],si64> -> !torch.int
%31 = torch.aten.slice.Tensor %20, %int0, %int1, %int2, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%32 = torch.aten.item %31 : !torch.vtensor<[1],si64> -> !torch.int
%33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool
%34 = torch.aten.Int.bool %33 : !torch.bool -> !torch.int
%35 = torch.aten.size.int %14, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%36 = torch.prim.NumToTensor.Scalar %34 : !torch.int -> !torch.vtensor<[],i1>
%37 = torch.prim.NumToTensor.Scalar %35 : !torch.int -> !torch.vtensor<[],si64>
%38 = torch.prim.NumToTensor.Scalar %32 : !torch.int -> !torch.vtensor<[],si64>
%39 = torch.aten.where.self %36, %37, %38 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%40 = torch.aten.item %39 : !torch.vtensor<[],si64> -> !torch.int
%41 = torch.aten.slice.Tensor %20, %int0, %int2, %int3, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%42 = torch.aten.item %41 : !torch.vtensor<[1],si64> -> !torch.int
%43 = torch.aten.eq.int %42, %int0 : !torch.int, !torch.int -> !torch.bool
%44 = torch.aten.Int.bool %43 : !torch.bool -> !torch.int
%45 = torch.prim.ListConstruct %30, %40, %42 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %14, %45 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?,?,?],si64>
%47 = torch.aten.unsqueeze %46, %int3 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,?,1],si64>
%48 = torch.aten.unsqueeze %arg3, %int2 : !torch.vtensor<[?,?,?],si64>, !torch.int -> !torch.vtensor<[?,?,1,?],si64>
%49 = torch.aten.to.dtype %47, %int11, %false, %false, %none : !torch.vtensor<[?,?,?,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,1],i1>
%50 = torch.aten.to.dtype %48, %int11, %false, %false, %none : !torch.vtensor<[?,?,1,?],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,1,?],i1>
%51 = torch.aten.logical_and %49, %50 : !torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1> -> !torch.vtensor<[?,?,?,?],i1>
%52 = torch.aten.logical_and %51, %1 : !torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1> -> !torch.vtensor<[?,?,128,384],i1>
return %52 : !torch.vtensor<[?,?,128,384],i1>
}
// -----// IR Dump After ConvertTMTensorToLinalgExtPass (torch-iree-tm-tensor-to-linalg-ext) //----- //
func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int4 = torch.constant.int 4
%1 = torch.vtensor.literal(dense<false> : tensor<1x1x128x384xi1>) : !torch.vtensor<[1,1,128,384],i1>
%false = torch.constant.bool false
%none = torch.constant.none
%int11 = torch.constant.int 11
%2 = torch.vtensor.literal(dense<8> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%3 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%4 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%int0 = torch.constant.int 0
%5 = torch.aten.slice.Tensor %arg4, %int0, %int0, %int1, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.item %5 : !torch.vtensor<[1],si64> -> !torch.int
%7 = torch.aten.slice.Tensor %arg4, %int0, %int1, %int2, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %arg4, %int0, %int2, %int3, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %arg4, %int0, %int3, %int4, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.prim.ListConstruct %8, %12, %6, %10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment