Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 5, 2023 05:16
Show Gist options
  • Save AmosLewis/259cd141333b33ec7df8a2f60eb9bf4b to your computer and use it in GitHub Desktop.
Save AmosLewis/259cd141333b33ec7df8a2f60eb9bf4b to your computer and use it in GitHub Desktop.
(turbine_venv) ➜ SHARK-Turbine git:(bump-iree) ✗ torch-mlir-opt tests/dynamo/llama_test.mlir -convert-torch-to-linalg
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map3 = affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map5 = affine_map<(d0, d1, d2) -> (d2)>
#map6 = affine_map<(d0, d1) -> (d0, d1)>
#map7 = affine_map<(d0, d1) -> (d1, d0)>
#map8 = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
#map9 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
#map10 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
#map12 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>
module {
func.func @main(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[64],f32>, %arg2: !torch.vtensor<[64],f32>, %arg3: !torch.vtensor<[64],f32>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[16,64],f32>, %arg6: !torch.vtensor<[64,64],f32>, %arg7: !torch.vtensor<[64,64],f32>, %arg8: !torch.vtensor<[64,64],f32>, %arg9: !torch.vtensor<[64,64],f32>, %arg10: !torch.vtensor<[256,64],f32>, %arg11: !torch.vtensor<[256,64],f32>, %arg12: !torch.vtensor<[64,256],f32>, %arg13: !torch.vtensor<[64,64],f32>, %arg14: !torch.vtensor<[64,64],f32>, %arg15: !torch.vtensor<[64,64],f32>, %arg16: !torch.vtensor<[64,64],f32>, %arg17: !torch.vtensor<[256,64],f32>, %arg18: !torch.vtensor<[256,64],f32>, %arg19: !torch.vtensor<[64,256],f32>, %arg20: !torch.vtensor<[16,64],f32>, %arg21: !torch.vtensor<[4096,8],complex<f32>>, %arg22: !torch.vtensor<[32,2048,4,16],f32>, %arg23: !torch.vtensor<[32,2048,4,16],f32>, %arg24: !torch.vtensor<[32,2048,4,16],f32>, %arg25: !torch.vtensor<[32,2048,4,16],f32>, %arg26: !torch.vtensor<[2,8],si64>) -> (!torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,8,16],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[2,8],si64>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,16],f32>, !torch.vtensor<[16,64],f32>) {
%0 = torch_c.to_builtin_tensor %arg22 : !torch.vtensor<[32,2048,4,16],f32> -> tensor<32x2048x4x16xf32>
%1 = torch_c.to_builtin_tensor %arg23 : !torch.vtensor<[32,2048,4,16],f32> -> tensor<32x2048x4x16xf32>
%2 = torch_c.to_builtin_tensor %arg24 : !torch.vtensor<[32,2048,4,16],f32> -> tensor<32x2048x4x16xf32>
%3 = torch_c.to_builtin_tensor %arg25 : !torch.vtensor<[32,2048,4,16],f32> -> tensor<32x2048x4x16xf32>
%4 = torch_c.to_builtin_tensor %arg5 : !torch.vtensor<[16,64],f32> -> tensor<16x64xf32>
%5 = torch_c.to_builtin_tensor %arg26 : !torch.vtensor<[2,8],si64> -> tensor<2x8xi64>
%6 = torch_c.to_builtin_tensor %arg21 : !torch.vtensor<[4096,8],complex<f32>> -> tensor<4096x8xcomplex<f32>>
%7 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[64],f32> -> tensor<64xf32>
%8 = torch_c.to_builtin_tensor %arg6 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%9 = torch_c.to_builtin_tensor %arg7 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%10 = torch_c.to_builtin_tensor %arg8 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%11 = torch_c.to_builtin_tensor %arg9 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%12 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[64],f32> -> tensor<64xf32>
%13 = torch_c.to_builtin_tensor %arg10 : !torch.vtensor<[256,64],f32> -> tensor<256x64xf32>
%14 = torch_c.to_builtin_tensor %arg11 : !torch.vtensor<[256,64],f32> -> tensor<256x64xf32>
%15 = torch_c.to_builtin_tensor %arg12 : !torch.vtensor<[64,256],f32> -> tensor<64x256xf32>
%16 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[64],f32> -> tensor<64xf32>
%17 = torch_c.to_builtin_tensor %arg13 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%18 = torch_c.to_builtin_tensor %arg14 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%19 = torch_c.to_builtin_tensor %arg15 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%20 = torch_c.to_builtin_tensor %arg16 : !torch.vtensor<[64,64],f32> -> tensor<64x64xf32>
%21 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[64],f32> -> tensor<64xf32>
%22 = torch_c.to_builtin_tensor %arg17 : !torch.vtensor<[256,64],f32> -> tensor<256x64xf32>
%23 = torch_c.to_builtin_tensor %arg18 : !torch.vtensor<[256,64],f32> -> tensor<256x64xf32>
%24 = torch_c.to_builtin_tensor %arg19 : !torch.vtensor<[64,256],f32> -> tensor<64x256xf32>
%25 = torch_c.to_builtin_tensor %arg4 : !torch.vtensor<[64],f32> -> tensor<64xf32>
%26 = torch_c.to_builtin_tensor %arg20 : !torch.vtensor<[16,64],f32> -> tensor<16x64xf32>
%none = torch.constant.none
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1_0 = arith.constant 1 : index
%c2048 = arith.constant 2048 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c16 = arith.constant 16 : index
%27 = tensor.empty() : tensor<32x2048x4x16xf32>
%28 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<32x2048x4x16xf32>) outs(%27 : tensor<32x2048x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x2048x4x16xf32>
%cast = tensor.cast %28 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%none_1 = torch.constant.none
%c1_2 = arith.constant 1 : index
%c0_3 = arith.constant 0 : index
%c32_4 = arith.constant 32 : index
%c1_5 = arith.constant 1 : index
%c2048_6 = arith.constant 2048 : index
%c2_7 = arith.constant 2 : index
%c4_8 = arith.constant 4 : index
%c3_9 = arith.constant 3 : index
%c16_10 = arith.constant 16 : index
%29 = tensor.empty() : tensor<32x2048x4x16xf32>
%30 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<32x2048x4x16xf32>) outs(%29 : tensor<32x2048x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x2048x4x16xf32>
%cast_11 = tensor.cast %30 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%none_12 = torch.constant.none
%c1_13 = arith.constant 1 : index
%c0_14 = arith.constant 0 : index
%c32_15 = arith.constant 32 : index
%c1_16 = arith.constant 1 : index
%c2048_17 = arith.constant 2048 : index
%c2_18 = arith.constant 2 : index
%c4_19 = arith.constant 4 : index
%c3_20 = arith.constant 3 : index
%c16_21 = arith.constant 16 : index
%31 = tensor.empty() : tensor<32x2048x4x16xf32>
%32 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<32x2048x4x16xf32>) outs(%31 : tensor<32x2048x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x2048x4x16xf32>
%cast_22 = tensor.cast %32 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%none_23 = torch.constant.none
%c1_24 = arith.constant 1 : index
%c0_25 = arith.constant 0 : index
%c32_26 = arith.constant 32 : index
%c1_27 = arith.constant 1 : index
%c2048_28 = arith.constant 2048 : index
%c2_29 = arith.constant 2 : index
%c4_30 = arith.constant 4 : index
%c3_31 = arith.constant 3 : index
%c16_32 = arith.constant 16 : index
%33 = tensor.empty() : tensor<32x2048x4x16xf32>
%34 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3 : tensor<32x2048x4x16xf32>) outs(%33 : tensor<32x2048x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x2048x4x16xf32>
%cast_33 = tensor.cast %34 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%false_34 = torch.constant.bool false
%c1_35 = arith.constant 1 : index
%c64 = arith.constant 64 : index
%c0_36 = arith.constant 0 : index
%c2_37 = arith.constant 2 : index
%c1_38 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%35 = tensor.empty() : tensor<2x8x64xf32>
%36 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor<2x8xi64>) outs(%35 : tensor<2x8x64xf32>) {
^bb0(%in: i64, %out: f32):
%1543 = arith.index_cast %in : i64 to index
%1544 = linalg.index 2 : index
%1545 = arith.index_cast %in : i64 to index
%c0_2549 = arith.constant 0 : index
%c16_2550 = arith.constant 16 : index
%1546 = arith.cmpi slt, %1545, %c16_2550 : index
cf.assert %1546, "index must be smaller than dim size"
%c0_i64_2551 = arith.constant 0 : i64
%1547 = arith.cmpi sge, %in, %c0_i64_2551 : i64
cf.assert %1547, "index must be larger or equal to 0"
%extracted = tensor.extract %4[%1543, %1544] : tensor<16x64xf32>
linalg.yield %extracted : f32
} -> tensor<2x8x64xf32>
%cast_39 = tensor.cast %36 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%37 = torch_c.from_builtin_tensor %cast_39 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int0 = torch.constant.int 0
%int0_40 = torch.constant.int 0
%38 = torch_c.to_i64 %int0_40
%int8 = torch.constant.int 8
%39 = torch_c.to_i64 %int8
%int1 = torch.constant.int 1
%c0_41 = arith.constant 0 : index
%c1_42 = arith.constant 1 : index
%c0_43 = arith.constant 0 : index
%c4096 = arith.constant 4096 : index
%c1_44 = arith.constant 1 : index
%c8_45 = arith.constant 8 : index
%40 = arith.index_cast %c4096 : index to i64
%41 = arith.addi %38, %40 : i64
%c0_i64 = arith.constant 0 : i64
%42 = arith.cmpi sge, %38, %c0_i64 : i64
%43 = arith.select %42, %38, %41 : i64
%c0_i64_46 = arith.constant 0 : i64
%44 = arith.cmpi slt, %43, %c0_i64_46 : i64
%45 = arith.select %44, %c0_i64_46, %43 : i64
%46 = arith.cmpi sgt, %45, %40 : i64
%47 = arith.select %46, %40, %45 : i64
%48 = arith.index_cast %47 : i64 to index
%49 = arith.index_cast %c4096 : index to i64
%50 = arith.addi %39, %49 : i64
%c0_i64_47 = arith.constant 0 : i64
%51 = arith.cmpi sge, %39, %c0_i64_47 : i64
%52 = arith.select %51, %39, %50 : i64
%c0_i64_48 = arith.constant 0 : i64
%53 = arith.cmpi slt, %52, %c0_i64_48 : i64
%54 = arith.select %53, %c0_i64_48, %52 : i64
%55 = arith.cmpi sgt, %54, %49 : i64
%56 = arith.select %55, %49, %54 : i64
%57 = arith.index_cast %56 : i64 to index
%58 = arith.cmpi sge, %57, %48 : index
%59 = arith.select %58, %57, %48 : index
%c1_49 = arith.constant 1 : index
%c0_50 = arith.constant 0 : index
%c4096_51 = arith.constant 4096 : index
%c1_52 = arith.constant 1 : index
%c8_53 = arith.constant 8 : index
%60 = arith.subi %59, %48 : index
%61 = arith.addi %60, %c1_49 : index
%62 = arith.subi %61, %c1_42 : index
%63 = arith.floordivsi %62, %c1_49 : index
%64 = arith.muli %c1_42, %c1_49 : index
%extracted_slice = tensor.extract_slice %6[%48, %c0_41] [%63, %c8_53] [%64, %c1_42] : tensor<4096x8xcomplex<f32>> to tensor<?x?xcomplex<f32>>
%cast_54 = tensor.cast %extracted_slice : tensor<?x?xcomplex<f32>> to tensor<8x8xcomplex<f32>>
%int1_55 = torch.constant.int 1
%int1_56 = torch.constant.int 1
%int8_57 = torch.constant.int 8
%int8_58 = torch.constant.int 8
%65 = torch.prim.ListConstruct %int1_55, %int1_56, %int8_57, %int8_58 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int64 = torch.constant.int 64
%int64_59 = torch.constant.int 64
%int8_60 = torch.constant.int 8
%int1_61 = torch.constant.int 1
%66 = torch.prim.ListConstruct %int64, %int64_59, %int8_60, %int1_61 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int6 = torch.constant.int 6
%int0_62 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%false_63 = torch.constant.bool false
%67 = torch.aten.empty_strided %65, %66, %int6, %int0_62, %cpu, %false_63 : !torch.list<int>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,1,8,8],f32>
%68 = torch_c.to_builtin_tensor %67 : !torch.vtensor<[1,1,8,8],f32> -> tensor<1x1x8x8xf32>
%float-Inf = torch.constant.float 0xFFF0000000000000
%69 = torch_c.to_f64 %float-Inf
%c1_64 = arith.constant 1 : index
%c2_65 = arith.constant 2 : index
%c8_66 = arith.constant 8 : index
%c3_67 = arith.constant 3 : index
%c8_68 = arith.constant 8 : index
%70 = tensor.empty() : tensor<1x1x8x8xf32>
%71 = linalg.generic {indexing_maps = [#map3, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%68 : tensor<1x1x8x8xf32>) outs(%70 : tensor<1x1x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %69 : f64 to f32
linalg.yield %1543 : f32
} -> tensor<1x1x8x8xf32>
%cast_69 = tensor.cast %71 : tensor<1x1x8x8xf32> to tensor<1x1x8x8xf32>
%int1_70 = torch.constant.int 1
%72 = torch_c.to_i64 %int1_70
%c1_71 = arith.constant 1 : index
%c2_72 = arith.constant 2 : index
%c8_73 = arith.constant 8 : index
%c3_74 = arith.constant 3 : index
%c8_75 = arith.constant 8 : index
%73 = tensor.empty() : tensor<1x1x8x8xf32>
%74 = linalg.generic {indexing_maps = [#map3, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_69 : tensor<1x1x8x8xf32>) outs(%73 : tensor<1x1x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = linalg.index 2 : index
%1544 = arith.index_cast %1543 : index to i64
%1545 = linalg.index 3 : index
%1546 = arith.index_cast %1545 : index to i64
%1547 = arith.addi %1544, %72 : i64
%1548 = arith.cmpi sge, %1546, %1547 : i64
%cst_2549 = arith.constant 0.000000e+00 : f32
%1549 = arith.select %1548, %in, %cst_2549 : f32
linalg.yield %1549 : f32
} -> tensor<1x1x8x8xf32>
%cast_76 = tensor.cast %74 : tensor<1x1x8x8xf32> to tensor<1x1x8x8xf32>
%int2 = torch.constant.int 2
%75 = torch_c.to_i64 %int2
%c1_77 = arith.constant 1 : index
%c0_78 = arith.constant 0 : index
%c2_79 = arith.constant 2 : index
%c1_80 = arith.constant 1 : index
%c8_81 = arith.constant 8 : index
%c2_82 = arith.constant 2 : index
%c64_83 = arith.constant 64 : index
%76 = tensor.empty() : tensor<2x8x64xf32>
%77 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_39 : tensor<2x8x64xf32>) outs(%76 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.sitofp %75 : i64 to f32
%1544 = math.powf %in, %1543 : f32
linalg.yield %1544 : f32
} -> tensor<2x8x64xf32>
%cast_84 = tensor.cast %77 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%78 = torch_c.from_builtin_tensor %cast_84 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int-1_85 = torch.constant.int -1
%79 = torch.prim.ListConstruct %int-1_85 : (!torch.int) -> !torch.list<int>
%true = torch.constant.bool true
%none_86 = torch.constant.none
%80 = torch.aten.mean.dim %78, %79, %true, %none_86 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%81 = torch_c.to_builtin_tensor %80 : !torch.vtensor<[2,8,1],f32> -> tensor<2x8x1xf32>
%float1.000000e-05 = torch.constant.float 1.000000e-05
%82 = torch_c.to_f64 %float1.000000e-05
%int1_87 = torch.constant.int 1
%83 = torch_c.to_i64 %int1_87
%c1_88 = arith.constant 1 : index
%c0_89 = arith.constant 0 : index
%c2_90 = arith.constant 2 : index
%c1_91 = arith.constant 1 : index
%c8_92 = arith.constant 8 : index
%84 = tensor.empty() : tensor<2x8x1xf32>
%85 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%81 : tensor<2x8x1xf32>) outs(%84 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %82 : f64 to f32
%1544 = arith.sitofp %83 : i64 to f32
%1545 = arith.mulf %1543, %1544 : f32
%1546 = arith.addf %in, %1545 : f32
linalg.yield %1546 : f32
} -> tensor<2x8x1xf32>
%cast_93 = tensor.cast %85 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%c1_94 = arith.constant 1 : index
%c0_95 = arith.constant 0 : index
%c2_96 = arith.constant 2 : index
%c1_97 = arith.constant 1 : index
%c8_98 = arith.constant 8 : index
%86 = tensor.empty() : tensor<2x8x1xf32>
%87 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_93 : tensor<2x8x1xf32>) outs(%86 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = math.rsqrt %in : f32
linalg.yield %1543 : f32
} -> tensor<2x8x1xf32>
%cast_99 = tensor.cast %87 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%88 = torch_c.from_builtin_tensor %cast_99 : tensor<2x8x1xf32> -> !torch.vtensor<[2,8,1],f32>
%c1_100 = arith.constant 1 : index
%c0_101 = arith.constant 0 : index
%c2_102 = arith.constant 2 : index
%c1_103 = arith.constant 1 : index
%c8_104 = arith.constant 8 : index
%c2_105 = arith.constant 2 : index
%c64_106 = arith.constant 64 : index
%c0_107 = arith.constant 0 : index
%c2_108 = arith.constant 2 : index
%89 = arith.cmpi eq, %c2_102, %c2_108 : index
cf.assert %89, "mismatched size for broadcast"
%c1_109 = arith.constant 1 : index
%c8_110 = arith.constant 8 : index
%90 = arith.cmpi eq, %c8_104, %c8_110 : index
cf.assert %90, "mismatched size for broadcast"
%91 = tensor.empty() : tensor<2x8x64xf32>
%92 = linalg.generic {indexing_maps = [#map2, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_39, %cast_99 : tensor<2x8x64xf32>, tensor<2x8x1xf32>) outs(%91 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_111 = tensor.cast %92 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%93 = torch_c.from_builtin_tensor %cast_111 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%c1_112 = arith.constant 1 : index
%c0_113 = arith.constant 0 : index
%c2_114 = arith.constant 2 : index
%c1_115 = arith.constant 1 : index
%c8_116 = arith.constant 8 : index
%c2_117 = arith.constant 2 : index
%c64_118 = arith.constant 64 : index
%c0_119 = arith.constant 0 : index
%c64_120 = arith.constant 64 : index
%94 = arith.cmpi eq, %c64_118, %c64_120 : index
cf.assert %94, "mismatched size for broadcast"
%95 = tensor.empty() : tensor<2x8x64xf32>
%96 = linalg.generic {indexing_maps = [#map2, #map5, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_111, %7 : tensor<2x8x64xf32>, tensor<64xf32>) outs(%95 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_121 = tensor.cast %96 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%int0_122 = torch.constant.int 0
%int1_123 = torch.constant.int 1
%c0_124 = arith.constant 0 : index
%c64_125 = arith.constant 64 : index
%c1_126 = arith.constant 1 : index
%c64_127 = arith.constant 64 : index
%97 = tensor.empty() : tensor<64x64xf32>
%98 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<64x64xf32>) outs(%97 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_128 = tensor.cast %98 : tensor<64x64xf32> to tensor<64x64xf32>
%99 = torch_c.from_builtin_tensor %cast_128 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16 = torch.constant.int 16
%int64_129 = torch.constant.int 64
%100 = torch.prim.ListConstruct %int16, %int64_129 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_130 = arith.constant 0 : index
%c2_131 = arith.constant 2 : index
%c1_132 = arith.constant 1 : index
%c8_133 = arith.constant 8 : index
%c2_134 = arith.constant 2 : index
%c64_135 = arith.constant 64 : index
%101 = torch_c.to_i64 %int16
%102 = torch_c.to_i64 %int64_129
%collapsed = tensor.collapse_shape %cast_121 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%103 = torch_c.from_builtin_tensor %collapsed : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_136 = arith.constant 0 : index
%dim = tensor.dim %collapsed, %c0_136 : tensor<16x64xf32>
%c1_137 = arith.constant 1 : index
%dim_138 = tensor.dim %cast_128, %c1_137 : tensor<64x64xf32>
%c1_139 = arith.constant 1 : index
%dim_140 = tensor.dim %collapsed, %c1_139 : tensor<16x64xf32>
%c0_141 = arith.constant 0 : index
%dim_142 = tensor.dim %cast_128, %c0_141 : tensor<64x64xf32>
%104 = arith.cmpi eq, %dim_140, %dim_142 : index
cf.assert %104, "mismatching contracting dimension for torch.aten.mm"
%105 = tensor.empty(%dim, %dim_138) : tensor<?x?xf32>
%cst = arith.constant 0.000000e+00 : f32
%106 = linalg.fill ins(%cst : f32) outs(%105 : tensor<?x?xf32>) -> tensor<?x?xf32>
%107 = linalg.matmul ins(%collapsed, %cast_128 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%106 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_143 = tensor.cast %107 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_144 = torch.constant.int 2
%int8_145 = torch.constant.int 8
%int64_146 = torch.constant.int 64
%108 = torch.prim.ListConstruct %int2_144, %int8_145, %int64_146 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_147 = arith.constant 0 : index
%c16_148 = arith.constant 16 : index
%c1_149 = arith.constant 1 : index
%c64_150 = arith.constant 64 : index
%109 = torch_c.to_i64 %int2_144
%110 = torch_c.to_i64 %int8_145
%111 = torch_c.to_i64 %int64_146
%expanded = tensor.expand_shape %cast_143 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int0_151 = torch.constant.int 0
%int1_152 = torch.constant.int 1
%c0_153 = arith.constant 0 : index
%c64_154 = arith.constant 64 : index
%c1_155 = arith.constant 1 : index
%c64_156 = arith.constant 64 : index
%112 = tensor.empty() : tensor<64x64xf32>
%113 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%9 : tensor<64x64xf32>) outs(%112 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_157 = tensor.cast %113 : tensor<64x64xf32> to tensor<64x64xf32>
%114 = torch_c.from_builtin_tensor %cast_157 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16_158 = torch.constant.int 16
%int64_159 = torch.constant.int 64
%115 = torch.prim.ListConstruct %int16_158, %int64_159 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_160 = arith.constant 0 : index
%c2_161 = arith.constant 2 : index
%c1_162 = arith.constant 1 : index
%c8_163 = arith.constant 8 : index
%c2_164 = arith.constant 2 : index
%c64_165 = arith.constant 64 : index
%116 = torch_c.to_i64 %int16_158
%117 = torch_c.to_i64 %int64_159
%collapsed_166 = tensor.collapse_shape %cast_121 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%118 = torch_c.from_builtin_tensor %collapsed_166 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_167 = arith.constant 0 : index
%dim_168 = tensor.dim %collapsed_166, %c0_167 : tensor<16x64xf32>
%c1_169 = arith.constant 1 : index
%dim_170 = tensor.dim %cast_157, %c1_169 : tensor<64x64xf32>
%c1_171 = arith.constant 1 : index
%dim_172 = tensor.dim %collapsed_166, %c1_171 : tensor<16x64xf32>
%c0_173 = arith.constant 0 : index
%dim_174 = tensor.dim %cast_157, %c0_173 : tensor<64x64xf32>
%119 = arith.cmpi eq, %dim_172, %dim_174 : index
cf.assert %119, "mismatching contracting dimension for torch.aten.mm"
%120 = tensor.empty(%dim_168, %dim_170) : tensor<?x?xf32>
%cst_175 = arith.constant 0.000000e+00 : f32
%121 = linalg.fill ins(%cst_175 : f32) outs(%120 : tensor<?x?xf32>) -> tensor<?x?xf32>
%122 = linalg.matmul ins(%collapsed_166, %cast_157 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%121 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_176 = tensor.cast %122 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_177 = torch.constant.int 2
%int8_178 = torch.constant.int 8
%int64_179 = torch.constant.int 64
%123 = torch.prim.ListConstruct %int2_177, %int8_178, %int64_179 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_180 = arith.constant 0 : index
%c16_181 = arith.constant 16 : index
%c1_182 = arith.constant 1 : index
%c64_183 = arith.constant 64 : index
%124 = torch_c.to_i64 %int2_177
%125 = torch_c.to_i64 %int8_178
%126 = torch_c.to_i64 %int64_179
%expanded_184 = tensor.expand_shape %cast_176 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int0_185 = torch.constant.int 0
%int1_186 = torch.constant.int 1
%c0_187 = arith.constant 0 : index
%c64_188 = arith.constant 64 : index
%c1_189 = arith.constant 1 : index
%c64_190 = arith.constant 64 : index
%127 = tensor.empty() : tensor<64x64xf32>
%128 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%10 : tensor<64x64xf32>) outs(%127 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_191 = tensor.cast %128 : tensor<64x64xf32> to tensor<64x64xf32>
%129 = torch_c.from_builtin_tensor %cast_191 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16_192 = torch.constant.int 16
%int64_193 = torch.constant.int 64
%130 = torch.prim.ListConstruct %int16_192, %int64_193 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_194 = arith.constant 0 : index
%c2_195 = arith.constant 2 : index
%c1_196 = arith.constant 1 : index
%c8_197 = arith.constant 8 : index
%c2_198 = arith.constant 2 : index
%c64_199 = arith.constant 64 : index
%131 = torch_c.to_i64 %int16_192
%132 = torch_c.to_i64 %int64_193
%collapsed_200 = tensor.collapse_shape %cast_121 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%133 = torch_c.from_builtin_tensor %collapsed_200 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_201 = arith.constant 0 : index
%dim_202 = tensor.dim %collapsed_200, %c0_201 : tensor<16x64xf32>
%c1_203 = arith.constant 1 : index
%dim_204 = tensor.dim %cast_191, %c1_203 : tensor<64x64xf32>
%c1_205 = arith.constant 1 : index
%dim_206 = tensor.dim %collapsed_200, %c1_205 : tensor<16x64xf32>
%c0_207 = arith.constant 0 : index
%dim_208 = tensor.dim %cast_191, %c0_207 : tensor<64x64xf32>
%134 = arith.cmpi eq, %dim_206, %dim_208 : index
cf.assert %134, "mismatching contracting dimension for torch.aten.mm"
%135 = tensor.empty(%dim_202, %dim_204) : tensor<?x?xf32>
%cst_209 = arith.constant 0.000000e+00 : f32
%136 = linalg.fill ins(%cst_209 : f32) outs(%135 : tensor<?x?xf32>) -> tensor<?x?xf32>
%137 = linalg.matmul ins(%collapsed_200, %cast_191 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%136 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_210 = tensor.cast %137 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_211 = torch.constant.int 2
%int8_212 = torch.constant.int 8
%int64_213 = torch.constant.int 64
%138 = torch.prim.ListConstruct %int2_211, %int8_212, %int64_213 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_214 = arith.constant 0 : index
%c16_215 = arith.constant 16 : index
%c1_216 = arith.constant 1 : index
%c64_217 = arith.constant 64 : index
%139 = torch_c.to_i64 %int2_211
%140 = torch_c.to_i64 %int8_212
%141 = torch_c.to_i64 %int64_213
%expanded_218 = tensor.expand_shape %cast_210 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int2_219 = torch.constant.int 2
%int8_220 = torch.constant.int 8
%int4 = torch.constant.int 4
%int16_221 = torch.constant.int 16
%142 = torch.prim.ListConstruct %int2_219, %int8_220, %int4, %int16_221 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_222 = arith.constant 0 : index
%c2_223 = arith.constant 2 : index
%c1_224 = arith.constant 1 : index
%c8_225 = arith.constant 8 : index
%c2_226 = arith.constant 2 : index
%c64_227 = arith.constant 64 : index
%143 = torch_c.to_i64 %int2_219
%144 = torch_c.to_i64 %int8_220
%145 = torch_c.to_i64 %int4
%146 = torch_c.to_i64 %int16_221
%expanded_228 = tensor.expand_shape %expanded [[0], [1], [2, 3]] : tensor<2x8x64xf32> into tensor<2x8x4x16xf32>
%int2_229 = torch.constant.int 2
%int8_230 = torch.constant.int 8
%int4_231 = torch.constant.int 4
%int16_232 = torch.constant.int 16
%147 = torch.prim.ListConstruct %int2_229, %int8_230, %int4_231, %int16_232 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_233 = arith.constant 0 : index
%c2_234 = arith.constant 2 : index
%c1_235 = arith.constant 1 : index
%c8_236 = arith.constant 8 : index
%c2_237 = arith.constant 2 : index
%c64_238 = arith.constant 64 : index
%148 = torch_c.to_i64 %int2_229
%149 = torch_c.to_i64 %int8_230
%150 = torch_c.to_i64 %int4_231
%151 = torch_c.to_i64 %int16_232
%expanded_239 = tensor.expand_shape %expanded_184 [[0], [1], [2, 3]] : tensor<2x8x64xf32> into tensor<2x8x4x16xf32>
%int2_240 = torch.constant.int 2
%int8_241 = torch.constant.int 8
%int4_242 = torch.constant.int 4
%int16_243 = torch.constant.int 16
%152 = torch.prim.ListConstruct %int2_240, %int8_241, %int4_242, %int16_243 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_244 = arith.constant 0 : index
%c2_245 = arith.constant 2 : index
%c1_246 = arith.constant 1 : index
%c8_247 = arith.constant 8 : index
%c2_248 = arith.constant 2 : index
%c64_249 = arith.constant 64 : index
%153 = torch_c.to_i64 %int2_240
%154 = torch_c.to_i64 %int8_241
%155 = torch_c.to_i64 %int4_242
%156 = torch_c.to_i64 %int16_243
%expanded_250 = tensor.expand_shape %expanded_218 [[0], [1], [2, 3]] : tensor<2x8x64xf32> into tensor<2x8x4x16xf32>
%int2_251 = torch.constant.int 2
%int8_252 = torch.constant.int 8
%int4_253 = torch.constant.int 4
%int-1_254 = torch.constant.int -1
%int2_255 = torch.constant.int 2
%157 = torch.prim.ListConstruct %int2_251, %int8_252, %int4_253, %int-1_254, %int2_255 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_256 = arith.constant 0 : index
%c2_257 = arith.constant 2 : index
%c1_258 = arith.constant 1 : index
%c8_259 = arith.constant 8 : index
%c2_260 = arith.constant 2 : index
%c4_261 = arith.constant 4 : index
%c3_262 = arith.constant 3 : index
%c16_263 = arith.constant 16 : index
%158 = torch_c.to_i64 %int2_251
%159 = torch_c.to_i64 %int8_252
%160 = torch_c.to_i64 %int4_253
%161 = torch_c.to_i64 %int-1_254
%162 = torch_c.to_i64 %int2_255
%expanded_264 = tensor.expand_shape %expanded_228 [[0], [1], [2], [3, 4]] : tensor<2x8x4x16xf32> into tensor<2x8x4x8x2xf32>
%c0_265 = arith.constant 0 : index
%dim_266 = tensor.dim %expanded_264, %c0_265 : tensor<2x8x4x8x2xf32>
%c1_267 = arith.constant 1 : index
%dim_268 = tensor.dim %expanded_264, %c1_267 : tensor<2x8x4x8x2xf32>
%c2_269 = arith.constant 2 : index
%dim_270 = tensor.dim %expanded_264, %c2_269 : tensor<2x8x4x8x2xf32>
%c3_271 = arith.constant 3 : index
%dim_272 = tensor.dim %expanded_264, %c3_271 : tensor<2x8x4x8x2xf32>
%163 = tensor.empty(%dim_266, %dim_268, %dim_270, %dim_272) : tensor<?x?x?x?xcomplex<f32>>
%c0_273 = arith.constant 0 : index
%c1_274 = arith.constant 1 : index
%164 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%163 : tensor<?x?x?x?xcomplex<f32>>) {
^bb0(%out: complex<f32>):
%1543 = linalg.index 0 : index
%1544 = linalg.index 0 : index
%1545 = linalg.index 1 : index
%1546 = linalg.index 1 : index
%1547 = linalg.index 2 : index
%1548 = linalg.index 2 : index
%1549 = linalg.index 3 : index
%1550 = linalg.index 3 : index
%extracted = tensor.extract %expanded_264[%1543, %1545, %1547, %1549, %c0_273] : tensor<2x8x4x8x2xf32>
%extracted_2549 = tensor.extract %expanded_264[%1544, %1546, %1548, %1550, %c1_274] : tensor<2x8x4x8x2xf32>
%1551 = complex.create %extracted, %extracted_2549 : complex<f32>
linalg.yield %1551 : complex<f32>
} -> tensor<?x?x?x?xcomplex<f32>>
%cast_275 = tensor.cast %164 : tensor<?x?x?x?xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%int2_276 = torch.constant.int 2
%int8_277 = torch.constant.int 8
%int4_278 = torch.constant.int 4
%int-1_279 = torch.constant.int -1
%int2_280 = torch.constant.int 2
%165 = torch.prim.ListConstruct %int2_276, %int8_277, %int4_278, %int-1_279, %int2_280 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_281 = arith.constant 0 : index
%c2_282 = arith.constant 2 : index
%c1_283 = arith.constant 1 : index
%c8_284 = arith.constant 8 : index
%c2_285 = arith.constant 2 : index
%c4_286 = arith.constant 4 : index
%c3_287 = arith.constant 3 : index
%c16_288 = arith.constant 16 : index
%166 = torch_c.to_i64 %int2_276
%167 = torch_c.to_i64 %int8_277
%168 = torch_c.to_i64 %int4_278
%169 = torch_c.to_i64 %int-1_279
%170 = torch_c.to_i64 %int2_280
%expanded_289 = tensor.expand_shape %expanded_239 [[0], [1], [2], [3, 4]] : tensor<2x8x4x16xf32> into tensor<2x8x4x8x2xf32>
%c0_290 = arith.constant 0 : index
%dim_291 = tensor.dim %expanded_289, %c0_290 : tensor<2x8x4x8x2xf32>
%c1_292 = arith.constant 1 : index
%dim_293 = tensor.dim %expanded_289, %c1_292 : tensor<2x8x4x8x2xf32>
%c2_294 = arith.constant 2 : index
%dim_295 = tensor.dim %expanded_289, %c2_294 : tensor<2x8x4x8x2xf32>
%c3_296 = arith.constant 3 : index
%dim_297 = tensor.dim %expanded_289, %c3_296 : tensor<2x8x4x8x2xf32>
%171 = tensor.empty(%dim_291, %dim_293, %dim_295, %dim_297) : tensor<?x?x?x?xcomplex<f32>>
%c0_298 = arith.constant 0 : index
%c1_299 = arith.constant 1 : index
%172 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%171 : tensor<?x?x?x?xcomplex<f32>>) {
^bb0(%out: complex<f32>):
%1543 = linalg.index 0 : index
%1544 = linalg.index 0 : index
%1545 = linalg.index 1 : index
%1546 = linalg.index 1 : index
%1547 = linalg.index 2 : index
%1548 = linalg.index 2 : index
%1549 = linalg.index 3 : index
%1550 = linalg.index 3 : index
%extracted = tensor.extract %expanded_289[%1543, %1545, %1547, %1549, %c0_298] : tensor<2x8x4x8x2xf32>
%extracted_2549 = tensor.extract %expanded_289[%1544, %1546, %1548, %1550, %c1_299] : tensor<2x8x4x8x2xf32>
%1551 = complex.create %extracted, %extracted_2549 : complex<f32>
linalg.yield %1551 : complex<f32>
} -> tensor<?x?x?x?xcomplex<f32>>
%cast_300 = tensor.cast %172 : tensor<?x?x?x?xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%int1_301 = torch.constant.int 1
%int8_302 = torch.constant.int 8
%int1_303 = torch.constant.int 1
%int8_304 = torch.constant.int 8
%173 = torch.prim.ListConstruct %int1_301, %int8_302, %int1_303, %int8_304 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_305 = arith.constant 0 : index
%c8_306 = arith.constant 8 : index
%c1_307 = arith.constant 1 : index
%c8_308 = arith.constant 8 : index
%174 = torch_c.to_i64 %int1_301
%175 = torch_c.to_i64 %int8_302
%176 = torch_c.to_i64 %int1_303
%177 = torch_c.to_i64 %int8_304
%expanded_309 = tensor.expand_shape %cast_54 [[0, 1], [2, 3]] : tensor<8x8xcomplex<f32>> into tensor<1x8x1x8xcomplex<f32>>
%178 = torch_c.from_builtin_tensor %expanded_309 : tensor<1x8x1x8xcomplex<f32>> -> !torch.vtensor<[1,8,1,8],complex<f32>>
%c1_310 = arith.constant 1 : index
%c0_311 = arith.constant 0 : index
%c2_312 = arith.constant 2 : index
%c1_313 = arith.constant 1 : index
%c8_314 = arith.constant 8 : index
%c2_315 = arith.constant 2 : index
%c4_316 = arith.constant 4 : index
%c3_317 = arith.constant 3 : index
%c8_318 = arith.constant 8 : index
%c1_319 = arith.constant 1 : index
%c8_320 = arith.constant 8 : index
%179 = arith.cmpi eq, %c8_314, %c8_320 : index
cf.assert %179, "mismatched size for broadcast"
%c3_321 = arith.constant 3 : index
%c8_322 = arith.constant 8 : index
%180 = arith.cmpi eq, %c8_318, %c8_322 : index
cf.assert %180, "mismatched size for broadcast"
%181 = tensor.empty() : tensor<2x8x4x8xcomplex<f32>>
%182 = linalg.generic {indexing_maps = [#map, #map8, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_275, %expanded_309 : tensor<2x8x4x8xcomplex<f32>>, tensor<1x8x1x8xcomplex<f32>>) outs(%181 : tensor<2x8x4x8xcomplex<f32>>) {
^bb0(%in: complex<f32>, %in_2549: complex<f32>, %out: complex<f32>):
%1543 = complex.mul %in, %in_2549 : complex<f32>
linalg.yield %1543 : complex<f32>
} -> tensor<2x8x4x8xcomplex<f32>>
%cast_323 = tensor.cast %182 : tensor<2x8x4x8xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%c2_324 = arith.constant 2 : index
%183 = tensor.empty(%c2_324) : tensor<2x8x4x8x?xf32>
%c0_325 = arith.constant 0 : index
%184 = linalg.generic {indexing_maps = [#map9, #map10], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cast_323 : tensor<2x8x4x8xcomplex<f32>>) outs(%183 : tensor<2x8x4x8x?xf32>) {
^bb0(%in: complex<f32>, %out: f32):
%1543 = complex.re %in : complex<f32>
%1544 = complex.im %in : complex<f32>
%1545 = linalg.index 4 : index
%1546 = arith.cmpi eq, %1545, %c0_325 : index
%1547 = arith.select %1546, %1543, %1544 : f32
linalg.yield %1547 : f32
} -> tensor<2x8x4x8x?xf32>
%cast_326 = tensor.cast %184 : tensor<2x8x4x8x?xf32> to tensor<2x8x4x8x2xf32>
%int2_327 = torch.constant.int 2
%int8_328 = torch.constant.int 8
%int4_329 = torch.constant.int 4
%int16_330 = torch.constant.int 16
%185 = torch.prim.ListConstruct %int2_327, %int8_328, %int4_329, %int16_330 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_331 = arith.constant 0 : index
%c2_332 = arith.constant 2 : index
%c1_333 = arith.constant 1 : index
%c8_334 = arith.constant 8 : index
%c2_335 = arith.constant 2 : index
%c4_336 = arith.constant 4 : index
%c3_337 = arith.constant 3 : index
%c8_338 = arith.constant 8 : index
%c4_339 = arith.constant 4 : index
%c2_340 = arith.constant 2 : index
%186 = torch_c.to_i64 %int2_327
%187 = torch_c.to_i64 %int8_328
%188 = torch_c.to_i64 %int4_329
%189 = torch_c.to_i64 %int16_330
%collapsed_341 = tensor.collapse_shape %cast_326 [[0], [1], [2], [3, 4]] : tensor<2x8x4x8x2xf32> into tensor<2x8x4x16xf32>
%c1_342 = arith.constant 1 : index
%c0_343 = arith.constant 0 : index
%c2_344 = arith.constant 2 : index
%c1_345 = arith.constant 1 : index
%c8_346 = arith.constant 8 : index
%c2_347 = arith.constant 2 : index
%c4_348 = arith.constant 4 : index
%c3_349 = arith.constant 3 : index
%c8_350 = arith.constant 8 : index
%c1_351 = arith.constant 1 : index
%c8_352 = arith.constant 8 : index
%190 = arith.cmpi eq, %c8_346, %c8_352 : index
cf.assert %190, "mismatched size for broadcast"
%c3_353 = arith.constant 3 : index
%c8_354 = arith.constant 8 : index
%191 = arith.cmpi eq, %c8_350, %c8_354 : index
cf.assert %191, "mismatched size for broadcast"
%192 = tensor.empty() : tensor<2x8x4x8xcomplex<f32>>
%193 = linalg.generic {indexing_maps = [#map, #map8, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_300, %expanded_309 : tensor<2x8x4x8xcomplex<f32>>, tensor<1x8x1x8xcomplex<f32>>) outs(%192 : tensor<2x8x4x8xcomplex<f32>>) {
^bb0(%in: complex<f32>, %in_2549: complex<f32>, %out: complex<f32>):
%1543 = complex.mul %in, %in_2549 : complex<f32>
linalg.yield %1543 : complex<f32>
} -> tensor<2x8x4x8xcomplex<f32>>
%cast_355 = tensor.cast %193 : tensor<2x8x4x8xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%c2_356 = arith.constant 2 : index
%194 = tensor.empty(%c2_356) : tensor<2x8x4x8x?xf32>
%c0_357 = arith.constant 0 : index
%195 = linalg.generic {indexing_maps = [#map9, #map10], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cast_355 : tensor<2x8x4x8xcomplex<f32>>) outs(%194 : tensor<2x8x4x8x?xf32>) {
^bb0(%in: complex<f32>, %out: f32):
%1543 = complex.re %in : complex<f32>
%1544 = complex.im %in : complex<f32>
%1545 = linalg.index 4 : index
%1546 = arith.cmpi eq, %1545, %c0_357 : index
%1547 = arith.select %1546, %1543, %1544 : f32
linalg.yield %1547 : f32
} -> tensor<2x8x4x8x?xf32>
%cast_358 = tensor.cast %195 : tensor<2x8x4x8x?xf32> to tensor<2x8x4x8x2xf32>
%int2_359 = torch.constant.int 2
%int8_360 = torch.constant.int 8
%int4_361 = torch.constant.int 4
%int16_362 = torch.constant.int 16
%196 = torch.prim.ListConstruct %int2_359, %int8_360, %int4_361, %int16_362 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_363 = arith.constant 0 : index
%c2_364 = arith.constant 2 : index
%c1_365 = arith.constant 1 : index
%c8_366 = arith.constant 8 : index
%c2_367 = arith.constant 2 : index
%c4_368 = arith.constant 4 : index
%c3_369 = arith.constant 3 : index
%c8_370 = arith.constant 8 : index
%c4_371 = arith.constant 4 : index
%c2_372 = arith.constant 2 : index
%197 = torch_c.to_i64 %int2_359
%198 = torch_c.to_i64 %int8_360
%199 = torch_c.to_i64 %int4_361
%200 = torch_c.to_i64 %int16_362
%collapsed_373 = tensor.collapse_shape %cast_358 [[0], [1], [2], [3, 4]] : tensor<2x8x4x8x2xf32> into tensor<2x8x4x16xf32>
%int0_374 = torch.constant.int 0
%int0_375 = torch.constant.int 0
%201 = torch_c.to_i64 %int0_375
%int2_376 = torch.constant.int 2
%202 = torch_c.to_i64 %int2_376
%int1_377 = torch.constant.int 1
%c0_378 = arith.constant 0 : index
%c1_379 = arith.constant 1 : index
%c0_380 = arith.constant 0 : index
%c32_381 = arith.constant 32 : index
%c1_382 = arith.constant 1 : index
%c2048_383 = arith.constant 2048 : index
%c2_384 = arith.constant 2 : index
%c4_385 = arith.constant 4 : index
%c3_386 = arith.constant 3 : index
%c16_387 = arith.constant 16 : index
%203 = arith.index_cast %c32_381 : index to i64
%204 = arith.addi %201, %203 : i64
%c0_i64_388 = arith.constant 0 : i64
%205 = arith.cmpi sge, %201, %c0_i64_388 : i64
%206 = arith.select %205, %201, %204 : i64
%c0_i64_389 = arith.constant 0 : i64
%207 = arith.cmpi slt, %206, %c0_i64_389 : i64
%208 = arith.select %207, %c0_i64_389, %206 : i64
%209 = arith.cmpi sgt, %208, %203 : i64
%210 = arith.select %209, %203, %208 : i64
%211 = arith.index_cast %210 : i64 to index
%212 = arith.index_cast %c32_381 : index to i64
%213 = arith.addi %202, %212 : i64
%c0_i64_390 = arith.constant 0 : i64
%214 = arith.cmpi sge, %202, %c0_i64_390 : i64
%215 = arith.select %214, %202, %213 : i64
%c0_i64_391 = arith.constant 0 : i64
%216 = arith.cmpi slt, %215, %c0_i64_391 : i64
%217 = arith.select %216, %c0_i64_391, %215 : i64
%218 = arith.cmpi sgt, %217, %212 : i64
%219 = arith.select %218, %212, %217 : i64
%220 = arith.index_cast %219 : i64 to index
%221 = arith.cmpi sge, %220, %211 : index
%222 = arith.select %221, %220, %211 : index
%c1_392 = arith.constant 1 : index
%c0_393 = arith.constant 0 : index
%c32_394 = arith.constant 32 : index
%c1_395 = arith.constant 1 : index
%c2048_396 = arith.constant 2048 : index
%c2_397 = arith.constant 2 : index
%c4_398 = arith.constant 4 : index
%c3_399 = arith.constant 3 : index
%c16_400 = arith.constant 16 : index
%223 = arith.subi %222, %211 : index
%224 = arith.addi %223, %c1_392 : index
%225 = arith.subi %224, %c1_379 : index
%226 = arith.floordivsi %225, %c1_392 : index
%227 = arith.muli %c1_379, %c1_392 : index
%extracted_slice_401 = tensor.extract_slice %cast[%211, %c0_378, %c0_378, %c0_378] [%226, %c2048_396, %c4_398, %c16_400] [%227, %c1_379, %c1_379, %c1_379] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_402 = tensor.cast %extracted_slice_401 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_403 = torch.constant.int 1
%int0_404 = torch.constant.int 0
%228 = torch_c.to_i64 %int0_404
%int8_405 = torch.constant.int 8
%229 = torch_c.to_i64 %int8_405
%int1_406 = torch.constant.int 1
%c0_407 = arith.constant 0 : index
%c1_408 = arith.constant 1 : index
%c0_409 = arith.constant 0 : index
%c2_410 = arith.constant 2 : index
%c1_411 = arith.constant 1 : index
%c2048_412 = arith.constant 2048 : index
%c2_413 = arith.constant 2 : index
%c4_414 = arith.constant 4 : index
%c3_415 = arith.constant 3 : index
%c16_416 = arith.constant 16 : index
%230 = arith.index_cast %c2048_412 : index to i64
%231 = arith.addi %228, %230 : i64
%c0_i64_417 = arith.constant 0 : i64
%232 = arith.cmpi sge, %228, %c0_i64_417 : i64
%233 = arith.select %232, %228, %231 : i64
%c0_i64_418 = arith.constant 0 : i64
%234 = arith.cmpi slt, %233, %c0_i64_418 : i64
%235 = arith.select %234, %c0_i64_418, %233 : i64
%236 = arith.cmpi sgt, %235, %230 : i64
%237 = arith.select %236, %230, %235 : i64
%238 = arith.index_cast %237 : i64 to index
%239 = arith.index_cast %c2048_412 : index to i64
%240 = arith.addi %229, %239 : i64
%c0_i64_419 = arith.constant 0 : i64
%241 = arith.cmpi sge, %229, %c0_i64_419 : i64
%242 = arith.select %241, %229, %240 : i64
%c0_i64_420 = arith.constant 0 : i64
%243 = arith.cmpi slt, %242, %c0_i64_420 : i64
%244 = arith.select %243, %c0_i64_420, %242 : i64
%245 = arith.cmpi sgt, %244, %239 : i64
%246 = arith.select %245, %239, %244 : i64
%247 = arith.index_cast %246 : i64 to index
%248 = arith.cmpi sge, %247, %238 : index
%249 = arith.select %248, %247, %238 : index
%c1_421 = arith.constant 1 : index
%c0_422 = arith.constant 0 : index
%c2_423 = arith.constant 2 : index
%c1_424 = arith.constant 1 : index
%c2048_425 = arith.constant 2048 : index
%c2_426 = arith.constant 2 : index
%c4_427 = arith.constant 4 : index
%c3_428 = arith.constant 3 : index
%c16_429 = arith.constant 16 : index
%250 = arith.subi %249, %238 : index
%251 = arith.addi %250, %c1_421 : index
%252 = arith.subi %251, %c1_408 : index
%253 = arith.floordivsi %252, %c1_421 : index
%254 = arith.muli %c1_408, %c1_421 : index
%extracted_slice_430 = tensor.extract_slice %cast_402[%c0_407, %238, %c0_407, %c0_407] [%c2_423, %253, %c4_427, %c16_429] [%c1_408, %254, %c1_408, %c1_408] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_431 = tensor.cast %extracted_slice_430 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%false_432 = torch.constant.bool false
%c0_433 = arith.constant 0 : index
%c2_434 = arith.constant 2 : index
%c1_435 = arith.constant 1 : index
%c8_436 = arith.constant 8 : index
%c2_437 = arith.constant 2 : index
%c4_438 = arith.constant 4 : index
%c3_439 = arith.constant 3 : index
%c16_440 = arith.constant 16 : index
%255 = arith.index_cast %c2_434 : index to i64
%256 = arith.index_cast %c8_436 : index to i64
%257 = arith.index_cast %c4_438 : index to i64
%258 = arith.index_cast %c16_440 : index to i64
%c0_i64_441 = arith.constant 0 : i64
%c0_442 = arith.constant 0 : index
%c1_443 = arith.constant 1 : index
%259 = tensor.empty() : tensor<2x8x4x16xf32>
%cast_444 = tensor.cast %collapsed_373 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%260 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_444 : tensor<2x8x4x16xf32>) outs(%cast_431 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_445 = tensor.cast %260 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int0_446 = torch.constant.int 0
%int0_447 = torch.constant.int 0
%261 = torch_c.to_i64 %int0_447
%int2_448 = torch.constant.int 2
%262 = torch_c.to_i64 %int2_448
%int1_449 = torch.constant.int 1
%c0_450 = arith.constant 0 : index
%c1_451 = arith.constant 1 : index
%c0_452 = arith.constant 0 : index
%c32_453 = arith.constant 32 : index
%c1_454 = arith.constant 1 : index
%c2048_455 = arith.constant 2048 : index
%c2_456 = arith.constant 2 : index
%c4_457 = arith.constant 4 : index
%c3_458 = arith.constant 3 : index
%c16_459 = arith.constant 16 : index
%263 = arith.index_cast %c32_453 : index to i64
%264 = arith.addi %261, %263 : i64
%c0_i64_460 = arith.constant 0 : i64
%265 = arith.cmpi sge, %261, %c0_i64_460 : i64
%266 = arith.select %265, %261, %264 : i64
%c0_i64_461 = arith.constant 0 : i64
%267 = arith.cmpi slt, %266, %c0_i64_461 : i64
%268 = arith.select %267, %c0_i64_461, %266 : i64
%269 = arith.cmpi sgt, %268, %263 : i64
%270 = arith.select %269, %263, %268 : i64
%271 = arith.index_cast %270 : i64 to index
%272 = arith.index_cast %c32_453 : index to i64
%273 = arith.addi %262, %272 : i64
%c0_i64_462 = arith.constant 0 : i64
%274 = arith.cmpi sge, %262, %c0_i64_462 : i64
%275 = arith.select %274, %262, %273 : i64
%c0_i64_463 = arith.constant 0 : i64
%276 = arith.cmpi slt, %275, %c0_i64_463 : i64
%277 = arith.select %276, %c0_i64_463, %275 : i64
%278 = arith.cmpi sgt, %277, %272 : i64
%279 = arith.select %278, %272, %277 : i64
%280 = arith.index_cast %279 : i64 to index
%281 = arith.cmpi sge, %280, %271 : index
%282 = arith.select %281, %280, %271 : index
%c1_464 = arith.constant 1 : index
%c0_465 = arith.constant 0 : index
%c32_466 = arith.constant 32 : index
%c1_467 = arith.constant 1 : index
%c2048_468 = arith.constant 2048 : index
%c2_469 = arith.constant 2 : index
%c4_470 = arith.constant 4 : index
%c3_471 = arith.constant 3 : index
%c16_472 = arith.constant 16 : index
%283 = arith.subi %282, %271 : index
%284 = arith.addi %283, %c1_464 : index
%285 = arith.subi %284, %c1_451 : index
%286 = arith.floordivsi %285, %c1_464 : index
%287 = arith.muli %c1_451, %c1_464 : index
%extracted_slice_473 = tensor.extract_slice %cast[%271, %c0_450, %c0_450, %c0_450] [%286, %c2048_468, %c4_470, %c16_472] [%287, %c1_451, %c1_451, %c1_451] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_474 = tensor.cast %extracted_slice_473 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_475 = torch.constant.int 1
%int0_476 = torch.constant.int 0
%288 = torch_c.to_i64 %int0_476
%int8_477 = torch.constant.int 8
%289 = torch_c.to_i64 %int8_477
%int1_478 = torch.constant.int 1
%c0_479 = arith.constant 0 : index
%c1_480 = arith.constant 1 : index
%c0_481 = arith.constant 0 : index
%c2_482 = arith.constant 2 : index
%c1_483 = arith.constant 1 : index
%c2048_484 = arith.constant 2048 : index
%c2_485 = arith.constant 2 : index
%c4_486 = arith.constant 4 : index
%c3_487 = arith.constant 3 : index
%c16_488 = arith.constant 16 : index
%290 = arith.index_cast %c2048_484 : index to i64
%291 = arith.addi %288, %290 : i64
%c0_i64_489 = arith.constant 0 : i64
%292 = arith.cmpi sge, %288, %c0_i64_489 : i64
%293 = arith.select %292, %288, %291 : i64
%c0_i64_490 = arith.constant 0 : i64
%294 = arith.cmpi slt, %293, %c0_i64_490 : i64
%295 = arith.select %294, %c0_i64_490, %293 : i64
%296 = arith.cmpi sgt, %295, %290 : i64
%297 = arith.select %296, %290, %295 : i64
%298 = arith.index_cast %297 : i64 to index
%299 = arith.index_cast %c2048_484 : index to i64
%300 = arith.addi %289, %299 : i64
%c0_i64_491 = arith.constant 0 : i64
%301 = arith.cmpi sge, %289, %c0_i64_491 : i64
%302 = arith.select %301, %289, %300 : i64
%c0_i64_492 = arith.constant 0 : i64
%303 = arith.cmpi slt, %302, %c0_i64_492 : i64
%304 = arith.select %303, %c0_i64_492, %302 : i64
%305 = arith.cmpi sgt, %304, %299 : i64
%306 = arith.select %305, %299, %304 : i64
%307 = arith.index_cast %306 : i64 to index
%308 = arith.cmpi sge, %307, %298 : index
%309 = arith.select %308, %307, %298 : index
%c1_493 = arith.constant 1 : index
%c0_494 = arith.constant 0 : index
%c2_495 = arith.constant 2 : index
%c1_496 = arith.constant 1 : index
%c2048_497 = arith.constant 2048 : index
%c2_498 = arith.constant 2 : index
%c4_499 = arith.constant 4 : index
%c3_500 = arith.constant 3 : index
%c16_501 = arith.constant 16 : index
%310 = arith.subi %309, %298 : index
%311 = arith.addi %310, %c1_493 : index
%312 = arith.subi %311, %c1_480 : index
%313 = arith.floordivsi %312, %c1_493 : index
%314 = arith.muli %c1_480, %c1_493 : index
%cast_502 = tensor.cast %cast_445 : tensor<2x8x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice = tensor.insert_slice %cast_502 into %cast_474[%c0_479, %298, %c0_479, %c0_479] [%c2_495, %313, %c4_499, %c16_501] [%c1_480, %314, %c1_480, %c1_480] : tensor<?x?x?x?xf32> into tensor<2x2048x4x16xf32>
%cast_503 = tensor.cast %inserted_slice : tensor<2x2048x4x16xf32> to tensor<2x2048x4x16xf32>
%int0_504 = torch.constant.int 0
%int0_505 = torch.constant.int 0
%315 = torch_c.to_i64 %int0_505
%int2_506 = torch.constant.int 2
%316 = torch_c.to_i64 %int2_506
%int1_507 = torch.constant.int 1
%c0_508 = arith.constant 0 : index
%c1_509 = arith.constant 1 : index
%c0_510 = arith.constant 0 : index
%c32_511 = arith.constant 32 : index
%c1_512 = arith.constant 1 : index
%c2048_513 = arith.constant 2048 : index
%c2_514 = arith.constant 2 : index
%c4_515 = arith.constant 4 : index
%c3_516 = arith.constant 3 : index
%c16_517 = arith.constant 16 : index
%317 = arith.index_cast %c32_511 : index to i64
%318 = arith.addi %315, %317 : i64
%c0_i64_518 = arith.constant 0 : i64
%319 = arith.cmpi sge, %315, %c0_i64_518 : i64
%320 = arith.select %319, %315, %318 : i64
%c0_i64_519 = arith.constant 0 : i64
%321 = arith.cmpi slt, %320, %c0_i64_519 : i64
%322 = arith.select %321, %c0_i64_519, %320 : i64
%323 = arith.cmpi sgt, %322, %317 : i64
%324 = arith.select %323, %317, %322 : i64
%325 = arith.index_cast %324 : i64 to index
%326 = arith.index_cast %c32_511 : index to i64
%327 = arith.addi %316, %326 : i64
%c0_i64_520 = arith.constant 0 : i64
%328 = arith.cmpi sge, %316, %c0_i64_520 : i64
%329 = arith.select %328, %316, %327 : i64
%c0_i64_521 = arith.constant 0 : i64
%330 = arith.cmpi slt, %329, %c0_i64_521 : i64
%331 = arith.select %330, %c0_i64_521, %329 : i64
%332 = arith.cmpi sgt, %331, %326 : i64
%333 = arith.select %332, %326, %331 : i64
%334 = arith.index_cast %333 : i64 to index
%335 = arith.cmpi sge, %334, %325 : index
%336 = arith.select %335, %334, %325 : index
%c1_522 = arith.constant 1 : index
%c0_523 = arith.constant 0 : index
%c32_524 = arith.constant 32 : index
%c1_525 = arith.constant 1 : index
%c2048_526 = arith.constant 2048 : index
%c2_527 = arith.constant 2 : index
%c4_528 = arith.constant 4 : index
%c3_529 = arith.constant 3 : index
%c16_530 = arith.constant 16 : index
%337 = arith.subi %336, %325 : index
%338 = arith.addi %337, %c1_522 : index
%339 = arith.subi %338, %c1_509 : index
%340 = arith.floordivsi %339, %c1_522 : index
%341 = arith.muli %c1_509, %c1_522 : index
%cast_531 = tensor.cast %cast_503 : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice_532 = tensor.insert_slice %cast_531 into %cast[%325, %c0_508, %c0_508, %c0_508] [%340, %c2048_526, %c4_528, %c16_530] [%341, %c1_509, %c1_509, %c1_509] : tensor<?x?x?x?xf32> into tensor<32x2048x4x16xf32>
%cast_533 = tensor.cast %inserted_slice_532 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%342 = torch_c.from_builtin_tensor %cast_533 : tensor<32x2048x4x16xf32> -> !torch.vtensor<[32,2048,4,16],f32>
%int0_534 = torch.constant.int 0
%int0_535 = torch.constant.int 0
%343 = torch_c.to_i64 %int0_535
%int2_536 = torch.constant.int 2
%344 = torch_c.to_i64 %int2_536
%int1_537 = torch.constant.int 1
%c0_538 = arith.constant 0 : index
%c1_539 = arith.constant 1 : index
%c0_540 = arith.constant 0 : index
%c32_541 = arith.constant 32 : index
%c1_542 = arith.constant 1 : index
%c2048_543 = arith.constant 2048 : index
%c2_544 = arith.constant 2 : index
%c4_545 = arith.constant 4 : index
%c3_546 = arith.constant 3 : index
%c16_547 = arith.constant 16 : index
%345 = arith.index_cast %c32_541 : index to i64
%346 = arith.addi %343, %345 : i64
%c0_i64_548 = arith.constant 0 : i64
%347 = arith.cmpi sge, %343, %c0_i64_548 : i64
%348 = arith.select %347, %343, %346 : i64
%c0_i64_549 = arith.constant 0 : i64
%349 = arith.cmpi slt, %348, %c0_i64_549 : i64
%350 = arith.select %349, %c0_i64_549, %348 : i64
%351 = arith.cmpi sgt, %350, %345 : i64
%352 = arith.select %351, %345, %350 : i64
%353 = arith.index_cast %352 : i64 to index
%354 = arith.index_cast %c32_541 : index to i64
%355 = arith.addi %344, %354 : i64
%c0_i64_550 = arith.constant 0 : i64
%356 = arith.cmpi sge, %344, %c0_i64_550 : i64
%357 = arith.select %356, %344, %355 : i64
%c0_i64_551 = arith.constant 0 : i64
%358 = arith.cmpi slt, %357, %c0_i64_551 : i64
%359 = arith.select %358, %c0_i64_551, %357 : i64
%360 = arith.cmpi sgt, %359, %354 : i64
%361 = arith.select %360, %354, %359 : i64
%362 = arith.index_cast %361 : i64 to index
%363 = arith.cmpi sge, %362, %353 : index
%364 = arith.select %363, %362, %353 : index
%c1_552 = arith.constant 1 : index
%c0_553 = arith.constant 0 : index
%c32_554 = arith.constant 32 : index
%c1_555 = arith.constant 1 : index
%c2048_556 = arith.constant 2048 : index
%c2_557 = arith.constant 2 : index
%c4_558 = arith.constant 4 : index
%c3_559 = arith.constant 3 : index
%c16_560 = arith.constant 16 : index
%365 = arith.subi %364, %353 : index
%366 = arith.addi %365, %c1_552 : index
%367 = arith.subi %366, %c1_539 : index
%368 = arith.floordivsi %367, %c1_552 : index
%369 = arith.muli %c1_539, %c1_552 : index
%extracted_slice_561 = tensor.extract_slice %cast_11[%353, %c0_538, %c0_538, %c0_538] [%368, %c2048_556, %c4_558, %c16_560] [%369, %c1_539, %c1_539, %c1_539] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_562 = tensor.cast %extracted_slice_561 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_563 = torch.constant.int 1
%int0_564 = torch.constant.int 0
%370 = torch_c.to_i64 %int0_564
%int8_565 = torch.constant.int 8
%371 = torch_c.to_i64 %int8_565
%int1_566 = torch.constant.int 1
%c0_567 = arith.constant 0 : index
%c1_568 = arith.constant 1 : index
%c0_569 = arith.constant 0 : index
%c2_570 = arith.constant 2 : index
%c1_571 = arith.constant 1 : index
%c2048_572 = arith.constant 2048 : index
%c2_573 = arith.constant 2 : index
%c4_574 = arith.constant 4 : index
%c3_575 = arith.constant 3 : index
%c16_576 = arith.constant 16 : index
%372 = arith.index_cast %c2048_572 : index to i64
%373 = arith.addi %370, %372 : i64
%c0_i64_577 = arith.constant 0 : i64
%374 = arith.cmpi sge, %370, %c0_i64_577 : i64
%375 = arith.select %374, %370, %373 : i64
%c0_i64_578 = arith.constant 0 : i64
%376 = arith.cmpi slt, %375, %c0_i64_578 : i64
%377 = arith.select %376, %c0_i64_578, %375 : i64
%378 = arith.cmpi sgt, %377, %372 : i64
%379 = arith.select %378, %372, %377 : i64
%380 = arith.index_cast %379 : i64 to index
%381 = arith.index_cast %c2048_572 : index to i64
%382 = arith.addi %371, %381 : i64
%c0_i64_579 = arith.constant 0 : i64
%383 = arith.cmpi sge, %371, %c0_i64_579 : i64
%384 = arith.select %383, %371, %382 : i64
%c0_i64_580 = arith.constant 0 : i64
%385 = arith.cmpi slt, %384, %c0_i64_580 : i64
%386 = arith.select %385, %c0_i64_580, %384 : i64
%387 = arith.cmpi sgt, %386, %381 : i64
%388 = arith.select %387, %381, %386 : i64
%389 = arith.index_cast %388 : i64 to index
%390 = arith.cmpi sge, %389, %380 : index
%391 = arith.select %390, %389, %380 : index
%c1_581 = arith.constant 1 : index
%c0_582 = arith.constant 0 : index
%c2_583 = arith.constant 2 : index
%c1_584 = arith.constant 1 : index
%c2048_585 = arith.constant 2048 : index
%c2_586 = arith.constant 2 : index
%c4_587 = arith.constant 4 : index
%c3_588 = arith.constant 3 : index
%c16_589 = arith.constant 16 : index
%392 = arith.subi %391, %380 : index
%393 = arith.addi %392, %c1_581 : index
%394 = arith.subi %393, %c1_568 : index
%395 = arith.floordivsi %394, %c1_581 : index
%396 = arith.muli %c1_568, %c1_581 : index
%extracted_slice_590 = tensor.extract_slice %cast_562[%c0_567, %380, %c0_567, %c0_567] [%c2_583, %395, %c4_587, %c16_589] [%c1_568, %396, %c1_568, %c1_568] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_591 = tensor.cast %extracted_slice_590 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%false_592 = torch.constant.bool false
%c0_593 = arith.constant 0 : index
%c2_594 = arith.constant 2 : index
%c1_595 = arith.constant 1 : index
%c8_596 = arith.constant 8 : index
%c2_597 = arith.constant 2 : index
%c4_598 = arith.constant 4 : index
%c3_599 = arith.constant 3 : index
%c16_600 = arith.constant 16 : index
%397 = arith.index_cast %c2_594 : index to i64
%398 = arith.index_cast %c8_596 : index to i64
%399 = arith.index_cast %c4_598 : index to i64
%400 = arith.index_cast %c16_600 : index to i64
%c0_i64_601 = arith.constant 0 : i64
%c0_602 = arith.constant 0 : index
%c1_603 = arith.constant 1 : index
%401 = tensor.empty() : tensor<2x8x4x16xf32>
%cast_604 = tensor.cast %expanded_250 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%402 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_604 : tensor<2x8x4x16xf32>) outs(%cast_591 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_605 = tensor.cast %402 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int0_606 = torch.constant.int 0
%int0_607 = torch.constant.int 0
%403 = torch_c.to_i64 %int0_607
%int2_608 = torch.constant.int 2
%404 = torch_c.to_i64 %int2_608
%int1_609 = torch.constant.int 1
%c0_610 = arith.constant 0 : index
%c1_611 = arith.constant 1 : index
%c0_612 = arith.constant 0 : index
%c32_613 = arith.constant 32 : index
%c1_614 = arith.constant 1 : index
%c2048_615 = arith.constant 2048 : index
%c2_616 = arith.constant 2 : index
%c4_617 = arith.constant 4 : index
%c3_618 = arith.constant 3 : index
%c16_619 = arith.constant 16 : index
%405 = arith.index_cast %c32_613 : index to i64
%406 = arith.addi %403, %405 : i64
%c0_i64_620 = arith.constant 0 : i64
%407 = arith.cmpi sge, %403, %c0_i64_620 : i64
%408 = arith.select %407, %403, %406 : i64
%c0_i64_621 = arith.constant 0 : i64
%409 = arith.cmpi slt, %408, %c0_i64_621 : i64
%410 = arith.select %409, %c0_i64_621, %408 : i64
%411 = arith.cmpi sgt, %410, %405 : i64
%412 = arith.select %411, %405, %410 : i64
%413 = arith.index_cast %412 : i64 to index
%414 = arith.index_cast %c32_613 : index to i64
%415 = arith.addi %404, %414 : i64
%c0_i64_622 = arith.constant 0 : i64
%416 = arith.cmpi sge, %404, %c0_i64_622 : i64
%417 = arith.select %416, %404, %415 : i64
%c0_i64_623 = arith.constant 0 : i64
%418 = arith.cmpi slt, %417, %c0_i64_623 : i64
%419 = arith.select %418, %c0_i64_623, %417 : i64
%420 = arith.cmpi sgt, %419, %414 : i64
%421 = arith.select %420, %414, %419 : i64
%422 = arith.index_cast %421 : i64 to index
%423 = arith.cmpi sge, %422, %413 : index
%424 = arith.select %423, %422, %413 : index
%c1_624 = arith.constant 1 : index
%c0_625 = arith.constant 0 : index
%c32_626 = arith.constant 32 : index
%c1_627 = arith.constant 1 : index
%c2048_628 = arith.constant 2048 : index
%c2_629 = arith.constant 2 : index
%c4_630 = arith.constant 4 : index
%c3_631 = arith.constant 3 : index
%c16_632 = arith.constant 16 : index
%425 = arith.subi %424, %413 : index
%426 = arith.addi %425, %c1_624 : index
%427 = arith.subi %426, %c1_611 : index
%428 = arith.floordivsi %427, %c1_624 : index
%429 = arith.muli %c1_611, %c1_624 : index
%extracted_slice_633 = tensor.extract_slice %cast_11[%413, %c0_610, %c0_610, %c0_610] [%428, %c2048_628, %c4_630, %c16_632] [%429, %c1_611, %c1_611, %c1_611] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_634 = tensor.cast %extracted_slice_633 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_635 = torch.constant.int 1
%int0_636 = torch.constant.int 0
%430 = torch_c.to_i64 %int0_636
%int8_637 = torch.constant.int 8
%431 = torch_c.to_i64 %int8_637
%int1_638 = torch.constant.int 1
%c0_639 = arith.constant 0 : index
%c1_640 = arith.constant 1 : index
%c0_641 = arith.constant 0 : index
%c2_642 = arith.constant 2 : index
%c1_643 = arith.constant 1 : index
%c2048_644 = arith.constant 2048 : index
%c2_645 = arith.constant 2 : index
%c4_646 = arith.constant 4 : index
%c3_647 = arith.constant 3 : index
%c16_648 = arith.constant 16 : index
%432 = arith.index_cast %c2048_644 : index to i64
%433 = arith.addi %430, %432 : i64
%c0_i64_649 = arith.constant 0 : i64
%434 = arith.cmpi sge, %430, %c0_i64_649 : i64
%435 = arith.select %434, %430, %433 : i64
%c0_i64_650 = arith.constant 0 : i64
%436 = arith.cmpi slt, %435, %c0_i64_650 : i64
%437 = arith.select %436, %c0_i64_650, %435 : i64
%438 = arith.cmpi sgt, %437, %432 : i64
%439 = arith.select %438, %432, %437 : i64
%440 = arith.index_cast %439 : i64 to index
%441 = arith.index_cast %c2048_644 : index to i64
%442 = arith.addi %431, %441 : i64
%c0_i64_651 = arith.constant 0 : i64
%443 = arith.cmpi sge, %431, %c0_i64_651 : i64
%444 = arith.select %443, %431, %442 : i64
%c0_i64_652 = arith.constant 0 : i64
%445 = arith.cmpi slt, %444, %c0_i64_652 : i64
%446 = arith.select %445, %c0_i64_652, %444 : i64
%447 = arith.cmpi sgt, %446, %441 : i64
%448 = arith.select %447, %441, %446 : i64
%449 = arith.index_cast %448 : i64 to index
%450 = arith.cmpi sge, %449, %440 : index
%451 = arith.select %450, %449, %440 : index
%c1_653 = arith.constant 1 : index
%c0_654 = arith.constant 0 : index
%c2_655 = arith.constant 2 : index
%c1_656 = arith.constant 1 : index
%c2048_657 = arith.constant 2048 : index
%c2_658 = arith.constant 2 : index
%c4_659 = arith.constant 4 : index
%c3_660 = arith.constant 3 : index
%c16_661 = arith.constant 16 : index
%452 = arith.subi %451, %440 : index
%453 = arith.addi %452, %c1_653 : index
%454 = arith.subi %453, %c1_640 : index
%455 = arith.floordivsi %454, %c1_653 : index
%456 = arith.muli %c1_640, %c1_653 : index
%cast_662 = tensor.cast %cast_605 : tensor<2x8x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice_663 = tensor.insert_slice %cast_662 into %cast_634[%c0_639, %440, %c0_639, %c0_639] [%c2_655, %455, %c4_659, %c16_661] [%c1_640, %456, %c1_640, %c1_640] : tensor<?x?x?x?xf32> into tensor<2x2048x4x16xf32>
%cast_664 = tensor.cast %inserted_slice_663 : tensor<2x2048x4x16xf32> to tensor<2x2048x4x16xf32>
%int0_665 = torch.constant.int 0
%int0_666 = torch.constant.int 0
%457 = torch_c.to_i64 %int0_666
%int2_667 = torch.constant.int 2
%458 = torch_c.to_i64 %int2_667
%int1_668 = torch.constant.int 1
%c0_669 = arith.constant 0 : index
%c1_670 = arith.constant 1 : index
%c0_671 = arith.constant 0 : index
%c32_672 = arith.constant 32 : index
%c1_673 = arith.constant 1 : index
%c2048_674 = arith.constant 2048 : index
%c2_675 = arith.constant 2 : index
%c4_676 = arith.constant 4 : index
%c3_677 = arith.constant 3 : index
%c16_678 = arith.constant 16 : index
%459 = arith.index_cast %c32_672 : index to i64
%460 = arith.addi %457, %459 : i64
%c0_i64_679 = arith.constant 0 : i64
%461 = arith.cmpi sge, %457, %c0_i64_679 : i64
%462 = arith.select %461, %457, %460 : i64
%c0_i64_680 = arith.constant 0 : i64
%463 = arith.cmpi slt, %462, %c0_i64_680 : i64
%464 = arith.select %463, %c0_i64_680, %462 : i64
%465 = arith.cmpi sgt, %464, %459 : i64
%466 = arith.select %465, %459, %464 : i64
%467 = arith.index_cast %466 : i64 to index
%468 = arith.index_cast %c32_672 : index to i64
%469 = arith.addi %458, %468 : i64
%c0_i64_681 = arith.constant 0 : i64
%470 = arith.cmpi sge, %458, %c0_i64_681 : i64
%471 = arith.select %470, %458, %469 : i64
%c0_i64_682 = arith.constant 0 : i64
%472 = arith.cmpi slt, %471, %c0_i64_682 : i64
%473 = arith.select %472, %c0_i64_682, %471 : i64
%474 = arith.cmpi sgt, %473, %468 : i64
%475 = arith.select %474, %468, %473 : i64
%476 = arith.index_cast %475 : i64 to index
%477 = arith.cmpi sge, %476, %467 : index
%478 = arith.select %477, %476, %467 : index
%c1_683 = arith.constant 1 : index
%c0_684 = arith.constant 0 : index
%c32_685 = arith.constant 32 : index
%c1_686 = arith.constant 1 : index
%c2048_687 = arith.constant 2048 : index
%c2_688 = arith.constant 2 : index
%c4_689 = arith.constant 4 : index
%c3_690 = arith.constant 3 : index
%c16_691 = arith.constant 16 : index
%479 = arith.subi %478, %467 : index
%480 = arith.addi %479, %c1_683 : index
%481 = arith.subi %480, %c1_670 : index
%482 = arith.floordivsi %481, %c1_683 : index
%483 = arith.muli %c1_670, %c1_683 : index
%cast_692 = tensor.cast %cast_664 : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice_693 = tensor.insert_slice %cast_692 into %cast_11[%467, %c0_669, %c0_669, %c0_669] [%482, %c2048_687, %c4_689, %c16_691] [%483, %c1_670, %c1_670, %c1_670] : tensor<?x?x?x?xf32> into tensor<32x2048x4x16xf32>
%cast_694 = tensor.cast %inserted_slice_693 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%484 = torch_c.from_builtin_tensor %cast_694 : tensor<32x2048x4x16xf32> -> !torch.vtensor<[32,2048,4,16],f32>
%int1_695 = torch.constant.int 1
%int2_696 = torch.constant.int 2
%c0_697 = arith.constant 0 : index
%c2_698 = arith.constant 2 : index
%c1_699 = arith.constant 1 : index
%c8_700 = arith.constant 8 : index
%c2_701 = arith.constant 2 : index
%c4_702 = arith.constant 4 : index
%c3_703 = arith.constant 3 : index
%c16_704 = arith.constant 16 : index
%485 = tensor.empty() : tensor<2x4x8x16xf32>
%486 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_341 : tensor<2x8x4x16xf32>) outs(%485 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_705 = tensor.cast %486 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%487 = torch_c.from_builtin_tensor %cast_705 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int2_706 = torch.constant.int 2
%int4_707 = torch.constant.int 4
%int8_708 = torch.constant.int 8
%int16_709 = torch.constant.int 16
%488 = torch.prim.ListConstruct %int2_706, %int4_707, %int8_708, %int16_709 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_710 = torch.constant.bool false
%489 = torch.aten.expand %487, %488, %false_710 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%490 = torch_c.to_builtin_tensor %489 : !torch.vtensor<[2,4,8,16],f32> -> tensor<2x4x8x16xf32>
%int0_711 = torch.constant.int 0
%c1_712 = arith.constant 1 : index
%c0_713 = arith.constant 0 : index
%c2_714 = arith.constant 2 : index
%c1_715 = arith.constant 1 : index
%c4_716 = arith.constant 4 : index
%c2_717 = arith.constant 2 : index
%c8_718 = arith.constant 8 : index
%c3_719 = arith.constant 3 : index
%c16_720 = arith.constant 16 : index
%491 = tensor.empty() : tensor<2x4x8x16xf32>
%492 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%490 : tensor<2x4x8x16xf32>) outs(%491 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_721 = tensor.cast %492 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%493 = torch_c.from_builtin_tensor %cast_721 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int8_722 = torch.constant.int 8
%int8_723 = torch.constant.int 8
%int16_724 = torch.constant.int 16
%494 = torch.prim.ListConstruct %int8_722, %int8_723, %int16_724 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%495 = torch.aten._unsafe_view %493, %494 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%496 = torch_c.to_builtin_tensor %495 : !torch.vtensor<[8,8,16],f32> -> tensor<8x8x16xf32>
%int0_725 = torch.constant.int 0
%int0_726 = torch.constant.int 0
%497 = torch_c.to_i64 %int0_726
%int2_727 = torch.constant.int 2
%498 = torch_c.to_i64 %int2_727
%int1_728 = torch.constant.int 1
%c0_729 = arith.constant 0 : index
%c1_730 = arith.constant 1 : index
%c0_731 = arith.constant 0 : index
%c32_732 = arith.constant 32 : index
%c1_733 = arith.constant 1 : index
%c2048_734 = arith.constant 2048 : index
%c2_735 = arith.constant 2 : index
%c4_736 = arith.constant 4 : index
%c3_737 = arith.constant 3 : index
%c16_738 = arith.constant 16 : index
%499 = arith.index_cast %c32_732 : index to i64
%500 = arith.addi %497, %499 : i64
%c0_i64_739 = arith.constant 0 : i64
%501 = arith.cmpi sge, %497, %c0_i64_739 : i64
%502 = arith.select %501, %497, %500 : i64
%c0_i64_740 = arith.constant 0 : i64
%503 = arith.cmpi slt, %502, %c0_i64_740 : i64
%504 = arith.select %503, %c0_i64_740, %502 : i64
%505 = arith.cmpi sgt, %504, %499 : i64
%506 = arith.select %505, %499, %504 : i64
%507 = arith.index_cast %506 : i64 to index
%508 = arith.index_cast %c32_732 : index to i64
%509 = arith.addi %498, %508 : i64
%c0_i64_741 = arith.constant 0 : i64
%510 = arith.cmpi sge, %498, %c0_i64_741 : i64
%511 = arith.select %510, %498, %509 : i64
%c0_i64_742 = arith.constant 0 : i64
%512 = arith.cmpi slt, %511, %c0_i64_742 : i64
%513 = arith.select %512, %c0_i64_742, %511 : i64
%514 = arith.cmpi sgt, %513, %508 : i64
%515 = arith.select %514, %508, %513 : i64
%516 = arith.index_cast %515 : i64 to index
%517 = arith.cmpi sge, %516, %507 : index
%518 = arith.select %517, %516, %507 : index
%c1_743 = arith.constant 1 : index
%c0_744 = arith.constant 0 : index
%c32_745 = arith.constant 32 : index
%c1_746 = arith.constant 1 : index
%c2048_747 = arith.constant 2048 : index
%c2_748 = arith.constant 2 : index
%c4_749 = arith.constant 4 : index
%c3_750 = arith.constant 3 : index
%c16_751 = arith.constant 16 : index
%519 = arith.subi %518, %507 : index
%520 = arith.addi %519, %c1_743 : index
%521 = arith.subi %520, %c1_730 : index
%522 = arith.floordivsi %521, %c1_743 : index
%523 = arith.muli %c1_730, %c1_743 : index
%extracted_slice_752 = tensor.extract_slice %cast_533[%507, %c0_729, %c0_729, %c0_729] [%522, %c2048_747, %c4_749, %c16_751] [%523, %c1_730, %c1_730, %c1_730] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_753 = tensor.cast %extracted_slice_752 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_754 = torch.constant.int 1
%int0_755 = torch.constant.int 0
%524 = torch_c.to_i64 %int0_755
%int8_756 = torch.constant.int 8
%525 = torch_c.to_i64 %int8_756
%int1_757 = torch.constant.int 1
%c0_758 = arith.constant 0 : index
%c1_759 = arith.constant 1 : index
%c0_760 = arith.constant 0 : index
%c2_761 = arith.constant 2 : index
%c1_762 = arith.constant 1 : index
%c2048_763 = arith.constant 2048 : index
%c2_764 = arith.constant 2 : index
%c4_765 = arith.constant 4 : index
%c3_766 = arith.constant 3 : index
%c16_767 = arith.constant 16 : index
%526 = arith.index_cast %c2048_763 : index to i64
%527 = arith.addi %524, %526 : i64
%c0_i64_768 = arith.constant 0 : i64
%528 = arith.cmpi sge, %524, %c0_i64_768 : i64
%529 = arith.select %528, %524, %527 : i64
%c0_i64_769 = arith.constant 0 : i64
%530 = arith.cmpi slt, %529, %c0_i64_769 : i64
%531 = arith.select %530, %c0_i64_769, %529 : i64
%532 = arith.cmpi sgt, %531, %526 : i64
%533 = arith.select %532, %526, %531 : i64
%534 = arith.index_cast %533 : i64 to index
%535 = arith.index_cast %c2048_763 : index to i64
%536 = arith.addi %525, %535 : i64
%c0_i64_770 = arith.constant 0 : i64
%537 = arith.cmpi sge, %525, %c0_i64_770 : i64
%538 = arith.select %537, %525, %536 : i64
%c0_i64_771 = arith.constant 0 : i64
%539 = arith.cmpi slt, %538, %c0_i64_771 : i64
%540 = arith.select %539, %c0_i64_771, %538 : i64
%541 = arith.cmpi sgt, %540, %535 : i64
%542 = arith.select %541, %535, %540 : i64
%543 = arith.index_cast %542 : i64 to index
%544 = arith.cmpi sge, %543, %534 : index
%545 = arith.select %544, %543, %534 : index
%c1_772 = arith.constant 1 : index
%c0_773 = arith.constant 0 : index
%c2_774 = arith.constant 2 : index
%c1_775 = arith.constant 1 : index
%c2048_776 = arith.constant 2048 : index
%c2_777 = arith.constant 2 : index
%c4_778 = arith.constant 4 : index
%c3_779 = arith.constant 3 : index
%c16_780 = arith.constant 16 : index
%546 = arith.subi %545, %534 : index
%547 = arith.addi %546, %c1_772 : index
%548 = arith.subi %547, %c1_759 : index
%549 = arith.floordivsi %548, %c1_772 : index
%550 = arith.muli %c1_759, %c1_772 : index
%extracted_slice_781 = tensor.extract_slice %cast_753[%c0_758, %534, %c0_758, %c0_758] [%c2_774, %549, %c4_778, %c16_780] [%c1_759, %550, %c1_759, %c1_759] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_782 = tensor.cast %extracted_slice_781 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%int1_783 = torch.constant.int 1
%int2_784 = torch.constant.int 2
%c0_785 = arith.constant 0 : index
%c2_786 = arith.constant 2 : index
%c1_787 = arith.constant 1 : index
%c8_788 = arith.constant 8 : index
%c2_789 = arith.constant 2 : index
%c4_790 = arith.constant 4 : index
%c3_791 = arith.constant 3 : index
%c16_792 = arith.constant 16 : index
%551 = tensor.empty() : tensor<2x4x8x16xf32>
%552 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_782 : tensor<2x8x4x16xf32>) outs(%551 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_793 = tensor.cast %552 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%int2_794 = torch.constant.int 2
%int3 = torch.constant.int 3
%c0_795 = arith.constant 0 : index
%c2_796 = arith.constant 2 : index
%c1_797 = arith.constant 1 : index
%c4_798 = arith.constant 4 : index
%c2_799 = arith.constant 2 : index
%c8_800 = arith.constant 8 : index
%c3_801 = arith.constant 3 : index
%c16_802 = arith.constant 16 : index
%553 = tensor.empty() : tensor<2x4x16x8xf32>
%554 = linalg.generic {indexing_maps = [#map, #map12], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_793 : tensor<2x4x8x16xf32>) outs(%553 : tensor<2x4x16x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x16x8xf32>
%cast_803 = tensor.cast %554 : tensor<2x4x16x8xf32> to tensor<2x4x16x8xf32>
%555 = torch_c.from_builtin_tensor %cast_803 : tensor<2x4x16x8xf32> -> !torch.vtensor<[2,4,16,8],f32>
%int2_804 = torch.constant.int 2
%int4_805 = torch.constant.int 4
%int16_806 = torch.constant.int 16
%int8_807 = torch.constant.int 8
%556 = torch.prim.ListConstruct %int2_804, %int4_805, %int16_806, %int8_807 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_808 = torch.constant.bool false
%557 = torch.aten.expand %555, %556, %false_808 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,16,8],f32>
%558 = torch_c.to_builtin_tensor %557 : !torch.vtensor<[2,4,16,8],f32> -> tensor<2x4x16x8xf32>
%int0_809 = torch.constant.int 0
%c1_810 = arith.constant 1 : index
%c0_811 = arith.constant 0 : index
%c2_812 = arith.constant 2 : index
%c1_813 = arith.constant 1 : index
%c4_814 = arith.constant 4 : index
%c2_815 = arith.constant 2 : index
%c16_816 = arith.constant 16 : index
%c3_817 = arith.constant 3 : index
%c8_818 = arith.constant 8 : index
%559 = tensor.empty() : tensor<2x4x16x8xf32>
%560 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%558 : tensor<2x4x16x8xf32>) outs(%559 : tensor<2x4x16x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x16x8xf32>
%cast_819 = tensor.cast %560 : tensor<2x4x16x8xf32> to tensor<2x4x16x8xf32>
%561 = torch_c.from_builtin_tensor %cast_819 : tensor<2x4x16x8xf32> -> !torch.vtensor<[2,4,16,8],f32>
%int8_820 = torch.constant.int 8
%int16_821 = torch.constant.int 16
%int8_822 = torch.constant.int 8
%562 = torch.prim.ListConstruct %int8_820, %int16_821, %int8_822 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%563 = torch.aten._unsafe_view %561, %562 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int> -> !torch.vtensor<[8,16,8],f32>
%564 = torch_c.to_builtin_tensor %563 : !torch.vtensor<[8,16,8],f32> -> tensor<8x16x8xf32>
%c0_823 = arith.constant 0 : index
%c8_824 = arith.constant 8 : index
%c1_825 = arith.constant 1 : index
%c8_826 = arith.constant 8 : index
%c2_827 = arith.constant 2 : index
%c16_828 = arith.constant 16 : index
%c0_829 = arith.constant 0 : index
%c8_830 = arith.constant 8 : index
%c1_831 = arith.constant 1 : index
%c16_832 = arith.constant 16 : index
%c2_833 = arith.constant 2 : index
%c8_834 = arith.constant 8 : index
%565 = arith.index_cast %c8_824 : index to i64
%566 = arith.index_cast %c8_830 : index to i64
%567 = arith.cmpi eq, %565, %566 : i64
cf.assert %567, "mismatching contracting dimension"
%568 = arith.index_cast %c16_828 : index to i64
%569 = arith.index_cast %c16_832 : index to i64
%570 = arith.cmpi eq, %568, %569 : i64
cf.assert %570, "mismatching contracting dimension"
%571 = tensor.empty() : tensor<8x8x8xf32>
%cst_835 = arith.constant 0.000000e+00 : f32
%572 = linalg.fill ins(%cst_835 : f32) outs(%571 : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
%573 = linalg.batch_matmul ins(%496, %564 : tensor<8x8x16xf32>, tensor<8x16x8xf32>) outs(%572 : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
%cast_836 = tensor.cast %573 : tensor<8x8x8xf32> to tensor<8x8x8xf32>
%int2_837 = torch.constant.int 2
%int4_838 = torch.constant.int 4
%int8_839 = torch.constant.int 8
%int8_840 = torch.constant.int 8
%574 = torch.prim.ListConstruct %int2_837, %int4_838, %int8_839, %int8_840 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_841 = arith.constant 0 : index
%c8_842 = arith.constant 8 : index
%c1_843 = arith.constant 1 : index
%c8_844 = arith.constant 8 : index
%c2_845 = arith.constant 2 : index
%c8_846 = arith.constant 8 : index
%575 = torch_c.to_i64 %int2_837
%576 = torch_c.to_i64 %int4_838
%577 = torch_c.to_i64 %int8_839
%578 = torch_c.to_i64 %int8_840
%expanded_847 = tensor.expand_shape %cast_836 [[0, 1], [2], [3]] : tensor<8x8x8xf32> into tensor<2x4x8x8xf32>
%float4.000000e00 = torch.constant.float 4.000000e+00
%579 = torch_c.to_f64 %float4.000000e00
%c1_848 = arith.constant 1 : index
%c0_849 = arith.constant 0 : index
%c2_850 = arith.constant 2 : index
%c1_851 = arith.constant 1 : index
%c4_852 = arith.constant 4 : index
%c2_853 = arith.constant 2 : index
%c8_854 = arith.constant 8 : index
%c3_855 = arith.constant 3 : index
%c8_856 = arith.constant 8 : index
%580 = tensor.empty() : tensor<2x4x8x8xf32>
%581 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_847 : tensor<2x4x8x8xf32>) outs(%580 : tensor<2x4x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %579 : f64 to f32
%1544 = arith.divf %in, %1543 : f32
linalg.yield %1544 : f32
} -> tensor<2x4x8x8xf32>
%cast_857 = tensor.cast %581 : tensor<2x4x8x8xf32> to tensor<2x4x8x8xf32>
%int1_858 = torch.constant.int 1
%582 = torch_c.to_i64 %int1_858
%c1_859 = arith.constant 1 : index
%c0_860 = arith.constant 0 : index
%c2_861 = arith.constant 2 : index
%c1_862 = arith.constant 1 : index
%c4_863 = arith.constant 4 : index
%c2_864 = arith.constant 2 : index
%c8_865 = arith.constant 8 : index
%c3_866 = arith.constant 3 : index
%c8_867 = arith.constant 8 : index
%c2_868 = arith.constant 2 : index
%c8_869 = arith.constant 8 : index
%583 = arith.cmpi eq, %c8_865, %c8_869 : index
cf.assert %583, "mismatched size for broadcast"
%c3_870 = arith.constant 3 : index
%c8_871 = arith.constant 8 : index
%584 = arith.cmpi eq, %c8_867, %c8_871 : index
cf.assert %584, "mismatched size for broadcast"
%585 = tensor.empty() : tensor<2x4x8x8xf32>
%586 = linalg.generic {indexing_maps = [#map, #map3, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_857, %cast_76 : tensor<2x4x8x8xf32>, tensor<1x1x8x8xf32>) outs(%585 : tensor<2x4x8x8xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.sitofp %582 : i64 to f32
%1544 = arith.mulf %in_2549, %1543 : f32
%1545 = arith.addf %in, %1544 : f32
linalg.yield %1545 : f32
} -> tensor<2x4x8x8xf32>
%cast_872 = tensor.cast %586 : tensor<2x4x8x8xf32> to tensor<2x4x8x8xf32>
%587 = torch_c.from_builtin_tensor %cast_872 : tensor<2x4x8x8xf32> -> !torch.vtensor<[2,4,8,8],f32>
%int-1_873 = torch.constant.int -1
%false_874 = torch.constant.bool false
%588 = torch.aten._softmax %587, %int-1_873, %false_874 : !torch.vtensor<[2,4,8,8],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%int2_875 = torch.constant.int 2
%int4_876 = torch.constant.int 4
%int8_877 = torch.constant.int 8
%int8_878 = torch.constant.int 8
%589 = torch.prim.ListConstruct %int2_875, %int4_876, %int8_877, %int8_878 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_879 = torch.constant.bool false
%590 = torch.aten.expand %588, %589, %false_879 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%591 = torch_c.to_builtin_tensor %590 : !torch.vtensor<[2,4,8,8],f32> -> tensor<2x4x8x8xf32>
%int8_880 = torch.constant.int 8
%int8_881 = torch.constant.int 8
%int8_882 = torch.constant.int 8
%592 = torch.prim.ListConstruct %int8_880, %int8_881, %int8_882 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_883 = arith.constant 0 : index
%c2_884 = arith.constant 2 : index
%c1_885 = arith.constant 1 : index
%c4_886 = arith.constant 4 : index
%c2_887 = arith.constant 2 : index
%c8_888 = arith.constant 8 : index
%c3_889 = arith.constant 3 : index
%c8_890 = arith.constant 8 : index
%593 = torch_c.to_i64 %int8_880
%594 = torch_c.to_i64 %int8_881
%595 = torch_c.to_i64 %int8_882
%collapsed_891 = tensor.collapse_shape %591 [[0, 1], [2], [3]] : tensor<2x4x8x8xf32> into tensor<8x8x8xf32>
%596 = torch_c.from_builtin_tensor %collapsed_891 : tensor<8x8x8xf32> -> !torch.vtensor<[8,8,8],f32>
%int0_892 = torch.constant.int 0
%int0_893 = torch.constant.int 0
%597 = torch_c.to_i64 %int0_893
%int2_894 = torch.constant.int 2
%598 = torch_c.to_i64 %int2_894
%int1_895 = torch.constant.int 1
%c0_896 = arith.constant 0 : index
%c1_897 = arith.constant 1 : index
%c0_898 = arith.constant 0 : index
%c32_899 = arith.constant 32 : index
%c1_900 = arith.constant 1 : index
%c2048_901 = arith.constant 2048 : index
%c2_902 = arith.constant 2 : index
%c4_903 = arith.constant 4 : index
%c3_904 = arith.constant 3 : index
%c16_905 = arith.constant 16 : index
%599 = arith.index_cast %c32_899 : index to i64
%600 = arith.addi %597, %599 : i64
%c0_i64_906 = arith.constant 0 : i64
%601 = arith.cmpi sge, %597, %c0_i64_906 : i64
%602 = arith.select %601, %597, %600 : i64
%c0_i64_907 = arith.constant 0 : i64
%603 = arith.cmpi slt, %602, %c0_i64_907 : i64
%604 = arith.select %603, %c0_i64_907, %602 : i64
%605 = arith.cmpi sgt, %604, %599 : i64
%606 = arith.select %605, %599, %604 : i64
%607 = arith.index_cast %606 : i64 to index
%608 = arith.index_cast %c32_899 : index to i64
%609 = arith.addi %598, %608 : i64
%c0_i64_908 = arith.constant 0 : i64
%610 = arith.cmpi sge, %598, %c0_i64_908 : i64
%611 = arith.select %610, %598, %609 : i64
%c0_i64_909 = arith.constant 0 : i64
%612 = arith.cmpi slt, %611, %c0_i64_909 : i64
%613 = arith.select %612, %c0_i64_909, %611 : i64
%614 = arith.cmpi sgt, %613, %608 : i64
%615 = arith.select %614, %608, %613 : i64
%616 = arith.index_cast %615 : i64 to index
%617 = arith.cmpi sge, %616, %607 : index
%618 = arith.select %617, %616, %607 : index
%c1_910 = arith.constant 1 : index
%c0_911 = arith.constant 0 : index
%c32_912 = arith.constant 32 : index
%c1_913 = arith.constant 1 : index
%c2048_914 = arith.constant 2048 : index
%c2_915 = arith.constant 2 : index
%c4_916 = arith.constant 4 : index
%c3_917 = arith.constant 3 : index
%c16_918 = arith.constant 16 : index
%619 = arith.subi %618, %607 : index
%620 = arith.addi %619, %c1_910 : index
%621 = arith.subi %620, %c1_897 : index
%622 = arith.floordivsi %621, %c1_910 : index
%623 = arith.muli %c1_897, %c1_910 : index
%extracted_slice_919 = tensor.extract_slice %cast_694[%607, %c0_896, %c0_896, %c0_896] [%622, %c2048_914, %c4_916, %c16_918] [%623, %c1_897, %c1_897, %c1_897] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_920 = tensor.cast %extracted_slice_919 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_921 = torch.constant.int 1
%int0_922 = torch.constant.int 0
%624 = torch_c.to_i64 %int0_922
%int8_923 = torch.constant.int 8
%625 = torch_c.to_i64 %int8_923
%int1_924 = torch.constant.int 1
%c0_925 = arith.constant 0 : index
%c1_926 = arith.constant 1 : index
%c0_927 = arith.constant 0 : index
%c2_928 = arith.constant 2 : index
%c1_929 = arith.constant 1 : index
%c2048_930 = arith.constant 2048 : index
%c2_931 = arith.constant 2 : index
%c4_932 = arith.constant 4 : index
%c3_933 = arith.constant 3 : index
%c16_934 = arith.constant 16 : index
%626 = arith.index_cast %c2048_930 : index to i64
%627 = arith.addi %624, %626 : i64
%c0_i64_935 = arith.constant 0 : i64
%628 = arith.cmpi sge, %624, %c0_i64_935 : i64
%629 = arith.select %628, %624, %627 : i64
%c0_i64_936 = arith.constant 0 : i64
%630 = arith.cmpi slt, %629, %c0_i64_936 : i64
%631 = arith.select %630, %c0_i64_936, %629 : i64
%632 = arith.cmpi sgt, %631, %626 : i64
%633 = arith.select %632, %626, %631 : i64
%634 = arith.index_cast %633 : i64 to index
%635 = arith.index_cast %c2048_930 : index to i64
%636 = arith.addi %625, %635 : i64
%c0_i64_937 = arith.constant 0 : i64
%637 = arith.cmpi sge, %625, %c0_i64_937 : i64
%638 = arith.select %637, %625, %636 : i64
%c0_i64_938 = arith.constant 0 : i64
%639 = arith.cmpi slt, %638, %c0_i64_938 : i64
%640 = arith.select %639, %c0_i64_938, %638 : i64
%641 = arith.cmpi sgt, %640, %635 : i64
%642 = arith.select %641, %635, %640 : i64
%643 = arith.index_cast %642 : i64 to index
%644 = arith.cmpi sge, %643, %634 : index
%645 = arith.select %644, %643, %634 : index
%c1_939 = arith.constant 1 : index
%c0_940 = arith.constant 0 : index
%c2_941 = arith.constant 2 : index
%c1_942 = arith.constant 1 : index
%c2048_943 = arith.constant 2048 : index
%c2_944 = arith.constant 2 : index
%c4_945 = arith.constant 4 : index
%c3_946 = arith.constant 3 : index
%c16_947 = arith.constant 16 : index
%646 = arith.subi %645, %634 : index
%647 = arith.addi %646, %c1_939 : index
%648 = arith.subi %647, %c1_926 : index
%649 = arith.floordivsi %648, %c1_939 : index
%650 = arith.muli %c1_926, %c1_939 : index
%extracted_slice_948 = tensor.extract_slice %cast_920[%c0_925, %634, %c0_925, %c0_925] [%c2_941, %649, %c4_945, %c16_947] [%c1_926, %650, %c1_926, %c1_926] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_949 = tensor.cast %extracted_slice_948 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%int1_950 = torch.constant.int 1
%int2_951 = torch.constant.int 2
%c0_952 = arith.constant 0 : index
%c2_953 = arith.constant 2 : index
%c1_954 = arith.constant 1 : index
%c8_955 = arith.constant 8 : index
%c2_956 = arith.constant 2 : index
%c4_957 = arith.constant 4 : index
%c3_958 = arith.constant 3 : index
%c16_959 = arith.constant 16 : index
%651 = tensor.empty() : tensor<2x4x8x16xf32>
%652 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_949 : tensor<2x8x4x16xf32>) outs(%651 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_960 = tensor.cast %652 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%653 = torch_c.from_builtin_tensor %cast_960 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int2_961 = torch.constant.int 2
%int4_962 = torch.constant.int 4
%int8_963 = torch.constant.int 8
%int16_964 = torch.constant.int 16
%654 = torch.prim.ListConstruct %int2_961, %int4_962, %int8_963, %int16_964 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_965 = torch.constant.bool false
%655 = torch.aten.expand %653, %654, %false_965 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%656 = torch_c.to_builtin_tensor %655 : !torch.vtensor<[2,4,8,16],f32> -> tensor<2x4x8x16xf32>
%int0_966 = torch.constant.int 0
%c1_967 = arith.constant 1 : index
%c0_968 = arith.constant 0 : index
%c2_969 = arith.constant 2 : index
%c1_970 = arith.constant 1 : index
%c4_971 = arith.constant 4 : index
%c2_972 = arith.constant 2 : index
%c8_973 = arith.constant 8 : index
%c3_974 = arith.constant 3 : index
%c16_975 = arith.constant 16 : index
%657 = tensor.empty() : tensor<2x4x8x16xf32>
%658 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%656 : tensor<2x4x8x16xf32>) outs(%657 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_976 = tensor.cast %658 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%659 = torch_c.from_builtin_tensor %cast_976 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int8_977 = torch.constant.int 8
%int8_978 = torch.constant.int 8
%int16_979 = torch.constant.int 16
%660 = torch.prim.ListConstruct %int8_977, %int8_978, %int16_979 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%661 = torch.aten._unsafe_view %659, %660 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%662 = torch_c.to_builtin_tensor %661 : !torch.vtensor<[8,8,16],f32> -> tensor<8x8x16xf32>
%c0_980 = arith.constant 0 : index
%c8_981 = arith.constant 8 : index
%c1_982 = arith.constant 1 : index
%c8_983 = arith.constant 8 : index
%c2_984 = arith.constant 2 : index
%c8_985 = arith.constant 8 : index
%c0_986 = arith.constant 0 : index
%c8_987 = arith.constant 8 : index
%c1_988 = arith.constant 1 : index
%c8_989 = arith.constant 8 : index
%c2_990 = arith.constant 2 : index
%c16_991 = arith.constant 16 : index
%663 = arith.index_cast %c8_981 : index to i64
%664 = arith.index_cast %c8_987 : index to i64
%665 = arith.cmpi eq, %663, %664 : i64
cf.assert %665, "mismatching contracting dimension"
%666 = arith.index_cast %c8_985 : index to i64
%667 = arith.index_cast %c8_989 : index to i64
%668 = arith.cmpi eq, %666, %667 : i64
cf.assert %668, "mismatching contracting dimension"
%669 = tensor.empty() : tensor<8x8x16xf32>
%cst_992 = arith.constant 0.000000e+00 : f32
%670 = linalg.fill ins(%cst_992 : f32) outs(%669 : tensor<8x8x16xf32>) -> tensor<8x8x16xf32>
%671 = linalg.batch_matmul ins(%collapsed_891, %662 : tensor<8x8x8xf32>, tensor<8x8x16xf32>) outs(%670 : tensor<8x8x16xf32>) -> tensor<8x8x16xf32>
%cast_993 = tensor.cast %671 : tensor<8x8x16xf32> to tensor<8x8x16xf32>
%int2_994 = torch.constant.int 2
%int4_995 = torch.constant.int 4
%int8_996 = torch.constant.int 8
%int16_997 = torch.constant.int 16
%672 = torch.prim.ListConstruct %int2_994, %int4_995, %int8_996, %int16_997 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_998 = arith.constant 0 : index
%c8_999 = arith.constant 8 : index
%c1_1000 = arith.constant 1 : index
%c8_1001 = arith.constant 8 : index
%c2_1002 = arith.constant 2 : index
%c16_1003 = arith.constant 16 : index
%673 = torch_c.to_i64 %int2_994
%674 = torch_c.to_i64 %int4_995
%675 = torch_c.to_i64 %int8_996
%676 = torch_c.to_i64 %int16_997
%expanded_1004 = tensor.expand_shape %cast_993 [[0, 1], [2], [3]] : tensor<8x8x16xf32> into tensor<2x4x8x16xf32>
%int1_1005 = torch.constant.int 1
%int2_1006 = torch.constant.int 2
%c0_1007 = arith.constant 0 : index
%c2_1008 = arith.constant 2 : index
%c1_1009 = arith.constant 1 : index
%c4_1010 = arith.constant 4 : index
%c2_1011 = arith.constant 2 : index
%c8_1012 = arith.constant 8 : index
%c3_1013 = arith.constant 3 : index
%c16_1014 = arith.constant 16 : index
%677 = tensor.empty() : tensor<2x8x4x16xf32>
%678 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1004 : tensor<2x4x8x16xf32>) outs(%677 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_1015 = tensor.cast %678 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int0_1016 = torch.constant.int 0
%c1_1017 = arith.constant 1 : index
%c0_1018 = arith.constant 0 : index
%c2_1019 = arith.constant 2 : index
%c1_1020 = arith.constant 1 : index
%c8_1021 = arith.constant 8 : index
%c2_1022 = arith.constant 2 : index
%c4_1023 = arith.constant 4 : index
%c3_1024 = arith.constant 3 : index
%c16_1025 = arith.constant 16 : index
%679 = tensor.empty() : tensor<2x8x4x16xf32>
%680 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_1015 : tensor<2x8x4x16xf32>) outs(%679 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_1026 = tensor.cast %680 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int2_1027 = torch.constant.int 2
%int8_1028 = torch.constant.int 8
%int-1_1029 = torch.constant.int -1
%681 = torch.prim.ListConstruct %int2_1027, %int8_1028, %int-1_1029 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1030 = arith.constant 0 : index
%c2_1031 = arith.constant 2 : index
%c1_1032 = arith.constant 1 : index
%c8_1033 = arith.constant 8 : index
%c2_1034 = arith.constant 2 : index
%c4_1035 = arith.constant 4 : index
%c3_1036 = arith.constant 3 : index
%c16_1037 = arith.constant 16 : index
%682 = torch_c.to_i64 %int2_1027
%683 = torch_c.to_i64 %int8_1028
%684 = torch_c.to_i64 %int-1_1029
%collapsed_1038 = tensor.collapse_shape %cast_1026 [[0], [1], [2, 3]] : tensor<2x8x4x16xf32> into tensor<2x8x64xf32>
%int0_1039 = torch.constant.int 0
%int1_1040 = torch.constant.int 1
%c0_1041 = arith.constant 0 : index
%c64_1042 = arith.constant 64 : index
%c1_1043 = arith.constant 1 : index
%c64_1044 = arith.constant 64 : index
%685 = tensor.empty() : tensor<64x64xf32>
%686 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%11 : tensor<64x64xf32>) outs(%685 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_1045 = tensor.cast %686 : tensor<64x64xf32> to tensor<64x64xf32>
%687 = torch_c.from_builtin_tensor %cast_1045 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16_1046 = torch.constant.int 16
%int64_1047 = torch.constant.int 64
%688 = torch.prim.ListConstruct %int16_1046, %int64_1047 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_1048 = arith.constant 0 : index
%c2_1049 = arith.constant 2 : index
%c1_1050 = arith.constant 1 : index
%c8_1051 = arith.constant 8 : index
%c2_1052 = arith.constant 2 : index
%c64_1053 = arith.constant 64 : index
%689 = torch_c.to_i64 %int16_1046
%690 = torch_c.to_i64 %int64_1047
%collapsed_1054 = tensor.collapse_shape %collapsed_1038 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%691 = torch_c.from_builtin_tensor %collapsed_1054 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_1055 = arith.constant 0 : index
%dim_1056 = tensor.dim %collapsed_1054, %c0_1055 : tensor<16x64xf32>
%c1_1057 = arith.constant 1 : index
%dim_1058 = tensor.dim %cast_1045, %c1_1057 : tensor<64x64xf32>
%c1_1059 = arith.constant 1 : index
%dim_1060 = tensor.dim %collapsed_1054, %c1_1059 : tensor<16x64xf32>
%c0_1061 = arith.constant 0 : index
%dim_1062 = tensor.dim %cast_1045, %c0_1061 : tensor<64x64xf32>
%692 = arith.cmpi eq, %dim_1060, %dim_1062 : index
cf.assert %692, "mismatching contracting dimension for torch.aten.mm"
%693 = tensor.empty(%dim_1056, %dim_1058) : tensor<?x?xf32>
%cst_1063 = arith.constant 0.000000e+00 : f32
%694 = linalg.fill ins(%cst_1063 : f32) outs(%693 : tensor<?x?xf32>) -> tensor<?x?xf32>
%695 = linalg.matmul ins(%collapsed_1054, %cast_1045 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%694 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_1064 = tensor.cast %695 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_1065 = torch.constant.int 2
%int8_1066 = torch.constant.int 8
%int64_1067 = torch.constant.int 64
%696 = torch.prim.ListConstruct %int2_1065, %int8_1066, %int64_1067 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1068 = arith.constant 0 : index
%c16_1069 = arith.constant 16 : index
%c1_1070 = arith.constant 1 : index
%c64_1071 = arith.constant 64 : index
%697 = torch_c.to_i64 %int2_1065
%698 = torch_c.to_i64 %int8_1066
%699 = torch_c.to_i64 %int64_1067
%expanded_1072 = tensor.expand_shape %cast_1064 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int1_1073 = torch.constant.int 1
%700 = torch_c.to_i64 %int1_1073
%c1_1074 = arith.constant 1 : index
%c0_1075 = arith.constant 0 : index
%c2_1076 = arith.constant 2 : index
%c1_1077 = arith.constant 1 : index
%c8_1078 = arith.constant 8 : index
%c2_1079 = arith.constant 2 : index
%c64_1080 = arith.constant 64 : index
%c0_1081 = arith.constant 0 : index
%c2_1082 = arith.constant 2 : index
%701 = arith.cmpi eq, %c2_1076, %c2_1082 : index
cf.assert %701, "mismatched size for broadcast"
%c1_1083 = arith.constant 1 : index
%c8_1084 = arith.constant 8 : index
%702 = arith.cmpi eq, %c8_1078, %c8_1084 : index
cf.assert %702, "mismatched size for broadcast"
%c2_1085 = arith.constant 2 : index
%c64_1086 = arith.constant 64 : index
%703 = arith.cmpi eq, %c64_1080, %c64_1086 : index
cf.assert %703, "mismatched size for broadcast"
%704 = tensor.empty() : tensor<2x8x64xf32>
%705 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_39, %expanded_1072 : tensor<2x8x64xf32>, tensor<2x8x64xf32>) outs(%704 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.sitofp %700 : i64 to f32
%1544 = arith.mulf %in_2549, %1543 : f32
%1545 = arith.addf %in, %1544 : f32
linalg.yield %1545 : f32
} -> tensor<2x8x64xf32>
%cast_1087 = tensor.cast %705 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%706 = torch_c.from_builtin_tensor %cast_1087 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int2_1088 = torch.constant.int 2
%707 = torch_c.to_i64 %int2_1088
%c1_1089 = arith.constant 1 : index
%c0_1090 = arith.constant 0 : index
%c2_1091 = arith.constant 2 : index
%c1_1092 = arith.constant 1 : index
%c8_1093 = arith.constant 8 : index
%c2_1094 = arith.constant 2 : index
%c64_1095 = arith.constant 64 : index
%708 = tensor.empty() : tensor<2x8x64xf32>
%709 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1087 : tensor<2x8x64xf32>) outs(%708 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.sitofp %707 : i64 to f32
%1544 = math.powf %in, %1543 : f32
linalg.yield %1544 : f32
} -> tensor<2x8x64xf32>
%cast_1096 = tensor.cast %709 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%710 = torch_c.from_builtin_tensor %cast_1096 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int-1_1097 = torch.constant.int -1
%711 = torch.prim.ListConstruct %int-1_1097 : (!torch.int) -> !torch.list<int>
%true_1098 = torch.constant.bool true
%none_1099 = torch.constant.none
%712 = torch.aten.mean.dim %710, %711, %true_1098, %none_1099 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%713 = torch_c.to_builtin_tensor %712 : !torch.vtensor<[2,8,1],f32> -> tensor<2x8x1xf32>
%float1.000000e-05_1100 = torch.constant.float 1.000000e-05
%714 = torch_c.to_f64 %float1.000000e-05_1100
%int1_1101 = torch.constant.int 1
%715 = torch_c.to_i64 %int1_1101
%c1_1102 = arith.constant 1 : index
%c0_1103 = arith.constant 0 : index
%c2_1104 = arith.constant 2 : index
%c1_1105 = arith.constant 1 : index
%c8_1106 = arith.constant 8 : index
%716 = tensor.empty() : tensor<2x8x1xf32>
%717 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%713 : tensor<2x8x1xf32>) outs(%716 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %714 : f64 to f32
%1544 = arith.sitofp %715 : i64 to f32
%1545 = arith.mulf %1543, %1544 : f32
%1546 = arith.addf %in, %1545 : f32
linalg.yield %1546 : f32
} -> tensor<2x8x1xf32>
%cast_1107 = tensor.cast %717 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%c1_1108 = arith.constant 1 : index
%c0_1109 = arith.constant 0 : index
%c2_1110 = arith.constant 2 : index
%c1_1111 = arith.constant 1 : index
%c8_1112 = arith.constant 8 : index
%718 = tensor.empty() : tensor<2x8x1xf32>
%719 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1107 : tensor<2x8x1xf32>) outs(%718 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = math.rsqrt %in : f32
linalg.yield %1543 : f32
} -> tensor<2x8x1xf32>
%cast_1113 = tensor.cast %719 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%720 = torch_c.from_builtin_tensor %cast_1113 : tensor<2x8x1xf32> -> !torch.vtensor<[2,8,1],f32>
%c1_1114 = arith.constant 1 : index
%c0_1115 = arith.constant 0 : index
%c2_1116 = arith.constant 2 : index
%c1_1117 = arith.constant 1 : index
%c8_1118 = arith.constant 8 : index
%c2_1119 = arith.constant 2 : index
%c64_1120 = arith.constant 64 : index
%c0_1121 = arith.constant 0 : index
%c2_1122 = arith.constant 2 : index
%721 = arith.cmpi eq, %c2_1116, %c2_1122 : index
cf.assert %721, "mismatched size for broadcast"
%c1_1123 = arith.constant 1 : index
%c8_1124 = arith.constant 8 : index
%722 = arith.cmpi eq, %c8_1118, %c8_1124 : index
cf.assert %722, "mismatched size for broadcast"
%723 = tensor.empty() : tensor<2x8x64xf32>
%724 = linalg.generic {indexing_maps = [#map2, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1087, %cast_1113 : tensor<2x8x64xf32>, tensor<2x8x1xf32>) outs(%723 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_1125 = tensor.cast %724 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%725 = torch_c.from_builtin_tensor %cast_1125 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%c1_1126 = arith.constant 1 : index
%c0_1127 = arith.constant 0 : index
%c2_1128 = arith.constant 2 : index
%c1_1129 = arith.constant 1 : index
%c8_1130 = arith.constant 8 : index
%c2_1131 = arith.constant 2 : index
%c64_1132 = arith.constant 64 : index
%c0_1133 = arith.constant 0 : index
%c64_1134 = arith.constant 64 : index
%726 = arith.cmpi eq, %c64_1132, %c64_1134 : index
cf.assert %726, "mismatched size for broadcast"
%727 = tensor.empty() : tensor<2x8x64xf32>
%728 = linalg.generic {indexing_maps = [#map2, #map5, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1125, %12 : tensor<2x8x64xf32>, tensor<64xf32>) outs(%727 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_1135 = tensor.cast %728 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%int0_1136 = torch.constant.int 0
%int1_1137 = torch.constant.int 1
%c0_1138 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c1_1139 = arith.constant 1 : index
%c64_1140 = arith.constant 64 : index
%729 = tensor.empty() : tensor<64x256xf32>
%730 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<256x64xf32>) outs(%729 : tensor<64x256xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x256xf32>
%cast_1141 = tensor.cast %730 : tensor<64x256xf32> to tensor<64x256xf32>
%731 = torch_c.from_builtin_tensor %cast_1141 : tensor<64x256xf32> -> !torch.vtensor<[64,256],f32>
%int16_1142 = torch.constant.int 16
%int64_1143 = torch.constant.int 64
%732 = torch.prim.ListConstruct %int16_1142, %int64_1143 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_1144 = arith.constant 0 : index
%c2_1145 = arith.constant 2 : index
%c1_1146 = arith.constant 1 : index
%c8_1147 = arith.constant 8 : index
%c2_1148 = arith.constant 2 : index
%c64_1149 = arith.constant 64 : index
%733 = torch_c.to_i64 %int16_1142
%734 = torch_c.to_i64 %int64_1143
%collapsed_1150 = tensor.collapse_shape %cast_1135 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%735 = torch_c.from_builtin_tensor %collapsed_1150 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_1151 = arith.constant 0 : index
%dim_1152 = tensor.dim %collapsed_1150, %c0_1151 : tensor<16x64xf32>
%c1_1153 = arith.constant 1 : index
%dim_1154 = tensor.dim %cast_1141, %c1_1153 : tensor<64x256xf32>
%c1_1155 = arith.constant 1 : index
%dim_1156 = tensor.dim %collapsed_1150, %c1_1155 : tensor<16x64xf32>
%c0_1157 = arith.constant 0 : index
%dim_1158 = tensor.dim %cast_1141, %c0_1157 : tensor<64x256xf32>
%736 = arith.cmpi eq, %dim_1156, %dim_1158 : index
cf.assert %736, "mismatching contracting dimension for torch.aten.mm"
%737 = tensor.empty(%dim_1152, %dim_1154) : tensor<?x?xf32>
%cst_1159 = arith.constant 0.000000e+00 : f32
%738 = linalg.fill ins(%cst_1159 : f32) outs(%737 : tensor<?x?xf32>) -> tensor<?x?xf32>
%739 = linalg.matmul ins(%collapsed_1150, %cast_1141 : tensor<16x64xf32>, tensor<64x256xf32>) outs(%738 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_1160 = tensor.cast %739 : tensor<?x?xf32> to tensor<16x256xf32>
%int2_1161 = torch.constant.int 2
%int8_1162 = torch.constant.int 8
%int256 = torch.constant.int 256
%740 = torch.prim.ListConstruct %int2_1161, %int8_1162, %int256 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1163 = arith.constant 0 : index
%c16_1164 = arith.constant 16 : index
%c1_1165 = arith.constant 1 : index
%c256_1166 = arith.constant 256 : index
%741 = torch_c.to_i64 %int2_1161
%742 = torch_c.to_i64 %int8_1162
%743 = torch_c.to_i64 %int256
%expanded_1167 = tensor.expand_shape %cast_1160 [[0, 1], [2]] : tensor<16x256xf32> into tensor<2x8x256xf32>
%744 = torch_c.from_builtin_tensor %expanded_1167 : tensor<2x8x256xf32> -> !torch.vtensor<[2,8,256],f32>
%745 = torch.aten.silu %744 : !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>
%746 = torch_c.to_builtin_tensor %745 : !torch.vtensor<[2,8,256],f32> -> tensor<2x8x256xf32>
%int0_1168 = torch.constant.int 0
%int1_1169 = torch.constant.int 1
%c0_1170 = arith.constant 0 : index
%c256_1171 = arith.constant 256 : index
%c1_1172 = arith.constant 1 : index
%c64_1173 = arith.constant 64 : index
%747 = tensor.empty() : tensor<64x256xf32>
%748 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<256x64xf32>) outs(%747 : tensor<64x256xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x256xf32>
%cast_1174 = tensor.cast %748 : tensor<64x256xf32> to tensor<64x256xf32>
%749 = torch_c.from_builtin_tensor %cast_1174 : tensor<64x256xf32> -> !torch.vtensor<[64,256],f32>
%int16_1175 = torch.constant.int 16
%int64_1176 = torch.constant.int 64
%750 = torch.prim.ListConstruct %int16_1175, %int64_1176 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_1177 = arith.constant 0 : index
%c2_1178 = arith.constant 2 : index
%c1_1179 = arith.constant 1 : index
%c8_1180 = arith.constant 8 : index
%c2_1181 = arith.constant 2 : index
%c64_1182 = arith.constant 64 : index
%751 = torch_c.to_i64 %int16_1175
%752 = torch_c.to_i64 %int64_1176
%collapsed_1183 = tensor.collapse_shape %cast_1135 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%753 = torch_c.from_builtin_tensor %collapsed_1183 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_1184 = arith.constant 0 : index
%dim_1185 = tensor.dim %collapsed_1183, %c0_1184 : tensor<16x64xf32>
%c1_1186 = arith.constant 1 : index
%dim_1187 = tensor.dim %cast_1174, %c1_1186 : tensor<64x256xf32>
%c1_1188 = arith.constant 1 : index
%dim_1189 = tensor.dim %collapsed_1183, %c1_1188 : tensor<16x64xf32>
%c0_1190 = arith.constant 0 : index
%dim_1191 = tensor.dim %cast_1174, %c0_1190 : tensor<64x256xf32>
%754 = arith.cmpi eq, %dim_1189, %dim_1191 : index
cf.assert %754, "mismatching contracting dimension for torch.aten.mm"
%755 = tensor.empty(%dim_1185, %dim_1187) : tensor<?x?xf32>
%cst_1192 = arith.constant 0.000000e+00 : f32
%756 = linalg.fill ins(%cst_1192 : f32) outs(%755 : tensor<?x?xf32>) -> tensor<?x?xf32>
%757 = linalg.matmul ins(%collapsed_1183, %cast_1174 : tensor<16x64xf32>, tensor<64x256xf32>) outs(%756 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_1193 = tensor.cast %757 : tensor<?x?xf32> to tensor<16x256xf32>
%int2_1194 = torch.constant.int 2
%int8_1195 = torch.constant.int 8
%int256_1196 = torch.constant.int 256
%758 = torch.prim.ListConstruct %int2_1194, %int8_1195, %int256_1196 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1197 = arith.constant 0 : index
%c16_1198 = arith.constant 16 : index
%c1_1199 = arith.constant 1 : index
%c256_1200 = arith.constant 256 : index
%759 = torch_c.to_i64 %int2_1194
%760 = torch_c.to_i64 %int8_1195
%761 = torch_c.to_i64 %int256_1196
%expanded_1201 = tensor.expand_shape %cast_1193 [[0, 1], [2]] : tensor<16x256xf32> into tensor<2x8x256xf32>
%762 = torch_c.from_builtin_tensor %expanded_1201 : tensor<2x8x256xf32> -> !torch.vtensor<[2,8,256],f32>
%c1_1202 = arith.constant 1 : index
%c0_1203 = arith.constant 0 : index
%c2_1204 = arith.constant 2 : index
%c1_1205 = arith.constant 1 : index
%c8_1206 = arith.constant 8 : index
%c2_1207 = arith.constant 2 : index
%c256_1208 = arith.constant 256 : index
%c0_1209 = arith.constant 0 : index
%c2_1210 = arith.constant 2 : index
%763 = arith.cmpi eq, %c2_1204, %c2_1210 : index
cf.assert %763, "mismatched size for broadcast"
%c1_1211 = arith.constant 1 : index
%c8_1212 = arith.constant 8 : index
%764 = arith.cmpi eq, %c8_1206, %c8_1212 : index
cf.assert %764, "mismatched size for broadcast"
%c2_1213 = arith.constant 2 : index
%c256_1214 = arith.constant 256 : index
%765 = arith.cmpi eq, %c256_1208, %c256_1214 : index
cf.assert %765, "mismatched size for broadcast"
%766 = tensor.empty() : tensor<2x8x256xf32>
%767 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%746, %expanded_1201 : tensor<2x8x256xf32>, tensor<2x8x256xf32>) outs(%766 : tensor<2x8x256xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x256xf32>
%cast_1215 = tensor.cast %767 : tensor<2x8x256xf32> to tensor<2x8x256xf32>
%int0_1216 = torch.constant.int 0
%int1_1217 = torch.constant.int 1
%c0_1218 = arith.constant 0 : index
%c64_1219 = arith.constant 64 : index
%c1_1220 = arith.constant 1 : index
%c256_1221 = arith.constant 256 : index
%768 = tensor.empty() : tensor<256x64xf32>
%769 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%15 : tensor<64x256xf32>) outs(%768 : tensor<256x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<256x64xf32>
%cast_1222 = tensor.cast %769 : tensor<256x64xf32> to tensor<256x64xf32>
%770 = torch_c.from_builtin_tensor %cast_1222 : tensor<256x64xf32> -> !torch.vtensor<[256,64],f32>
%int16_1223 = torch.constant.int 16
%int256_1224 = torch.constant.int 256
%771 = torch.prim.ListConstruct %int16_1223, %int256_1224 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_1225 = arith.constant 0 : index
%c2_1226 = arith.constant 2 : index
%c1_1227 = arith.constant 1 : index
%c8_1228 = arith.constant 8 : index
%c2_1229 = arith.constant 2 : index
%c256_1230 = arith.constant 256 : index
%772 = torch_c.to_i64 %int16_1223
%773 = torch_c.to_i64 %int256_1224
%collapsed_1231 = tensor.collapse_shape %cast_1215 [[0, 1], [2]] : tensor<2x8x256xf32> into tensor<16x256xf32>
%774 = torch_c.from_builtin_tensor %collapsed_1231 : tensor<16x256xf32> -> !torch.vtensor<[16,256],f32>
%c0_1232 = arith.constant 0 : index
%dim_1233 = tensor.dim %collapsed_1231, %c0_1232 : tensor<16x256xf32>
%c1_1234 = arith.constant 1 : index
%dim_1235 = tensor.dim %cast_1222, %c1_1234 : tensor<256x64xf32>
%c1_1236 = arith.constant 1 : index
%dim_1237 = tensor.dim %collapsed_1231, %c1_1236 : tensor<16x256xf32>
%c0_1238 = arith.constant 0 : index
%dim_1239 = tensor.dim %cast_1222, %c0_1238 : tensor<256x64xf32>
%775 = arith.cmpi eq, %dim_1237, %dim_1239 : index
cf.assert %775, "mismatching contracting dimension for torch.aten.mm"
%776 = tensor.empty(%dim_1233, %dim_1235) : tensor<?x?xf32>
%cst_1240 = arith.constant 0.000000e+00 : f32
%777 = linalg.fill ins(%cst_1240 : f32) outs(%776 : tensor<?x?xf32>) -> tensor<?x?xf32>
%778 = linalg.matmul ins(%collapsed_1231, %cast_1222 : tensor<16x256xf32>, tensor<256x64xf32>) outs(%777 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_1241 = tensor.cast %778 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_1242 = torch.constant.int 2
%int8_1243 = torch.constant.int 8
%int64_1244 = torch.constant.int 64
%779 = torch.prim.ListConstruct %int2_1242, %int8_1243, %int64_1244 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1245 = arith.constant 0 : index
%c16_1246 = arith.constant 16 : index
%c1_1247 = arith.constant 1 : index
%c64_1248 = arith.constant 64 : index
%780 = torch_c.to_i64 %int2_1242
%781 = torch_c.to_i64 %int8_1243
%782 = torch_c.to_i64 %int64_1244
%expanded_1249 = tensor.expand_shape %cast_1241 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int1_1250 = torch.constant.int 1
%783 = torch_c.to_i64 %int1_1250
%c1_1251 = arith.constant 1 : index
%c0_1252 = arith.constant 0 : index
%c2_1253 = arith.constant 2 : index
%c1_1254 = arith.constant 1 : index
%c8_1255 = arith.constant 8 : index
%c2_1256 = arith.constant 2 : index
%c64_1257 = arith.constant 64 : index
%c0_1258 = arith.constant 0 : index
%c2_1259 = arith.constant 2 : index
%784 = arith.cmpi eq, %c2_1253, %c2_1259 : index
cf.assert %784, "mismatched size for broadcast"
%c1_1260 = arith.constant 1 : index
%c8_1261 = arith.constant 8 : index
%785 = arith.cmpi eq, %c8_1255, %c8_1261 : index
cf.assert %785, "mismatched size for broadcast"
%c2_1262 = arith.constant 2 : index
%c64_1263 = arith.constant 64 : index
%786 = arith.cmpi eq, %c64_1257, %c64_1263 : index
cf.assert %786, "mismatched size for broadcast"
%787 = tensor.empty() : tensor<2x8x64xf32>
%788 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1087, %expanded_1249 : tensor<2x8x64xf32>, tensor<2x8x64xf32>) outs(%787 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.sitofp %783 : i64 to f32
%1544 = arith.mulf %in_2549, %1543 : f32
%1545 = arith.addf %in, %1544 : f32
linalg.yield %1545 : f32
} -> tensor<2x8x64xf32>
%cast_1264 = tensor.cast %788 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%789 = torch_c.from_builtin_tensor %cast_1264 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int2_1265 = torch.constant.int 2
%790 = torch_c.to_i64 %int2_1265
%c1_1266 = arith.constant 1 : index
%c0_1267 = arith.constant 0 : index
%c2_1268 = arith.constant 2 : index
%c1_1269 = arith.constant 1 : index
%c8_1270 = arith.constant 8 : index
%c2_1271 = arith.constant 2 : index
%c64_1272 = arith.constant 64 : index
%791 = tensor.empty() : tensor<2x8x64xf32>
%792 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1264 : tensor<2x8x64xf32>) outs(%791 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.sitofp %790 : i64 to f32
%1544 = math.powf %in, %1543 : f32
linalg.yield %1544 : f32
} -> tensor<2x8x64xf32>
%cast_1273 = tensor.cast %792 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%793 = torch_c.from_builtin_tensor %cast_1273 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int-1_1274 = torch.constant.int -1
%794 = torch.prim.ListConstruct %int-1_1274 : (!torch.int) -> !torch.list<int>
%true_1275 = torch.constant.bool true
%none_1276 = torch.constant.none
%795 = torch.aten.mean.dim %793, %794, %true_1275, %none_1276 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%796 = torch_c.to_builtin_tensor %795 : !torch.vtensor<[2,8,1],f32> -> tensor<2x8x1xf32>
%float1.000000e-05_1277 = torch.constant.float 1.000000e-05
%797 = torch_c.to_f64 %float1.000000e-05_1277
%int1_1278 = torch.constant.int 1
%798 = torch_c.to_i64 %int1_1278
%c1_1279 = arith.constant 1 : index
%c0_1280 = arith.constant 0 : index
%c2_1281 = arith.constant 2 : index
%c1_1282 = arith.constant 1 : index
%c8_1283 = arith.constant 8 : index
%799 = tensor.empty() : tensor<2x8x1xf32>
%800 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%796 : tensor<2x8x1xf32>) outs(%799 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %797 : f64 to f32
%1544 = arith.sitofp %798 : i64 to f32
%1545 = arith.mulf %1543, %1544 : f32
%1546 = arith.addf %in, %1545 : f32
linalg.yield %1546 : f32
} -> tensor<2x8x1xf32>
%cast_1284 = tensor.cast %800 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%c1_1285 = arith.constant 1 : index
%c0_1286 = arith.constant 0 : index
%c2_1287 = arith.constant 2 : index
%c1_1288 = arith.constant 1 : index
%c8_1289 = arith.constant 8 : index
%801 = tensor.empty() : tensor<2x8x1xf32>
%802 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1284 : tensor<2x8x1xf32>) outs(%801 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = math.rsqrt %in : f32
linalg.yield %1543 : f32
} -> tensor<2x8x1xf32>
%cast_1290 = tensor.cast %802 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%803 = torch_c.from_builtin_tensor %cast_1290 : tensor<2x8x1xf32> -> !torch.vtensor<[2,8,1],f32>
%c1_1291 = arith.constant 1 : index
%c0_1292 = arith.constant 0 : index
%c2_1293 = arith.constant 2 : index
%c1_1294 = arith.constant 1 : index
%c8_1295 = arith.constant 8 : index
%c2_1296 = arith.constant 2 : index
%c64_1297 = arith.constant 64 : index
%c0_1298 = arith.constant 0 : index
%c2_1299 = arith.constant 2 : index
%804 = arith.cmpi eq, %c2_1293, %c2_1299 : index
cf.assert %804, "mismatched size for broadcast"
%c1_1300 = arith.constant 1 : index
%c8_1301 = arith.constant 8 : index
%805 = arith.cmpi eq, %c8_1295, %c8_1301 : index
cf.assert %805, "mismatched size for broadcast"
%806 = tensor.empty() : tensor<2x8x64xf32>
%807 = linalg.generic {indexing_maps = [#map2, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1264, %cast_1290 : tensor<2x8x64xf32>, tensor<2x8x1xf32>) outs(%806 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_1302 = tensor.cast %807 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%808 = torch_c.from_builtin_tensor %cast_1302 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%c1_1303 = arith.constant 1 : index
%c0_1304 = arith.constant 0 : index
%c2_1305 = arith.constant 2 : index
%c1_1306 = arith.constant 1 : index
%c8_1307 = arith.constant 8 : index
%c2_1308 = arith.constant 2 : index
%c64_1309 = arith.constant 64 : index
%c0_1310 = arith.constant 0 : index
%c64_1311 = arith.constant 64 : index
%809 = arith.cmpi eq, %c64_1309, %c64_1311 : index
cf.assert %809, "mismatched size for broadcast"
%810 = tensor.empty() : tensor<2x8x64xf32>
%811 = linalg.generic {indexing_maps = [#map2, #map5, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1302, %16 : tensor<2x8x64xf32>, tensor<64xf32>) outs(%810 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_1312 = tensor.cast %811 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%int0_1313 = torch.constant.int 0
%int1_1314 = torch.constant.int 1
%c0_1315 = arith.constant 0 : index
%c64_1316 = arith.constant 64 : index
%c1_1317 = arith.constant 1 : index
%c64_1318 = arith.constant 64 : index
%812 = tensor.empty() : tensor<64x64xf32>
%813 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%17 : tensor<64x64xf32>) outs(%812 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_1319 = tensor.cast %813 : tensor<64x64xf32> to tensor<64x64xf32>
%814 = torch_c.from_builtin_tensor %cast_1319 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16_1320 = torch.constant.int 16
%int64_1321 = torch.constant.int 64
%815 = torch.prim.ListConstruct %int16_1320, %int64_1321 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_1322 = arith.constant 0 : index
%c2_1323 = arith.constant 2 : index
%c1_1324 = arith.constant 1 : index
%c8_1325 = arith.constant 8 : index
%c2_1326 = arith.constant 2 : index
%c64_1327 = arith.constant 64 : index
%816 = torch_c.to_i64 %int16_1320
%817 = torch_c.to_i64 %int64_1321
%collapsed_1328 = tensor.collapse_shape %cast_1312 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%818 = torch_c.from_builtin_tensor %collapsed_1328 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_1329 = arith.constant 0 : index
%dim_1330 = tensor.dim %collapsed_1328, %c0_1329 : tensor<16x64xf32>
%c1_1331 = arith.constant 1 : index
%dim_1332 = tensor.dim %cast_1319, %c1_1331 : tensor<64x64xf32>
%c1_1333 = arith.constant 1 : index
%dim_1334 = tensor.dim %collapsed_1328, %c1_1333 : tensor<16x64xf32>
%c0_1335 = arith.constant 0 : index
%dim_1336 = tensor.dim %cast_1319, %c0_1335 : tensor<64x64xf32>
%819 = arith.cmpi eq, %dim_1334, %dim_1336 : index
cf.assert %819, "mismatching contracting dimension for torch.aten.mm"
%820 = tensor.empty(%dim_1330, %dim_1332) : tensor<?x?xf32>
%cst_1337 = arith.constant 0.000000e+00 : f32
%821 = linalg.fill ins(%cst_1337 : f32) outs(%820 : tensor<?x?xf32>) -> tensor<?x?xf32>
%822 = linalg.matmul ins(%collapsed_1328, %cast_1319 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%821 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_1338 = tensor.cast %822 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_1339 = torch.constant.int 2
%int8_1340 = torch.constant.int 8
%int64_1341 = torch.constant.int 64
%823 = torch.prim.ListConstruct %int2_1339, %int8_1340, %int64_1341 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1342 = arith.constant 0 : index
%c16_1343 = arith.constant 16 : index
%c1_1344 = arith.constant 1 : index
%c64_1345 = arith.constant 64 : index
%824 = torch_c.to_i64 %int2_1339
%825 = torch_c.to_i64 %int8_1340
%826 = torch_c.to_i64 %int64_1341
%expanded_1346 = tensor.expand_shape %cast_1338 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int0_1347 = torch.constant.int 0
%int1_1348 = torch.constant.int 1
%c0_1349 = arith.constant 0 : index
%c64_1350 = arith.constant 64 : index
%c1_1351 = arith.constant 1 : index
%c64_1352 = arith.constant 64 : index
%827 = tensor.empty() : tensor<64x64xf32>
%828 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%18 : tensor<64x64xf32>) outs(%827 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_1353 = tensor.cast %828 : tensor<64x64xf32> to tensor<64x64xf32>
%829 = torch_c.from_builtin_tensor %cast_1353 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16_1354 = torch.constant.int 16
%int64_1355 = torch.constant.int 64
%830 = torch.prim.ListConstruct %int16_1354, %int64_1355 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_1356 = arith.constant 0 : index
%c2_1357 = arith.constant 2 : index
%c1_1358 = arith.constant 1 : index
%c8_1359 = arith.constant 8 : index
%c2_1360 = arith.constant 2 : index
%c64_1361 = arith.constant 64 : index
%831 = torch_c.to_i64 %int16_1354
%832 = torch_c.to_i64 %int64_1355
%collapsed_1362 = tensor.collapse_shape %cast_1312 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%833 = torch_c.from_builtin_tensor %collapsed_1362 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_1363 = arith.constant 0 : index
%dim_1364 = tensor.dim %collapsed_1362, %c0_1363 : tensor<16x64xf32>
%c1_1365 = arith.constant 1 : index
%dim_1366 = tensor.dim %cast_1353, %c1_1365 : tensor<64x64xf32>
%c1_1367 = arith.constant 1 : index
%dim_1368 = tensor.dim %collapsed_1362, %c1_1367 : tensor<16x64xf32>
%c0_1369 = arith.constant 0 : index
%dim_1370 = tensor.dim %cast_1353, %c0_1369 : tensor<64x64xf32>
%834 = arith.cmpi eq, %dim_1368, %dim_1370 : index
cf.assert %834, "mismatching contracting dimension for torch.aten.mm"
%835 = tensor.empty(%dim_1364, %dim_1366) : tensor<?x?xf32>
%cst_1371 = arith.constant 0.000000e+00 : f32
%836 = linalg.fill ins(%cst_1371 : f32) outs(%835 : tensor<?x?xf32>) -> tensor<?x?xf32>
%837 = linalg.matmul ins(%collapsed_1362, %cast_1353 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%836 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_1372 = tensor.cast %837 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_1373 = torch.constant.int 2
%int8_1374 = torch.constant.int 8
%int64_1375 = torch.constant.int 64
%838 = torch.prim.ListConstruct %int2_1373, %int8_1374, %int64_1375 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1376 = arith.constant 0 : index
%c16_1377 = arith.constant 16 : index
%c1_1378 = arith.constant 1 : index
%c64_1379 = arith.constant 64 : index
%839 = torch_c.to_i64 %int2_1373
%840 = torch_c.to_i64 %int8_1374
%841 = torch_c.to_i64 %int64_1375
%expanded_1380 = tensor.expand_shape %cast_1372 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int0_1381 = torch.constant.int 0
%int1_1382 = torch.constant.int 1
%c0_1383 = arith.constant 0 : index
%c64_1384 = arith.constant 64 : index
%c1_1385 = arith.constant 1 : index
%c64_1386 = arith.constant 64 : index
%842 = tensor.empty() : tensor<64x64xf32>
%843 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%19 : tensor<64x64xf32>) outs(%842 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_1387 = tensor.cast %843 : tensor<64x64xf32> to tensor<64x64xf32>
%844 = torch_c.from_builtin_tensor %cast_1387 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16_1388 = torch.constant.int 16
%int64_1389 = torch.constant.int 64
%845 = torch.prim.ListConstruct %int16_1388, %int64_1389 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_1390 = arith.constant 0 : index
%c2_1391 = arith.constant 2 : index
%c1_1392 = arith.constant 1 : index
%c8_1393 = arith.constant 8 : index
%c2_1394 = arith.constant 2 : index
%c64_1395 = arith.constant 64 : index
%846 = torch_c.to_i64 %int16_1388
%847 = torch_c.to_i64 %int64_1389
%collapsed_1396 = tensor.collapse_shape %cast_1312 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%848 = torch_c.from_builtin_tensor %collapsed_1396 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_1397 = arith.constant 0 : index
%dim_1398 = tensor.dim %collapsed_1396, %c0_1397 : tensor<16x64xf32>
%c1_1399 = arith.constant 1 : index
%dim_1400 = tensor.dim %cast_1387, %c1_1399 : tensor<64x64xf32>
%c1_1401 = arith.constant 1 : index
%dim_1402 = tensor.dim %collapsed_1396, %c1_1401 : tensor<16x64xf32>
%c0_1403 = arith.constant 0 : index
%dim_1404 = tensor.dim %cast_1387, %c0_1403 : tensor<64x64xf32>
%849 = arith.cmpi eq, %dim_1402, %dim_1404 : index
cf.assert %849, "mismatching contracting dimension for torch.aten.mm"
%850 = tensor.empty(%dim_1398, %dim_1400) : tensor<?x?xf32>
%cst_1405 = arith.constant 0.000000e+00 : f32
%851 = linalg.fill ins(%cst_1405 : f32) outs(%850 : tensor<?x?xf32>) -> tensor<?x?xf32>
%852 = linalg.matmul ins(%collapsed_1396, %cast_1387 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%851 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_1406 = tensor.cast %852 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_1407 = torch.constant.int 2
%int8_1408 = torch.constant.int 8
%int64_1409 = torch.constant.int 64
%853 = torch.prim.ListConstruct %int2_1407, %int8_1408, %int64_1409 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1410 = arith.constant 0 : index
%c16_1411 = arith.constant 16 : index
%c1_1412 = arith.constant 1 : index
%c64_1413 = arith.constant 64 : index
%854 = torch_c.to_i64 %int2_1407
%855 = torch_c.to_i64 %int8_1408
%856 = torch_c.to_i64 %int64_1409
%expanded_1414 = tensor.expand_shape %cast_1406 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int2_1415 = torch.constant.int 2
%int8_1416 = torch.constant.int 8
%int4_1417 = torch.constant.int 4
%int16_1418 = torch.constant.int 16
%857 = torch.prim.ListConstruct %int2_1415, %int8_1416, %int4_1417, %int16_1418 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1419 = arith.constant 0 : index
%c2_1420 = arith.constant 2 : index
%c1_1421 = arith.constant 1 : index
%c8_1422 = arith.constant 8 : index
%c2_1423 = arith.constant 2 : index
%c64_1424 = arith.constant 64 : index
%858 = torch_c.to_i64 %int2_1415
%859 = torch_c.to_i64 %int8_1416
%860 = torch_c.to_i64 %int4_1417
%861 = torch_c.to_i64 %int16_1418
%expanded_1425 = tensor.expand_shape %expanded_1346 [[0], [1], [2, 3]] : tensor<2x8x64xf32> into tensor<2x8x4x16xf32>
%int2_1426 = torch.constant.int 2
%int8_1427 = torch.constant.int 8
%int4_1428 = torch.constant.int 4
%int16_1429 = torch.constant.int 16
%862 = torch.prim.ListConstruct %int2_1426, %int8_1427, %int4_1428, %int16_1429 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1430 = arith.constant 0 : index
%c2_1431 = arith.constant 2 : index
%c1_1432 = arith.constant 1 : index
%c8_1433 = arith.constant 8 : index
%c2_1434 = arith.constant 2 : index
%c64_1435 = arith.constant 64 : index
%863 = torch_c.to_i64 %int2_1426
%864 = torch_c.to_i64 %int8_1427
%865 = torch_c.to_i64 %int4_1428
%866 = torch_c.to_i64 %int16_1429
%expanded_1436 = tensor.expand_shape %expanded_1380 [[0], [1], [2, 3]] : tensor<2x8x64xf32> into tensor<2x8x4x16xf32>
%int2_1437 = torch.constant.int 2
%int8_1438 = torch.constant.int 8
%int4_1439 = torch.constant.int 4
%int16_1440 = torch.constant.int 16
%867 = torch.prim.ListConstruct %int2_1437, %int8_1438, %int4_1439, %int16_1440 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1441 = arith.constant 0 : index
%c2_1442 = arith.constant 2 : index
%c1_1443 = arith.constant 1 : index
%c8_1444 = arith.constant 8 : index
%c2_1445 = arith.constant 2 : index
%c64_1446 = arith.constant 64 : index
%868 = torch_c.to_i64 %int2_1437
%869 = torch_c.to_i64 %int8_1438
%870 = torch_c.to_i64 %int4_1439
%871 = torch_c.to_i64 %int16_1440
%expanded_1447 = tensor.expand_shape %expanded_1414 [[0], [1], [2, 3]] : tensor<2x8x64xf32> into tensor<2x8x4x16xf32>
%int2_1448 = torch.constant.int 2
%int8_1449 = torch.constant.int 8
%int4_1450 = torch.constant.int 4
%int-1_1451 = torch.constant.int -1
%int2_1452 = torch.constant.int 2
%872 = torch.prim.ListConstruct %int2_1448, %int8_1449, %int4_1450, %int-1_1451, %int2_1452 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1453 = arith.constant 0 : index
%c2_1454 = arith.constant 2 : index
%c1_1455 = arith.constant 1 : index
%c8_1456 = arith.constant 8 : index
%c2_1457 = arith.constant 2 : index
%c4_1458 = arith.constant 4 : index
%c3_1459 = arith.constant 3 : index
%c16_1460 = arith.constant 16 : index
%873 = torch_c.to_i64 %int2_1448
%874 = torch_c.to_i64 %int8_1449
%875 = torch_c.to_i64 %int4_1450
%876 = torch_c.to_i64 %int-1_1451
%877 = torch_c.to_i64 %int2_1452
%expanded_1461 = tensor.expand_shape %expanded_1425 [[0], [1], [2], [3, 4]] : tensor<2x8x4x16xf32> into tensor<2x8x4x8x2xf32>
%c0_1462 = arith.constant 0 : index
%dim_1463 = tensor.dim %expanded_1461, %c0_1462 : tensor<2x8x4x8x2xf32>
%c1_1464 = arith.constant 1 : index
%dim_1465 = tensor.dim %expanded_1461, %c1_1464 : tensor<2x8x4x8x2xf32>
%c2_1466 = arith.constant 2 : index
%dim_1467 = tensor.dim %expanded_1461, %c2_1466 : tensor<2x8x4x8x2xf32>
%c3_1468 = arith.constant 3 : index
%dim_1469 = tensor.dim %expanded_1461, %c3_1468 : tensor<2x8x4x8x2xf32>
%878 = tensor.empty(%dim_1463, %dim_1465, %dim_1467, %dim_1469) : tensor<?x?x?x?xcomplex<f32>>
%c0_1470 = arith.constant 0 : index
%c1_1471 = arith.constant 1 : index
%879 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%878 : tensor<?x?x?x?xcomplex<f32>>) {
^bb0(%out: complex<f32>):
%1543 = linalg.index 0 : index
%1544 = linalg.index 0 : index
%1545 = linalg.index 1 : index
%1546 = linalg.index 1 : index
%1547 = linalg.index 2 : index
%1548 = linalg.index 2 : index
%1549 = linalg.index 3 : index
%1550 = linalg.index 3 : index
%extracted = tensor.extract %expanded_1461[%1543, %1545, %1547, %1549, %c0_1470] : tensor<2x8x4x8x2xf32>
%extracted_2549 = tensor.extract %expanded_1461[%1544, %1546, %1548, %1550, %c1_1471] : tensor<2x8x4x8x2xf32>
%1551 = complex.create %extracted, %extracted_2549 : complex<f32>
linalg.yield %1551 : complex<f32>
} -> tensor<?x?x?x?xcomplex<f32>>
%cast_1472 = tensor.cast %879 : tensor<?x?x?x?xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%int2_1473 = torch.constant.int 2
%int8_1474 = torch.constant.int 8
%int4_1475 = torch.constant.int 4
%int-1_1476 = torch.constant.int -1
%int2_1477 = torch.constant.int 2
%880 = torch.prim.ListConstruct %int2_1473, %int8_1474, %int4_1475, %int-1_1476, %int2_1477 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1478 = arith.constant 0 : index
%c2_1479 = arith.constant 2 : index
%c1_1480 = arith.constant 1 : index
%c8_1481 = arith.constant 8 : index
%c2_1482 = arith.constant 2 : index
%c4_1483 = arith.constant 4 : index
%c3_1484 = arith.constant 3 : index
%c16_1485 = arith.constant 16 : index
%881 = torch_c.to_i64 %int2_1473
%882 = torch_c.to_i64 %int8_1474
%883 = torch_c.to_i64 %int4_1475
%884 = torch_c.to_i64 %int-1_1476
%885 = torch_c.to_i64 %int2_1477
%expanded_1486 = tensor.expand_shape %expanded_1436 [[0], [1], [2], [3, 4]] : tensor<2x8x4x16xf32> into tensor<2x8x4x8x2xf32>
%c0_1487 = arith.constant 0 : index
%dim_1488 = tensor.dim %expanded_1486, %c0_1487 : tensor<2x8x4x8x2xf32>
%c1_1489 = arith.constant 1 : index
%dim_1490 = tensor.dim %expanded_1486, %c1_1489 : tensor<2x8x4x8x2xf32>
%c2_1491 = arith.constant 2 : index
%dim_1492 = tensor.dim %expanded_1486, %c2_1491 : tensor<2x8x4x8x2xf32>
%c3_1493 = arith.constant 3 : index
%dim_1494 = tensor.dim %expanded_1486, %c3_1493 : tensor<2x8x4x8x2xf32>
%886 = tensor.empty(%dim_1488, %dim_1490, %dim_1492, %dim_1494) : tensor<?x?x?x?xcomplex<f32>>
%c0_1495 = arith.constant 0 : index
%c1_1496 = arith.constant 1 : index
%887 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%886 : tensor<?x?x?x?xcomplex<f32>>) {
^bb0(%out: complex<f32>):
%1543 = linalg.index 0 : index
%1544 = linalg.index 0 : index
%1545 = linalg.index 1 : index
%1546 = linalg.index 1 : index
%1547 = linalg.index 2 : index
%1548 = linalg.index 2 : index
%1549 = linalg.index 3 : index
%1550 = linalg.index 3 : index
%extracted = tensor.extract %expanded_1486[%1543, %1545, %1547, %1549, %c0_1495] : tensor<2x8x4x8x2xf32>
%extracted_2549 = tensor.extract %expanded_1486[%1544, %1546, %1548, %1550, %c1_1496] : tensor<2x8x4x8x2xf32>
%1551 = complex.create %extracted, %extracted_2549 : complex<f32>
linalg.yield %1551 : complex<f32>
} -> tensor<?x?x?x?xcomplex<f32>>
%cast_1497 = tensor.cast %887 : tensor<?x?x?x?xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%int1_1498 = torch.constant.int 1
%int8_1499 = torch.constant.int 8
%int1_1500 = torch.constant.int 1
%int8_1501 = torch.constant.int 8
%888 = torch.prim.ListConstruct %int1_1498, %int8_1499, %int1_1500, %int8_1501 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1502 = arith.constant 0 : index
%c8_1503 = arith.constant 8 : index
%c1_1504 = arith.constant 1 : index
%c8_1505 = arith.constant 8 : index
%889 = torch_c.to_i64 %int1_1498
%890 = torch_c.to_i64 %int8_1499
%891 = torch_c.to_i64 %int1_1500
%892 = torch_c.to_i64 %int8_1501
%expanded_1506 = tensor.expand_shape %cast_54 [[0, 1], [2, 3]] : tensor<8x8xcomplex<f32>> into tensor<1x8x1x8xcomplex<f32>>
%893 = torch_c.from_builtin_tensor %expanded_1506 : tensor<1x8x1x8xcomplex<f32>> -> !torch.vtensor<[1,8,1,8],complex<f32>>
%c1_1507 = arith.constant 1 : index
%c0_1508 = arith.constant 0 : index
%c2_1509 = arith.constant 2 : index
%c1_1510 = arith.constant 1 : index
%c8_1511 = arith.constant 8 : index
%c2_1512 = arith.constant 2 : index
%c4_1513 = arith.constant 4 : index
%c3_1514 = arith.constant 3 : index
%c8_1515 = arith.constant 8 : index
%c1_1516 = arith.constant 1 : index
%c8_1517 = arith.constant 8 : index
%894 = arith.cmpi eq, %c8_1511, %c8_1517 : index
cf.assert %894, "mismatched size for broadcast"
%c3_1518 = arith.constant 3 : index
%c8_1519 = arith.constant 8 : index
%895 = arith.cmpi eq, %c8_1515, %c8_1519 : index
cf.assert %895, "mismatched size for broadcast"
%896 = tensor.empty() : tensor<2x8x4x8xcomplex<f32>>
%897 = linalg.generic {indexing_maps = [#map, #map8, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_1472, %expanded_1506 : tensor<2x8x4x8xcomplex<f32>>, tensor<1x8x1x8xcomplex<f32>>) outs(%896 : tensor<2x8x4x8xcomplex<f32>>) {
^bb0(%in: complex<f32>, %in_2549: complex<f32>, %out: complex<f32>):
%1543 = complex.mul %in, %in_2549 : complex<f32>
linalg.yield %1543 : complex<f32>
} -> tensor<2x8x4x8xcomplex<f32>>
%cast_1520 = tensor.cast %897 : tensor<2x8x4x8xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%c2_1521 = arith.constant 2 : index
%898 = tensor.empty(%c2_1521) : tensor<2x8x4x8x?xf32>
%c0_1522 = arith.constant 0 : index
%899 = linalg.generic {indexing_maps = [#map9, #map10], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cast_1520 : tensor<2x8x4x8xcomplex<f32>>) outs(%898 : tensor<2x8x4x8x?xf32>) {
^bb0(%in: complex<f32>, %out: f32):
%1543 = complex.re %in : complex<f32>
%1544 = complex.im %in : complex<f32>
%1545 = linalg.index 4 : index
%1546 = arith.cmpi eq, %1545, %c0_1522 : index
%1547 = arith.select %1546, %1543, %1544 : f32
linalg.yield %1547 : f32
} -> tensor<2x8x4x8x?xf32>
%cast_1523 = tensor.cast %899 : tensor<2x8x4x8x?xf32> to tensor<2x8x4x8x2xf32>
%int2_1524 = torch.constant.int 2
%int8_1525 = torch.constant.int 8
%int4_1526 = torch.constant.int 4
%int16_1527 = torch.constant.int 16
%900 = torch.prim.ListConstruct %int2_1524, %int8_1525, %int4_1526, %int16_1527 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1528 = arith.constant 0 : index
%c2_1529 = arith.constant 2 : index
%c1_1530 = arith.constant 1 : index
%c8_1531 = arith.constant 8 : index
%c2_1532 = arith.constant 2 : index
%c4_1533 = arith.constant 4 : index
%c3_1534 = arith.constant 3 : index
%c8_1535 = arith.constant 8 : index
%c4_1536 = arith.constant 4 : index
%c2_1537 = arith.constant 2 : index
%901 = torch_c.to_i64 %int2_1524
%902 = torch_c.to_i64 %int8_1525
%903 = torch_c.to_i64 %int4_1526
%904 = torch_c.to_i64 %int16_1527
%collapsed_1538 = tensor.collapse_shape %cast_1523 [[0], [1], [2], [3, 4]] : tensor<2x8x4x8x2xf32> into tensor<2x8x4x16xf32>
%c1_1539 = arith.constant 1 : index
%c0_1540 = arith.constant 0 : index
%c2_1541 = arith.constant 2 : index
%c1_1542 = arith.constant 1 : index
%c8_1543 = arith.constant 8 : index
%c2_1544 = arith.constant 2 : index
%c4_1545 = arith.constant 4 : index
%c3_1546 = arith.constant 3 : index
%c8_1547 = arith.constant 8 : index
%c1_1548 = arith.constant 1 : index
%c8_1549 = arith.constant 8 : index
%905 = arith.cmpi eq, %c8_1543, %c8_1549 : index
cf.assert %905, "mismatched size for broadcast"
%c3_1550 = arith.constant 3 : index
%c8_1551 = arith.constant 8 : index
%906 = arith.cmpi eq, %c8_1547, %c8_1551 : index
cf.assert %906, "mismatched size for broadcast"
%907 = tensor.empty() : tensor<2x8x4x8xcomplex<f32>>
%908 = linalg.generic {indexing_maps = [#map, #map8, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_1497, %expanded_1506 : tensor<2x8x4x8xcomplex<f32>>, tensor<1x8x1x8xcomplex<f32>>) outs(%907 : tensor<2x8x4x8xcomplex<f32>>) {
^bb0(%in: complex<f32>, %in_2549: complex<f32>, %out: complex<f32>):
%1543 = complex.mul %in, %in_2549 : complex<f32>
linalg.yield %1543 : complex<f32>
} -> tensor<2x8x4x8xcomplex<f32>>
%cast_1552 = tensor.cast %908 : tensor<2x8x4x8xcomplex<f32>> to tensor<2x8x4x8xcomplex<f32>>
%c2_1553 = arith.constant 2 : index
%909 = tensor.empty(%c2_1553) : tensor<2x8x4x8x?xf32>
%c0_1554 = arith.constant 0 : index
%910 = linalg.generic {indexing_maps = [#map9, #map10], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cast_1552 : tensor<2x8x4x8xcomplex<f32>>) outs(%909 : tensor<2x8x4x8x?xf32>) {
^bb0(%in: complex<f32>, %out: f32):
%1543 = complex.re %in : complex<f32>
%1544 = complex.im %in : complex<f32>
%1545 = linalg.index 4 : index
%1546 = arith.cmpi eq, %1545, %c0_1554 : index
%1547 = arith.select %1546, %1543, %1544 : f32
linalg.yield %1547 : f32
} -> tensor<2x8x4x8x?xf32>
%cast_1555 = tensor.cast %910 : tensor<2x8x4x8x?xf32> to tensor<2x8x4x8x2xf32>
%int2_1556 = torch.constant.int 2
%int8_1557 = torch.constant.int 8
%int4_1558 = torch.constant.int 4
%int16_1559 = torch.constant.int 16
%911 = torch.prim.ListConstruct %int2_1556, %int8_1557, %int4_1558, %int16_1559 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_1560 = arith.constant 0 : index
%c2_1561 = arith.constant 2 : index
%c1_1562 = arith.constant 1 : index
%c8_1563 = arith.constant 8 : index
%c2_1564 = arith.constant 2 : index
%c4_1565 = arith.constant 4 : index
%c3_1566 = arith.constant 3 : index
%c8_1567 = arith.constant 8 : index
%c4_1568 = arith.constant 4 : index
%c2_1569 = arith.constant 2 : index
%912 = torch_c.to_i64 %int2_1556
%913 = torch_c.to_i64 %int8_1557
%914 = torch_c.to_i64 %int4_1558
%915 = torch_c.to_i64 %int16_1559
%collapsed_1570 = tensor.collapse_shape %cast_1555 [[0], [1], [2], [3, 4]] : tensor<2x8x4x8x2xf32> into tensor<2x8x4x16xf32>
%int0_1571 = torch.constant.int 0
%int0_1572 = torch.constant.int 0
%916 = torch_c.to_i64 %int0_1572
%int2_1573 = torch.constant.int 2
%917 = torch_c.to_i64 %int2_1573
%int1_1574 = torch.constant.int 1
%c0_1575 = arith.constant 0 : index
%c1_1576 = arith.constant 1 : index
%c0_1577 = arith.constant 0 : index
%c32_1578 = arith.constant 32 : index
%c1_1579 = arith.constant 1 : index
%c2048_1580 = arith.constant 2048 : index
%c2_1581 = arith.constant 2 : index
%c4_1582 = arith.constant 4 : index
%c3_1583 = arith.constant 3 : index
%c16_1584 = arith.constant 16 : index
%918 = arith.index_cast %c32_1578 : index to i64
%919 = arith.addi %916, %918 : i64
%c0_i64_1585 = arith.constant 0 : i64
%920 = arith.cmpi sge, %916, %c0_i64_1585 : i64
%921 = arith.select %920, %916, %919 : i64
%c0_i64_1586 = arith.constant 0 : i64
%922 = arith.cmpi slt, %921, %c0_i64_1586 : i64
%923 = arith.select %922, %c0_i64_1586, %921 : i64
%924 = arith.cmpi sgt, %923, %918 : i64
%925 = arith.select %924, %918, %923 : i64
%926 = arith.index_cast %925 : i64 to index
%927 = arith.index_cast %c32_1578 : index to i64
%928 = arith.addi %917, %927 : i64
%c0_i64_1587 = arith.constant 0 : i64
%929 = arith.cmpi sge, %917, %c0_i64_1587 : i64
%930 = arith.select %929, %917, %928 : i64
%c0_i64_1588 = arith.constant 0 : i64
%931 = arith.cmpi slt, %930, %c0_i64_1588 : i64
%932 = arith.select %931, %c0_i64_1588, %930 : i64
%933 = arith.cmpi sgt, %932, %927 : i64
%934 = arith.select %933, %927, %932 : i64
%935 = arith.index_cast %934 : i64 to index
%936 = arith.cmpi sge, %935, %926 : index
%937 = arith.select %936, %935, %926 : index
%c1_1589 = arith.constant 1 : index
%c0_1590 = arith.constant 0 : index
%c32_1591 = arith.constant 32 : index
%c1_1592 = arith.constant 1 : index
%c2048_1593 = arith.constant 2048 : index
%c2_1594 = arith.constant 2 : index
%c4_1595 = arith.constant 4 : index
%c3_1596 = arith.constant 3 : index
%c16_1597 = arith.constant 16 : index
%938 = arith.subi %937, %926 : index
%939 = arith.addi %938, %c1_1589 : index
%940 = arith.subi %939, %c1_1576 : index
%941 = arith.floordivsi %940, %c1_1589 : index
%942 = arith.muli %c1_1576, %c1_1589 : index
%extracted_slice_1598 = tensor.extract_slice %cast_22[%926, %c0_1575, %c0_1575, %c0_1575] [%941, %c2048_1593, %c4_1595, %c16_1597] [%942, %c1_1576, %c1_1576, %c1_1576] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1599 = tensor.cast %extracted_slice_1598 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_1600 = torch.constant.int 1
%int0_1601 = torch.constant.int 0
%943 = torch_c.to_i64 %int0_1601
%int8_1602 = torch.constant.int 8
%944 = torch_c.to_i64 %int8_1602
%int1_1603 = torch.constant.int 1
%c0_1604 = arith.constant 0 : index
%c1_1605 = arith.constant 1 : index
%c0_1606 = arith.constant 0 : index
%c2_1607 = arith.constant 2 : index
%c1_1608 = arith.constant 1 : index
%c2048_1609 = arith.constant 2048 : index
%c2_1610 = arith.constant 2 : index
%c4_1611 = arith.constant 4 : index
%c3_1612 = arith.constant 3 : index
%c16_1613 = arith.constant 16 : index
%945 = arith.index_cast %c2048_1609 : index to i64
%946 = arith.addi %943, %945 : i64
%c0_i64_1614 = arith.constant 0 : i64
%947 = arith.cmpi sge, %943, %c0_i64_1614 : i64
%948 = arith.select %947, %943, %946 : i64
%c0_i64_1615 = arith.constant 0 : i64
%949 = arith.cmpi slt, %948, %c0_i64_1615 : i64
%950 = arith.select %949, %c0_i64_1615, %948 : i64
%951 = arith.cmpi sgt, %950, %945 : i64
%952 = arith.select %951, %945, %950 : i64
%953 = arith.index_cast %952 : i64 to index
%954 = arith.index_cast %c2048_1609 : index to i64
%955 = arith.addi %944, %954 : i64
%c0_i64_1616 = arith.constant 0 : i64
%956 = arith.cmpi sge, %944, %c0_i64_1616 : i64
%957 = arith.select %956, %944, %955 : i64
%c0_i64_1617 = arith.constant 0 : i64
%958 = arith.cmpi slt, %957, %c0_i64_1617 : i64
%959 = arith.select %958, %c0_i64_1617, %957 : i64
%960 = arith.cmpi sgt, %959, %954 : i64
%961 = arith.select %960, %954, %959 : i64
%962 = arith.index_cast %961 : i64 to index
%963 = arith.cmpi sge, %962, %953 : index
%964 = arith.select %963, %962, %953 : index
%c1_1618 = arith.constant 1 : index
%c0_1619 = arith.constant 0 : index
%c2_1620 = arith.constant 2 : index
%c1_1621 = arith.constant 1 : index
%c2048_1622 = arith.constant 2048 : index
%c2_1623 = arith.constant 2 : index
%c4_1624 = arith.constant 4 : index
%c3_1625 = arith.constant 3 : index
%c16_1626 = arith.constant 16 : index
%965 = arith.subi %964, %953 : index
%966 = arith.addi %965, %c1_1618 : index
%967 = arith.subi %966, %c1_1605 : index
%968 = arith.floordivsi %967, %c1_1618 : index
%969 = arith.muli %c1_1605, %c1_1618 : index
%extracted_slice_1627 = tensor.extract_slice %cast_1599[%c0_1604, %953, %c0_1604, %c0_1604] [%c2_1620, %968, %c4_1624, %c16_1626] [%c1_1605, %969, %c1_1605, %c1_1605] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1628 = tensor.cast %extracted_slice_1627 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%false_1629 = torch.constant.bool false
%c0_1630 = arith.constant 0 : index
%c2_1631 = arith.constant 2 : index
%c1_1632 = arith.constant 1 : index
%c8_1633 = arith.constant 8 : index
%c2_1634 = arith.constant 2 : index
%c4_1635 = arith.constant 4 : index
%c3_1636 = arith.constant 3 : index
%c16_1637 = arith.constant 16 : index
%970 = arith.index_cast %c2_1631 : index to i64
%971 = arith.index_cast %c8_1633 : index to i64
%972 = arith.index_cast %c4_1635 : index to i64
%973 = arith.index_cast %c16_1637 : index to i64
%c0_i64_1638 = arith.constant 0 : i64
%c0_1639 = arith.constant 0 : index
%c1_1640 = arith.constant 1 : index
%974 = tensor.empty() : tensor<2x8x4x16xf32>
%cast_1641 = tensor.cast %collapsed_1570 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%975 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_1641 : tensor<2x8x4x16xf32>) outs(%cast_1628 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_1642 = tensor.cast %975 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int0_1643 = torch.constant.int 0
%int0_1644 = torch.constant.int 0
%976 = torch_c.to_i64 %int0_1644
%int2_1645 = torch.constant.int 2
%977 = torch_c.to_i64 %int2_1645
%int1_1646 = torch.constant.int 1
%c0_1647 = arith.constant 0 : index
%c1_1648 = arith.constant 1 : index
%c0_1649 = arith.constant 0 : index
%c32_1650 = arith.constant 32 : index
%c1_1651 = arith.constant 1 : index
%c2048_1652 = arith.constant 2048 : index
%c2_1653 = arith.constant 2 : index
%c4_1654 = arith.constant 4 : index
%c3_1655 = arith.constant 3 : index
%c16_1656 = arith.constant 16 : index
%978 = arith.index_cast %c32_1650 : index to i64
%979 = arith.addi %976, %978 : i64
%c0_i64_1657 = arith.constant 0 : i64
%980 = arith.cmpi sge, %976, %c0_i64_1657 : i64
%981 = arith.select %980, %976, %979 : i64
%c0_i64_1658 = arith.constant 0 : i64
%982 = arith.cmpi slt, %981, %c0_i64_1658 : i64
%983 = arith.select %982, %c0_i64_1658, %981 : i64
%984 = arith.cmpi sgt, %983, %978 : i64
%985 = arith.select %984, %978, %983 : i64
%986 = arith.index_cast %985 : i64 to index
%987 = arith.index_cast %c32_1650 : index to i64
%988 = arith.addi %977, %987 : i64
%c0_i64_1659 = arith.constant 0 : i64
%989 = arith.cmpi sge, %977, %c0_i64_1659 : i64
%990 = arith.select %989, %977, %988 : i64
%c0_i64_1660 = arith.constant 0 : i64
%991 = arith.cmpi slt, %990, %c0_i64_1660 : i64
%992 = arith.select %991, %c0_i64_1660, %990 : i64
%993 = arith.cmpi sgt, %992, %987 : i64
%994 = arith.select %993, %987, %992 : i64
%995 = arith.index_cast %994 : i64 to index
%996 = arith.cmpi sge, %995, %986 : index
%997 = arith.select %996, %995, %986 : index
%c1_1661 = arith.constant 1 : index
%c0_1662 = arith.constant 0 : index
%c32_1663 = arith.constant 32 : index
%c1_1664 = arith.constant 1 : index
%c2048_1665 = arith.constant 2048 : index
%c2_1666 = arith.constant 2 : index
%c4_1667 = arith.constant 4 : index
%c3_1668 = arith.constant 3 : index
%c16_1669 = arith.constant 16 : index
%998 = arith.subi %997, %986 : index
%999 = arith.addi %998, %c1_1661 : index
%1000 = arith.subi %999, %c1_1648 : index
%1001 = arith.floordivsi %1000, %c1_1661 : index
%1002 = arith.muli %c1_1648, %c1_1661 : index
%extracted_slice_1670 = tensor.extract_slice %cast_22[%986, %c0_1647, %c0_1647, %c0_1647] [%1001, %c2048_1665, %c4_1667, %c16_1669] [%1002, %c1_1648, %c1_1648, %c1_1648] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1671 = tensor.cast %extracted_slice_1670 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_1672 = torch.constant.int 1
%int0_1673 = torch.constant.int 0
%1003 = torch_c.to_i64 %int0_1673
%int8_1674 = torch.constant.int 8
%1004 = torch_c.to_i64 %int8_1674
%int1_1675 = torch.constant.int 1
%c0_1676 = arith.constant 0 : index
%c1_1677 = arith.constant 1 : index
%c0_1678 = arith.constant 0 : index
%c2_1679 = arith.constant 2 : index
%c1_1680 = arith.constant 1 : index
%c2048_1681 = arith.constant 2048 : index
%c2_1682 = arith.constant 2 : index
%c4_1683 = arith.constant 4 : index
%c3_1684 = arith.constant 3 : index
%c16_1685 = arith.constant 16 : index
%1005 = arith.index_cast %c2048_1681 : index to i64
%1006 = arith.addi %1003, %1005 : i64
%c0_i64_1686 = arith.constant 0 : i64
%1007 = arith.cmpi sge, %1003, %c0_i64_1686 : i64
%1008 = arith.select %1007, %1003, %1006 : i64
%c0_i64_1687 = arith.constant 0 : i64
%1009 = arith.cmpi slt, %1008, %c0_i64_1687 : i64
%1010 = arith.select %1009, %c0_i64_1687, %1008 : i64
%1011 = arith.cmpi sgt, %1010, %1005 : i64
%1012 = arith.select %1011, %1005, %1010 : i64
%1013 = arith.index_cast %1012 : i64 to index
%1014 = arith.index_cast %c2048_1681 : index to i64
%1015 = arith.addi %1004, %1014 : i64
%c0_i64_1688 = arith.constant 0 : i64
%1016 = arith.cmpi sge, %1004, %c0_i64_1688 : i64
%1017 = arith.select %1016, %1004, %1015 : i64
%c0_i64_1689 = arith.constant 0 : i64
%1018 = arith.cmpi slt, %1017, %c0_i64_1689 : i64
%1019 = arith.select %1018, %c0_i64_1689, %1017 : i64
%1020 = arith.cmpi sgt, %1019, %1014 : i64
%1021 = arith.select %1020, %1014, %1019 : i64
%1022 = arith.index_cast %1021 : i64 to index
%1023 = arith.cmpi sge, %1022, %1013 : index
%1024 = arith.select %1023, %1022, %1013 : index
%c1_1690 = arith.constant 1 : index
%c0_1691 = arith.constant 0 : index
%c2_1692 = arith.constant 2 : index
%c1_1693 = arith.constant 1 : index
%c2048_1694 = arith.constant 2048 : index
%c2_1695 = arith.constant 2 : index
%c4_1696 = arith.constant 4 : index
%c3_1697 = arith.constant 3 : index
%c16_1698 = arith.constant 16 : index
%1025 = arith.subi %1024, %1013 : index
%1026 = arith.addi %1025, %c1_1690 : index
%1027 = arith.subi %1026, %c1_1677 : index
%1028 = arith.floordivsi %1027, %c1_1690 : index
%1029 = arith.muli %c1_1677, %c1_1690 : index
%cast_1699 = tensor.cast %cast_1642 : tensor<2x8x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice_1700 = tensor.insert_slice %cast_1699 into %cast_1671[%c0_1676, %1013, %c0_1676, %c0_1676] [%c2_1692, %1028, %c4_1696, %c16_1698] [%c1_1677, %1029, %c1_1677, %c1_1677] : tensor<?x?x?x?xf32> into tensor<2x2048x4x16xf32>
%cast_1701 = tensor.cast %inserted_slice_1700 : tensor<2x2048x4x16xf32> to tensor<2x2048x4x16xf32>
%int0_1702 = torch.constant.int 0
%int0_1703 = torch.constant.int 0
%1030 = torch_c.to_i64 %int0_1703
%int2_1704 = torch.constant.int 2
%1031 = torch_c.to_i64 %int2_1704
%int1_1705 = torch.constant.int 1
%c0_1706 = arith.constant 0 : index
%c1_1707 = arith.constant 1 : index
%c0_1708 = arith.constant 0 : index
%c32_1709 = arith.constant 32 : index
%c1_1710 = arith.constant 1 : index
%c2048_1711 = arith.constant 2048 : index
%c2_1712 = arith.constant 2 : index
%c4_1713 = arith.constant 4 : index
%c3_1714 = arith.constant 3 : index
%c16_1715 = arith.constant 16 : index
%1032 = arith.index_cast %c32_1709 : index to i64
%1033 = arith.addi %1030, %1032 : i64
%c0_i64_1716 = arith.constant 0 : i64
%1034 = arith.cmpi sge, %1030, %c0_i64_1716 : i64
%1035 = arith.select %1034, %1030, %1033 : i64
%c0_i64_1717 = arith.constant 0 : i64
%1036 = arith.cmpi slt, %1035, %c0_i64_1717 : i64
%1037 = arith.select %1036, %c0_i64_1717, %1035 : i64
%1038 = arith.cmpi sgt, %1037, %1032 : i64
%1039 = arith.select %1038, %1032, %1037 : i64
%1040 = arith.index_cast %1039 : i64 to index
%1041 = arith.index_cast %c32_1709 : index to i64
%1042 = arith.addi %1031, %1041 : i64
%c0_i64_1718 = arith.constant 0 : i64
%1043 = arith.cmpi sge, %1031, %c0_i64_1718 : i64
%1044 = arith.select %1043, %1031, %1042 : i64
%c0_i64_1719 = arith.constant 0 : i64
%1045 = arith.cmpi slt, %1044, %c0_i64_1719 : i64
%1046 = arith.select %1045, %c0_i64_1719, %1044 : i64
%1047 = arith.cmpi sgt, %1046, %1041 : i64
%1048 = arith.select %1047, %1041, %1046 : i64
%1049 = arith.index_cast %1048 : i64 to index
%1050 = arith.cmpi sge, %1049, %1040 : index
%1051 = arith.select %1050, %1049, %1040 : index
%c1_1720 = arith.constant 1 : index
%c0_1721 = arith.constant 0 : index
%c32_1722 = arith.constant 32 : index
%c1_1723 = arith.constant 1 : index
%c2048_1724 = arith.constant 2048 : index
%c2_1725 = arith.constant 2 : index
%c4_1726 = arith.constant 4 : index
%c3_1727 = arith.constant 3 : index
%c16_1728 = arith.constant 16 : index
%1052 = arith.subi %1051, %1040 : index
%1053 = arith.addi %1052, %c1_1720 : index
%1054 = arith.subi %1053, %c1_1707 : index
%1055 = arith.floordivsi %1054, %c1_1720 : index
%1056 = arith.muli %c1_1707, %c1_1720 : index
%cast_1729 = tensor.cast %cast_1701 : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice_1730 = tensor.insert_slice %cast_1729 into %cast_22[%1040, %c0_1706, %c0_1706, %c0_1706] [%1055, %c2048_1724, %c4_1726, %c16_1728] [%1056, %c1_1707, %c1_1707, %c1_1707] : tensor<?x?x?x?xf32> into tensor<32x2048x4x16xf32>
%cast_1731 = tensor.cast %inserted_slice_1730 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%1057 = torch_c.from_builtin_tensor %cast_1731 : tensor<32x2048x4x16xf32> -> !torch.vtensor<[32,2048,4,16],f32>
%int0_1732 = torch.constant.int 0
%int0_1733 = torch.constant.int 0
%1058 = torch_c.to_i64 %int0_1733
%int2_1734 = torch.constant.int 2
%1059 = torch_c.to_i64 %int2_1734
%int1_1735 = torch.constant.int 1
%c0_1736 = arith.constant 0 : index
%c1_1737 = arith.constant 1 : index
%c0_1738 = arith.constant 0 : index
%c32_1739 = arith.constant 32 : index
%c1_1740 = arith.constant 1 : index
%c2048_1741 = arith.constant 2048 : index
%c2_1742 = arith.constant 2 : index
%c4_1743 = arith.constant 4 : index
%c3_1744 = arith.constant 3 : index
%c16_1745 = arith.constant 16 : index
%1060 = arith.index_cast %c32_1739 : index to i64
%1061 = arith.addi %1058, %1060 : i64
%c0_i64_1746 = arith.constant 0 : i64
%1062 = arith.cmpi sge, %1058, %c0_i64_1746 : i64
%1063 = arith.select %1062, %1058, %1061 : i64
%c0_i64_1747 = arith.constant 0 : i64
%1064 = arith.cmpi slt, %1063, %c0_i64_1747 : i64
%1065 = arith.select %1064, %c0_i64_1747, %1063 : i64
%1066 = arith.cmpi sgt, %1065, %1060 : i64
%1067 = arith.select %1066, %1060, %1065 : i64
%1068 = arith.index_cast %1067 : i64 to index
%1069 = arith.index_cast %c32_1739 : index to i64
%1070 = arith.addi %1059, %1069 : i64
%c0_i64_1748 = arith.constant 0 : i64
%1071 = arith.cmpi sge, %1059, %c0_i64_1748 : i64
%1072 = arith.select %1071, %1059, %1070 : i64
%c0_i64_1749 = arith.constant 0 : i64
%1073 = arith.cmpi slt, %1072, %c0_i64_1749 : i64
%1074 = arith.select %1073, %c0_i64_1749, %1072 : i64
%1075 = arith.cmpi sgt, %1074, %1069 : i64
%1076 = arith.select %1075, %1069, %1074 : i64
%1077 = arith.index_cast %1076 : i64 to index
%1078 = arith.cmpi sge, %1077, %1068 : index
%1079 = arith.select %1078, %1077, %1068 : index
%c1_1750 = arith.constant 1 : index
%c0_1751 = arith.constant 0 : index
%c32_1752 = arith.constant 32 : index
%c1_1753 = arith.constant 1 : index
%c2048_1754 = arith.constant 2048 : index
%c2_1755 = arith.constant 2 : index
%c4_1756 = arith.constant 4 : index
%c3_1757 = arith.constant 3 : index
%c16_1758 = arith.constant 16 : index
%1080 = arith.subi %1079, %1068 : index
%1081 = arith.addi %1080, %c1_1750 : index
%1082 = arith.subi %1081, %c1_1737 : index
%1083 = arith.floordivsi %1082, %c1_1750 : index
%1084 = arith.muli %c1_1737, %c1_1750 : index
%extracted_slice_1759 = tensor.extract_slice %cast_33[%1068, %c0_1736, %c0_1736, %c0_1736] [%1083, %c2048_1754, %c4_1756, %c16_1758] [%1084, %c1_1737, %c1_1737, %c1_1737] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1760 = tensor.cast %extracted_slice_1759 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_1761 = torch.constant.int 1
%int0_1762 = torch.constant.int 0
%1085 = torch_c.to_i64 %int0_1762
%int8_1763 = torch.constant.int 8
%1086 = torch_c.to_i64 %int8_1763
%int1_1764 = torch.constant.int 1
%c0_1765 = arith.constant 0 : index
%c1_1766 = arith.constant 1 : index
%c0_1767 = arith.constant 0 : index
%c2_1768 = arith.constant 2 : index
%c1_1769 = arith.constant 1 : index
%c2048_1770 = arith.constant 2048 : index
%c2_1771 = arith.constant 2 : index
%c4_1772 = arith.constant 4 : index
%c3_1773 = arith.constant 3 : index
%c16_1774 = arith.constant 16 : index
%1087 = arith.index_cast %c2048_1770 : index to i64
%1088 = arith.addi %1085, %1087 : i64
%c0_i64_1775 = arith.constant 0 : i64
%1089 = arith.cmpi sge, %1085, %c0_i64_1775 : i64
%1090 = arith.select %1089, %1085, %1088 : i64
%c0_i64_1776 = arith.constant 0 : i64
%1091 = arith.cmpi slt, %1090, %c0_i64_1776 : i64
%1092 = arith.select %1091, %c0_i64_1776, %1090 : i64
%1093 = arith.cmpi sgt, %1092, %1087 : i64
%1094 = arith.select %1093, %1087, %1092 : i64
%1095 = arith.index_cast %1094 : i64 to index
%1096 = arith.index_cast %c2048_1770 : index to i64
%1097 = arith.addi %1086, %1096 : i64
%c0_i64_1777 = arith.constant 0 : i64
%1098 = arith.cmpi sge, %1086, %c0_i64_1777 : i64
%1099 = arith.select %1098, %1086, %1097 : i64
%c0_i64_1778 = arith.constant 0 : i64
%1100 = arith.cmpi slt, %1099, %c0_i64_1778 : i64
%1101 = arith.select %1100, %c0_i64_1778, %1099 : i64
%1102 = arith.cmpi sgt, %1101, %1096 : i64
%1103 = arith.select %1102, %1096, %1101 : i64
%1104 = arith.index_cast %1103 : i64 to index
%1105 = arith.cmpi sge, %1104, %1095 : index
%1106 = arith.select %1105, %1104, %1095 : index
%c1_1779 = arith.constant 1 : index
%c0_1780 = arith.constant 0 : index
%c2_1781 = arith.constant 2 : index
%c1_1782 = arith.constant 1 : index
%c2048_1783 = arith.constant 2048 : index
%c2_1784 = arith.constant 2 : index
%c4_1785 = arith.constant 4 : index
%c3_1786 = arith.constant 3 : index
%c16_1787 = arith.constant 16 : index
%1107 = arith.subi %1106, %1095 : index
%1108 = arith.addi %1107, %c1_1779 : index
%1109 = arith.subi %1108, %c1_1766 : index
%1110 = arith.floordivsi %1109, %c1_1779 : index
%1111 = arith.muli %c1_1766, %c1_1779 : index
%extracted_slice_1788 = tensor.extract_slice %cast_1760[%c0_1765, %1095, %c0_1765, %c0_1765] [%c2_1781, %1110, %c4_1785, %c16_1787] [%c1_1766, %1111, %c1_1766, %c1_1766] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1789 = tensor.cast %extracted_slice_1788 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%false_1790 = torch.constant.bool false
%c0_1791 = arith.constant 0 : index
%c2_1792 = arith.constant 2 : index
%c1_1793 = arith.constant 1 : index
%c8_1794 = arith.constant 8 : index
%c2_1795 = arith.constant 2 : index
%c4_1796 = arith.constant 4 : index
%c3_1797 = arith.constant 3 : index
%c16_1798 = arith.constant 16 : index
%1112 = arith.index_cast %c2_1792 : index to i64
%1113 = arith.index_cast %c8_1794 : index to i64
%1114 = arith.index_cast %c4_1796 : index to i64
%1115 = arith.index_cast %c16_1798 : index to i64
%c0_i64_1799 = arith.constant 0 : i64
%c0_1800 = arith.constant 0 : index
%c1_1801 = arith.constant 1 : index
%1116 = tensor.empty() : tensor<2x8x4x16xf32>
%cast_1802 = tensor.cast %expanded_1447 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%1117 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_1802 : tensor<2x8x4x16xf32>) outs(%cast_1789 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_1803 = tensor.cast %1117 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int0_1804 = torch.constant.int 0
%int0_1805 = torch.constant.int 0
%1118 = torch_c.to_i64 %int0_1805
%int2_1806 = torch.constant.int 2
%1119 = torch_c.to_i64 %int2_1806
%int1_1807 = torch.constant.int 1
%c0_1808 = arith.constant 0 : index
%c1_1809 = arith.constant 1 : index
%c0_1810 = arith.constant 0 : index
%c32_1811 = arith.constant 32 : index
%c1_1812 = arith.constant 1 : index
%c2048_1813 = arith.constant 2048 : index
%c2_1814 = arith.constant 2 : index
%c4_1815 = arith.constant 4 : index
%c3_1816 = arith.constant 3 : index
%c16_1817 = arith.constant 16 : index
%1120 = arith.index_cast %c32_1811 : index to i64
%1121 = arith.addi %1118, %1120 : i64
%c0_i64_1818 = arith.constant 0 : i64
%1122 = arith.cmpi sge, %1118, %c0_i64_1818 : i64
%1123 = arith.select %1122, %1118, %1121 : i64
%c0_i64_1819 = arith.constant 0 : i64
%1124 = arith.cmpi slt, %1123, %c0_i64_1819 : i64
%1125 = arith.select %1124, %c0_i64_1819, %1123 : i64
%1126 = arith.cmpi sgt, %1125, %1120 : i64
%1127 = arith.select %1126, %1120, %1125 : i64
%1128 = arith.index_cast %1127 : i64 to index
%1129 = arith.index_cast %c32_1811 : index to i64
%1130 = arith.addi %1119, %1129 : i64
%c0_i64_1820 = arith.constant 0 : i64
%1131 = arith.cmpi sge, %1119, %c0_i64_1820 : i64
%1132 = arith.select %1131, %1119, %1130 : i64
%c0_i64_1821 = arith.constant 0 : i64
%1133 = arith.cmpi slt, %1132, %c0_i64_1821 : i64
%1134 = arith.select %1133, %c0_i64_1821, %1132 : i64
%1135 = arith.cmpi sgt, %1134, %1129 : i64
%1136 = arith.select %1135, %1129, %1134 : i64
%1137 = arith.index_cast %1136 : i64 to index
%1138 = arith.cmpi sge, %1137, %1128 : index
%1139 = arith.select %1138, %1137, %1128 : index
%c1_1822 = arith.constant 1 : index
%c0_1823 = arith.constant 0 : index
%c32_1824 = arith.constant 32 : index
%c1_1825 = arith.constant 1 : index
%c2048_1826 = arith.constant 2048 : index
%c2_1827 = arith.constant 2 : index
%c4_1828 = arith.constant 4 : index
%c3_1829 = arith.constant 3 : index
%c16_1830 = arith.constant 16 : index
%1140 = arith.subi %1139, %1128 : index
%1141 = arith.addi %1140, %c1_1822 : index
%1142 = arith.subi %1141, %c1_1809 : index
%1143 = arith.floordivsi %1142, %c1_1822 : index
%1144 = arith.muli %c1_1809, %c1_1822 : index
%extracted_slice_1831 = tensor.extract_slice %cast_33[%1128, %c0_1808, %c0_1808, %c0_1808] [%1143, %c2048_1826, %c4_1828, %c16_1830] [%1144, %c1_1809, %c1_1809, %c1_1809] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1832 = tensor.cast %extracted_slice_1831 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_1833 = torch.constant.int 1
%int0_1834 = torch.constant.int 0
%1145 = torch_c.to_i64 %int0_1834
%int8_1835 = torch.constant.int 8
%1146 = torch_c.to_i64 %int8_1835
%int1_1836 = torch.constant.int 1
%c0_1837 = arith.constant 0 : index
%c1_1838 = arith.constant 1 : index
%c0_1839 = arith.constant 0 : index
%c2_1840 = arith.constant 2 : index
%c1_1841 = arith.constant 1 : index
%c2048_1842 = arith.constant 2048 : index
%c2_1843 = arith.constant 2 : index
%c4_1844 = arith.constant 4 : index
%c3_1845 = arith.constant 3 : index
%c16_1846 = arith.constant 16 : index
%1147 = arith.index_cast %c2048_1842 : index to i64
%1148 = arith.addi %1145, %1147 : i64
%c0_i64_1847 = arith.constant 0 : i64
%1149 = arith.cmpi sge, %1145, %c0_i64_1847 : i64
%1150 = arith.select %1149, %1145, %1148 : i64
%c0_i64_1848 = arith.constant 0 : i64
%1151 = arith.cmpi slt, %1150, %c0_i64_1848 : i64
%1152 = arith.select %1151, %c0_i64_1848, %1150 : i64
%1153 = arith.cmpi sgt, %1152, %1147 : i64
%1154 = arith.select %1153, %1147, %1152 : i64
%1155 = arith.index_cast %1154 : i64 to index
%1156 = arith.index_cast %c2048_1842 : index to i64
%1157 = arith.addi %1146, %1156 : i64
%c0_i64_1849 = arith.constant 0 : i64
%1158 = arith.cmpi sge, %1146, %c0_i64_1849 : i64
%1159 = arith.select %1158, %1146, %1157 : i64
%c0_i64_1850 = arith.constant 0 : i64
%1160 = arith.cmpi slt, %1159, %c0_i64_1850 : i64
%1161 = arith.select %1160, %c0_i64_1850, %1159 : i64
%1162 = arith.cmpi sgt, %1161, %1156 : i64
%1163 = arith.select %1162, %1156, %1161 : i64
%1164 = arith.index_cast %1163 : i64 to index
%1165 = arith.cmpi sge, %1164, %1155 : index
%1166 = arith.select %1165, %1164, %1155 : index
%c1_1851 = arith.constant 1 : index
%c0_1852 = arith.constant 0 : index
%c2_1853 = arith.constant 2 : index
%c1_1854 = arith.constant 1 : index
%c2048_1855 = arith.constant 2048 : index
%c2_1856 = arith.constant 2 : index
%c4_1857 = arith.constant 4 : index
%c3_1858 = arith.constant 3 : index
%c16_1859 = arith.constant 16 : index
%1167 = arith.subi %1166, %1155 : index
%1168 = arith.addi %1167, %c1_1851 : index
%1169 = arith.subi %1168, %c1_1838 : index
%1170 = arith.floordivsi %1169, %c1_1851 : index
%1171 = arith.muli %c1_1838, %c1_1851 : index
%cast_1860 = tensor.cast %cast_1803 : tensor<2x8x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice_1861 = tensor.insert_slice %cast_1860 into %cast_1832[%c0_1837, %1155, %c0_1837, %c0_1837] [%c2_1853, %1170, %c4_1857, %c16_1859] [%c1_1838, %1171, %c1_1838, %c1_1838] : tensor<?x?x?x?xf32> into tensor<2x2048x4x16xf32>
%cast_1862 = tensor.cast %inserted_slice_1861 : tensor<2x2048x4x16xf32> to tensor<2x2048x4x16xf32>
%int0_1863 = torch.constant.int 0
%int0_1864 = torch.constant.int 0
%1172 = torch_c.to_i64 %int0_1864
%int2_1865 = torch.constant.int 2
%1173 = torch_c.to_i64 %int2_1865
%int1_1866 = torch.constant.int 1
%c0_1867 = arith.constant 0 : index
%c1_1868 = arith.constant 1 : index
%c0_1869 = arith.constant 0 : index
%c32_1870 = arith.constant 32 : index
%c1_1871 = arith.constant 1 : index
%c2048_1872 = arith.constant 2048 : index
%c2_1873 = arith.constant 2 : index
%c4_1874 = arith.constant 4 : index
%c3_1875 = arith.constant 3 : index
%c16_1876 = arith.constant 16 : index
%1174 = arith.index_cast %c32_1870 : index to i64
%1175 = arith.addi %1172, %1174 : i64
%c0_i64_1877 = arith.constant 0 : i64
%1176 = arith.cmpi sge, %1172, %c0_i64_1877 : i64
%1177 = arith.select %1176, %1172, %1175 : i64
%c0_i64_1878 = arith.constant 0 : i64
%1178 = arith.cmpi slt, %1177, %c0_i64_1878 : i64
%1179 = arith.select %1178, %c0_i64_1878, %1177 : i64
%1180 = arith.cmpi sgt, %1179, %1174 : i64
%1181 = arith.select %1180, %1174, %1179 : i64
%1182 = arith.index_cast %1181 : i64 to index
%1183 = arith.index_cast %c32_1870 : index to i64
%1184 = arith.addi %1173, %1183 : i64
%c0_i64_1879 = arith.constant 0 : i64
%1185 = arith.cmpi sge, %1173, %c0_i64_1879 : i64
%1186 = arith.select %1185, %1173, %1184 : i64
%c0_i64_1880 = arith.constant 0 : i64
%1187 = arith.cmpi slt, %1186, %c0_i64_1880 : i64
%1188 = arith.select %1187, %c0_i64_1880, %1186 : i64
%1189 = arith.cmpi sgt, %1188, %1183 : i64
%1190 = arith.select %1189, %1183, %1188 : i64
%1191 = arith.index_cast %1190 : i64 to index
%1192 = arith.cmpi sge, %1191, %1182 : index
%1193 = arith.select %1192, %1191, %1182 : index
%c1_1881 = arith.constant 1 : index
%c0_1882 = arith.constant 0 : index
%c32_1883 = arith.constant 32 : index
%c1_1884 = arith.constant 1 : index
%c2048_1885 = arith.constant 2048 : index
%c2_1886 = arith.constant 2 : index
%c4_1887 = arith.constant 4 : index
%c3_1888 = arith.constant 3 : index
%c16_1889 = arith.constant 16 : index
%1194 = arith.subi %1193, %1182 : index
%1195 = arith.addi %1194, %c1_1881 : index
%1196 = arith.subi %1195, %c1_1868 : index
%1197 = arith.floordivsi %1196, %c1_1881 : index
%1198 = arith.muli %c1_1868, %c1_1881 : index
%cast_1890 = tensor.cast %cast_1862 : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%inserted_slice_1891 = tensor.insert_slice %cast_1890 into %cast_33[%1182, %c0_1867, %c0_1867, %c0_1867] [%1197, %c2048_1885, %c4_1887, %c16_1889] [%1198, %c1_1868, %c1_1868, %c1_1868] : tensor<?x?x?x?xf32> into tensor<32x2048x4x16xf32>
%cast_1892 = tensor.cast %inserted_slice_1891 : tensor<32x2048x4x16xf32> to tensor<32x2048x4x16xf32>
%1199 = torch_c.from_builtin_tensor %cast_1892 : tensor<32x2048x4x16xf32> -> !torch.vtensor<[32,2048,4,16],f32>
%int1_1893 = torch.constant.int 1
%int2_1894 = torch.constant.int 2
%c0_1895 = arith.constant 0 : index
%c2_1896 = arith.constant 2 : index
%c1_1897 = arith.constant 1 : index
%c8_1898 = arith.constant 8 : index
%c2_1899 = arith.constant 2 : index
%c4_1900 = arith.constant 4 : index
%c3_1901 = arith.constant 3 : index
%c16_1902 = arith.constant 16 : index
%1200 = tensor.empty() : tensor<2x4x8x16xf32>
%1201 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_1538 : tensor<2x8x4x16xf32>) outs(%1200 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_1903 = tensor.cast %1201 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%1202 = torch_c.from_builtin_tensor %cast_1903 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int2_1904 = torch.constant.int 2
%int4_1905 = torch.constant.int 4
%int8_1906 = torch.constant.int 8
%int16_1907 = torch.constant.int 16
%1203 = torch.prim.ListConstruct %int2_1904, %int4_1905, %int8_1906, %int16_1907 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_1908 = torch.constant.bool false
%1204 = torch.aten.expand %1202, %1203, %false_1908 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%1205 = torch_c.to_builtin_tensor %1204 : !torch.vtensor<[2,4,8,16],f32> -> tensor<2x4x8x16xf32>
%int0_1909 = torch.constant.int 0
%c1_1910 = arith.constant 1 : index
%c0_1911 = arith.constant 0 : index
%c2_1912 = arith.constant 2 : index
%c1_1913 = arith.constant 1 : index
%c4_1914 = arith.constant 4 : index
%c2_1915 = arith.constant 2 : index
%c8_1916 = arith.constant 8 : index
%c3_1917 = arith.constant 3 : index
%c16_1918 = arith.constant 16 : index
%1206 = tensor.empty() : tensor<2x4x8x16xf32>
%1207 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1205 : tensor<2x4x8x16xf32>) outs(%1206 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_1919 = tensor.cast %1207 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%1208 = torch_c.from_builtin_tensor %cast_1919 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int8_1920 = torch.constant.int 8
%int8_1921 = torch.constant.int 8
%int16_1922 = torch.constant.int 16
%1209 = torch.prim.ListConstruct %int8_1920, %int8_1921, %int16_1922 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1210 = torch.aten._unsafe_view %1208, %1209 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%1211 = torch_c.to_builtin_tensor %1210 : !torch.vtensor<[8,8,16],f32> -> tensor<8x8x16xf32>
%int0_1923 = torch.constant.int 0
%int0_1924 = torch.constant.int 0
%1212 = torch_c.to_i64 %int0_1924
%int2_1925 = torch.constant.int 2
%1213 = torch_c.to_i64 %int2_1925
%int1_1926 = torch.constant.int 1
%c0_1927 = arith.constant 0 : index
%c1_1928 = arith.constant 1 : index
%c0_1929 = arith.constant 0 : index
%c32_1930 = arith.constant 32 : index
%c1_1931 = arith.constant 1 : index
%c2048_1932 = arith.constant 2048 : index
%c2_1933 = arith.constant 2 : index
%c4_1934 = arith.constant 4 : index
%c3_1935 = arith.constant 3 : index
%c16_1936 = arith.constant 16 : index
%1214 = arith.index_cast %c32_1930 : index to i64
%1215 = arith.addi %1212, %1214 : i64
%c0_i64_1937 = arith.constant 0 : i64
%1216 = arith.cmpi sge, %1212, %c0_i64_1937 : i64
%1217 = arith.select %1216, %1212, %1215 : i64
%c0_i64_1938 = arith.constant 0 : i64
%1218 = arith.cmpi slt, %1217, %c0_i64_1938 : i64
%1219 = arith.select %1218, %c0_i64_1938, %1217 : i64
%1220 = arith.cmpi sgt, %1219, %1214 : i64
%1221 = arith.select %1220, %1214, %1219 : i64
%1222 = arith.index_cast %1221 : i64 to index
%1223 = arith.index_cast %c32_1930 : index to i64
%1224 = arith.addi %1213, %1223 : i64
%c0_i64_1939 = arith.constant 0 : i64
%1225 = arith.cmpi sge, %1213, %c0_i64_1939 : i64
%1226 = arith.select %1225, %1213, %1224 : i64
%c0_i64_1940 = arith.constant 0 : i64
%1227 = arith.cmpi slt, %1226, %c0_i64_1940 : i64
%1228 = arith.select %1227, %c0_i64_1940, %1226 : i64
%1229 = arith.cmpi sgt, %1228, %1223 : i64
%1230 = arith.select %1229, %1223, %1228 : i64
%1231 = arith.index_cast %1230 : i64 to index
%1232 = arith.cmpi sge, %1231, %1222 : index
%1233 = arith.select %1232, %1231, %1222 : index
%c1_1941 = arith.constant 1 : index
%c0_1942 = arith.constant 0 : index
%c32_1943 = arith.constant 32 : index
%c1_1944 = arith.constant 1 : index
%c2048_1945 = arith.constant 2048 : index
%c2_1946 = arith.constant 2 : index
%c4_1947 = arith.constant 4 : index
%c3_1948 = arith.constant 3 : index
%c16_1949 = arith.constant 16 : index
%1234 = arith.subi %1233, %1222 : index
%1235 = arith.addi %1234, %c1_1941 : index
%1236 = arith.subi %1235, %c1_1928 : index
%1237 = arith.floordivsi %1236, %c1_1941 : index
%1238 = arith.muli %c1_1928, %c1_1941 : index
%extracted_slice_1950 = tensor.extract_slice %cast_1731[%1222, %c0_1927, %c0_1927, %c0_1927] [%1237, %c2048_1945, %c4_1947, %c16_1949] [%1238, %c1_1928, %c1_1928, %c1_1928] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1951 = tensor.cast %extracted_slice_1950 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_1952 = torch.constant.int 1
%int0_1953 = torch.constant.int 0
%1239 = torch_c.to_i64 %int0_1953
%int8_1954 = torch.constant.int 8
%1240 = torch_c.to_i64 %int8_1954
%int1_1955 = torch.constant.int 1
%c0_1956 = arith.constant 0 : index
%c1_1957 = arith.constant 1 : index
%c0_1958 = arith.constant 0 : index
%c2_1959 = arith.constant 2 : index
%c1_1960 = arith.constant 1 : index
%c2048_1961 = arith.constant 2048 : index
%c2_1962 = arith.constant 2 : index
%c4_1963 = arith.constant 4 : index
%c3_1964 = arith.constant 3 : index
%c16_1965 = arith.constant 16 : index
%1241 = arith.index_cast %c2048_1961 : index to i64
%1242 = arith.addi %1239, %1241 : i64
%c0_i64_1966 = arith.constant 0 : i64
%1243 = arith.cmpi sge, %1239, %c0_i64_1966 : i64
%1244 = arith.select %1243, %1239, %1242 : i64
%c0_i64_1967 = arith.constant 0 : i64
%1245 = arith.cmpi slt, %1244, %c0_i64_1967 : i64
%1246 = arith.select %1245, %c0_i64_1967, %1244 : i64
%1247 = arith.cmpi sgt, %1246, %1241 : i64
%1248 = arith.select %1247, %1241, %1246 : i64
%1249 = arith.index_cast %1248 : i64 to index
%1250 = arith.index_cast %c2048_1961 : index to i64
%1251 = arith.addi %1240, %1250 : i64
%c0_i64_1968 = arith.constant 0 : i64
%1252 = arith.cmpi sge, %1240, %c0_i64_1968 : i64
%1253 = arith.select %1252, %1240, %1251 : i64
%c0_i64_1969 = arith.constant 0 : i64
%1254 = arith.cmpi slt, %1253, %c0_i64_1969 : i64
%1255 = arith.select %1254, %c0_i64_1969, %1253 : i64
%1256 = arith.cmpi sgt, %1255, %1250 : i64
%1257 = arith.select %1256, %1250, %1255 : i64
%1258 = arith.index_cast %1257 : i64 to index
%1259 = arith.cmpi sge, %1258, %1249 : index
%1260 = arith.select %1259, %1258, %1249 : index
%c1_1970 = arith.constant 1 : index
%c0_1971 = arith.constant 0 : index
%c2_1972 = arith.constant 2 : index
%c1_1973 = arith.constant 1 : index
%c2048_1974 = arith.constant 2048 : index
%c2_1975 = arith.constant 2 : index
%c4_1976 = arith.constant 4 : index
%c3_1977 = arith.constant 3 : index
%c16_1978 = arith.constant 16 : index
%1261 = arith.subi %1260, %1249 : index
%1262 = arith.addi %1261, %c1_1970 : index
%1263 = arith.subi %1262, %c1_1957 : index
%1264 = arith.floordivsi %1263, %c1_1970 : index
%1265 = arith.muli %c1_1957, %c1_1970 : index
%extracted_slice_1979 = tensor.extract_slice %cast_1951[%c0_1956, %1249, %c0_1956, %c0_1956] [%c2_1972, %1264, %c4_1976, %c16_1978] [%c1_1957, %1265, %c1_1957, %c1_1957] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_1980 = tensor.cast %extracted_slice_1979 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%int1_1981 = torch.constant.int 1
%int2_1982 = torch.constant.int 2
%c0_1983 = arith.constant 0 : index
%c2_1984 = arith.constant 2 : index
%c1_1985 = arith.constant 1 : index
%c8_1986 = arith.constant 8 : index
%c2_1987 = arith.constant 2 : index
%c4_1988 = arith.constant 4 : index
%c3_1989 = arith.constant 3 : index
%c16_1990 = arith.constant 16 : index
%1266 = tensor.empty() : tensor<2x4x8x16xf32>
%1267 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_1980 : tensor<2x8x4x16xf32>) outs(%1266 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_1991 = tensor.cast %1267 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%int2_1992 = torch.constant.int 2
%int3_1993 = torch.constant.int 3
%c0_1994 = arith.constant 0 : index
%c2_1995 = arith.constant 2 : index
%c1_1996 = arith.constant 1 : index
%c4_1997 = arith.constant 4 : index
%c2_1998 = arith.constant 2 : index
%c8_1999 = arith.constant 8 : index
%c3_2000 = arith.constant 3 : index
%c16_2001 = arith.constant 16 : index
%1268 = tensor.empty() : tensor<2x4x16x8xf32>
%1269 = linalg.generic {indexing_maps = [#map, #map12], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_1991 : tensor<2x4x8x16xf32>) outs(%1268 : tensor<2x4x16x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x16x8xf32>
%cast_2002 = tensor.cast %1269 : tensor<2x4x16x8xf32> to tensor<2x4x16x8xf32>
%1270 = torch_c.from_builtin_tensor %cast_2002 : tensor<2x4x16x8xf32> -> !torch.vtensor<[2,4,16,8],f32>
%int2_2003 = torch.constant.int 2
%int4_2004 = torch.constant.int 4
%int16_2005 = torch.constant.int 16
%int8_2006 = torch.constant.int 8
%1271 = torch.prim.ListConstruct %int2_2003, %int4_2004, %int16_2005, %int8_2006 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_2007 = torch.constant.bool false
%1272 = torch.aten.expand %1270, %1271, %false_2007 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,16,8],f32>
%1273 = torch_c.to_builtin_tensor %1272 : !torch.vtensor<[2,4,16,8],f32> -> tensor<2x4x16x8xf32>
%int0_2008 = torch.constant.int 0
%c1_2009 = arith.constant 1 : index
%c0_2010 = arith.constant 0 : index
%c2_2011 = arith.constant 2 : index
%c1_2012 = arith.constant 1 : index
%c4_2013 = arith.constant 4 : index
%c2_2014 = arith.constant 2 : index
%c16_2015 = arith.constant 16 : index
%c3_2016 = arith.constant 3 : index
%c8_2017 = arith.constant 8 : index
%1274 = tensor.empty() : tensor<2x4x16x8xf32>
%1275 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1273 : tensor<2x4x16x8xf32>) outs(%1274 : tensor<2x4x16x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x16x8xf32>
%cast_2018 = tensor.cast %1275 : tensor<2x4x16x8xf32> to tensor<2x4x16x8xf32>
%1276 = torch_c.from_builtin_tensor %cast_2018 : tensor<2x4x16x8xf32> -> !torch.vtensor<[2,4,16,8],f32>
%int8_2019 = torch.constant.int 8
%int16_2020 = torch.constant.int 16
%int8_2021 = torch.constant.int 8
%1277 = torch.prim.ListConstruct %int8_2019, %int16_2020, %int8_2021 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1278 = torch.aten._unsafe_view %1276, %1277 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int> -> !torch.vtensor<[8,16,8],f32>
%1279 = torch_c.to_builtin_tensor %1278 : !torch.vtensor<[8,16,8],f32> -> tensor<8x16x8xf32>
%c0_2022 = arith.constant 0 : index
%c8_2023 = arith.constant 8 : index
%c1_2024 = arith.constant 1 : index
%c8_2025 = arith.constant 8 : index
%c2_2026 = arith.constant 2 : index
%c16_2027 = arith.constant 16 : index
%c0_2028 = arith.constant 0 : index
%c8_2029 = arith.constant 8 : index
%c1_2030 = arith.constant 1 : index
%c16_2031 = arith.constant 16 : index
%c2_2032 = arith.constant 2 : index
%c8_2033 = arith.constant 8 : index
%1280 = arith.index_cast %c8_2023 : index to i64
%1281 = arith.index_cast %c8_2029 : index to i64
%1282 = arith.cmpi eq, %1280, %1281 : i64
cf.assert %1282, "mismatching contracting dimension"
%1283 = arith.index_cast %c16_2027 : index to i64
%1284 = arith.index_cast %c16_2031 : index to i64
%1285 = arith.cmpi eq, %1283, %1284 : i64
cf.assert %1285, "mismatching contracting dimension"
%1286 = tensor.empty() : tensor<8x8x8xf32>
%cst_2034 = arith.constant 0.000000e+00 : f32
%1287 = linalg.fill ins(%cst_2034 : f32) outs(%1286 : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
%1288 = linalg.batch_matmul ins(%1211, %1279 : tensor<8x8x16xf32>, tensor<8x16x8xf32>) outs(%1287 : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
%cast_2035 = tensor.cast %1288 : tensor<8x8x8xf32> to tensor<8x8x8xf32>
%int2_2036 = torch.constant.int 2
%int4_2037 = torch.constant.int 4
%int8_2038 = torch.constant.int 8
%int8_2039 = torch.constant.int 8
%1289 = torch.prim.ListConstruct %int2_2036, %int4_2037, %int8_2038, %int8_2039 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2040 = arith.constant 0 : index
%c8_2041 = arith.constant 8 : index
%c1_2042 = arith.constant 1 : index
%c8_2043 = arith.constant 8 : index
%c2_2044 = arith.constant 2 : index
%c8_2045 = arith.constant 8 : index
%1290 = torch_c.to_i64 %int2_2036
%1291 = torch_c.to_i64 %int4_2037
%1292 = torch_c.to_i64 %int8_2038
%1293 = torch_c.to_i64 %int8_2039
%expanded_2046 = tensor.expand_shape %cast_2035 [[0, 1], [2], [3]] : tensor<8x8x8xf32> into tensor<2x4x8x8xf32>
%float4.000000e00_2047 = torch.constant.float 4.000000e+00
%1294 = torch_c.to_f64 %float4.000000e00_2047
%c1_2048 = arith.constant 1 : index
%c0_2049 = arith.constant 0 : index
%c2_2050 = arith.constant 2 : index
%c1_2051 = arith.constant 1 : index
%c4_2052 = arith.constant 4 : index
%c2_2053 = arith.constant 2 : index
%c8_2054 = arith.constant 8 : index
%c3_2055 = arith.constant 3 : index
%c8_2056 = arith.constant 8 : index
%1295 = tensor.empty() : tensor<2x4x8x8xf32>
%1296 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_2046 : tensor<2x4x8x8xf32>) outs(%1295 : tensor<2x4x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %1294 : f64 to f32
%1544 = arith.divf %in, %1543 : f32
linalg.yield %1544 : f32
} -> tensor<2x4x8x8xf32>
%cast_2057 = tensor.cast %1296 : tensor<2x4x8x8xf32> to tensor<2x4x8x8xf32>
%int1_2058 = torch.constant.int 1
%1297 = torch_c.to_i64 %int1_2058
%c1_2059 = arith.constant 1 : index
%c0_2060 = arith.constant 0 : index
%c2_2061 = arith.constant 2 : index
%c1_2062 = arith.constant 1 : index
%c4_2063 = arith.constant 4 : index
%c2_2064 = arith.constant 2 : index
%c8_2065 = arith.constant 8 : index
%c3_2066 = arith.constant 3 : index
%c8_2067 = arith.constant 8 : index
%c2_2068 = arith.constant 2 : index
%c8_2069 = arith.constant 8 : index
%1298 = arith.cmpi eq, %c8_2065, %c8_2069 : index
cf.assert %1298, "mismatched size for broadcast"
%c3_2070 = arith.constant 3 : index
%c8_2071 = arith.constant 8 : index
%1299 = arith.cmpi eq, %c8_2067, %c8_2071 : index
cf.assert %1299, "mismatched size for broadcast"
%1300 = tensor.empty() : tensor<2x4x8x8xf32>
%1301 = linalg.generic {indexing_maps = [#map, #map3, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_2057, %cast_76 : tensor<2x4x8x8xf32>, tensor<1x1x8x8xf32>) outs(%1300 : tensor<2x4x8x8xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.sitofp %1297 : i64 to f32
%1544 = arith.mulf %in_2549, %1543 : f32
%1545 = arith.addf %in, %1544 : f32
linalg.yield %1545 : f32
} -> tensor<2x4x8x8xf32>
%cast_2072 = tensor.cast %1301 : tensor<2x4x8x8xf32> to tensor<2x4x8x8xf32>
%1302 = torch_c.from_builtin_tensor %cast_2072 : tensor<2x4x8x8xf32> -> !torch.vtensor<[2,4,8,8],f32>
%int-1_2073 = torch.constant.int -1
%false_2074 = torch.constant.bool false
%1303 = torch.aten._softmax %1302, %int-1_2073, %false_2074 : !torch.vtensor<[2,4,8,8],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%int2_2075 = torch.constant.int 2
%int4_2076 = torch.constant.int 4
%int8_2077 = torch.constant.int 8
%int8_2078 = torch.constant.int 8
%1304 = torch.prim.ListConstruct %int2_2075, %int4_2076, %int8_2077, %int8_2078 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_2079 = torch.constant.bool false
%1305 = torch.aten.expand %1303, %1304, %false_2079 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%1306 = torch_c.to_builtin_tensor %1305 : !torch.vtensor<[2,4,8,8],f32> -> tensor<2x4x8x8xf32>
%int8_2080 = torch.constant.int 8
%int8_2081 = torch.constant.int 8
%int8_2082 = torch.constant.int 8
%1307 = torch.prim.ListConstruct %int8_2080, %int8_2081, %int8_2082 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2083 = arith.constant 0 : index
%c2_2084 = arith.constant 2 : index
%c1_2085 = arith.constant 1 : index
%c4_2086 = arith.constant 4 : index
%c2_2087 = arith.constant 2 : index
%c8_2088 = arith.constant 8 : index
%c3_2089 = arith.constant 3 : index
%c8_2090 = arith.constant 8 : index
%1308 = torch_c.to_i64 %int8_2080
%1309 = torch_c.to_i64 %int8_2081
%1310 = torch_c.to_i64 %int8_2082
%collapsed_2091 = tensor.collapse_shape %1306 [[0, 1], [2], [3]] : tensor<2x4x8x8xf32> into tensor<8x8x8xf32>
%1311 = torch_c.from_builtin_tensor %collapsed_2091 : tensor<8x8x8xf32> -> !torch.vtensor<[8,8,8],f32>
%int0_2092 = torch.constant.int 0
%int0_2093 = torch.constant.int 0
%1312 = torch_c.to_i64 %int0_2093
%int2_2094 = torch.constant.int 2
%1313 = torch_c.to_i64 %int2_2094
%int1_2095 = torch.constant.int 1
%c0_2096 = arith.constant 0 : index
%c1_2097 = arith.constant 1 : index
%c0_2098 = arith.constant 0 : index
%c32_2099 = arith.constant 32 : index
%c1_2100 = arith.constant 1 : index
%c2048_2101 = arith.constant 2048 : index
%c2_2102 = arith.constant 2 : index
%c4_2103 = arith.constant 4 : index
%c3_2104 = arith.constant 3 : index
%c16_2105 = arith.constant 16 : index
%1314 = arith.index_cast %c32_2099 : index to i64
%1315 = arith.addi %1312, %1314 : i64
%c0_i64_2106 = arith.constant 0 : i64
%1316 = arith.cmpi sge, %1312, %c0_i64_2106 : i64
%1317 = arith.select %1316, %1312, %1315 : i64
%c0_i64_2107 = arith.constant 0 : i64
%1318 = arith.cmpi slt, %1317, %c0_i64_2107 : i64
%1319 = arith.select %1318, %c0_i64_2107, %1317 : i64
%1320 = arith.cmpi sgt, %1319, %1314 : i64
%1321 = arith.select %1320, %1314, %1319 : i64
%1322 = arith.index_cast %1321 : i64 to index
%1323 = arith.index_cast %c32_2099 : index to i64
%1324 = arith.addi %1313, %1323 : i64
%c0_i64_2108 = arith.constant 0 : i64
%1325 = arith.cmpi sge, %1313, %c0_i64_2108 : i64
%1326 = arith.select %1325, %1313, %1324 : i64
%c0_i64_2109 = arith.constant 0 : i64
%1327 = arith.cmpi slt, %1326, %c0_i64_2109 : i64
%1328 = arith.select %1327, %c0_i64_2109, %1326 : i64
%1329 = arith.cmpi sgt, %1328, %1323 : i64
%1330 = arith.select %1329, %1323, %1328 : i64
%1331 = arith.index_cast %1330 : i64 to index
%1332 = arith.cmpi sge, %1331, %1322 : index
%1333 = arith.select %1332, %1331, %1322 : index
%c1_2110 = arith.constant 1 : index
%c0_2111 = arith.constant 0 : index
%c32_2112 = arith.constant 32 : index
%c1_2113 = arith.constant 1 : index
%c2048_2114 = arith.constant 2048 : index
%c2_2115 = arith.constant 2 : index
%c4_2116 = arith.constant 4 : index
%c3_2117 = arith.constant 3 : index
%c16_2118 = arith.constant 16 : index
%1334 = arith.subi %1333, %1322 : index
%1335 = arith.addi %1334, %c1_2110 : index
%1336 = arith.subi %1335, %c1_2097 : index
%1337 = arith.floordivsi %1336, %c1_2110 : index
%1338 = arith.muli %c1_2097, %c1_2110 : index
%extracted_slice_2119 = tensor.extract_slice %cast_1892[%1322, %c0_2096, %c0_2096, %c0_2096] [%1337, %c2048_2114, %c4_2116, %c16_2118] [%1338, %c1_2097, %c1_2097, %c1_2097] : tensor<32x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_2120 = tensor.cast %extracted_slice_2119 : tensor<?x?x?x?xf32> to tensor<2x2048x4x16xf32>
%int1_2121 = torch.constant.int 1
%int0_2122 = torch.constant.int 0
%1339 = torch_c.to_i64 %int0_2122
%int8_2123 = torch.constant.int 8
%1340 = torch_c.to_i64 %int8_2123
%int1_2124 = torch.constant.int 1
%c0_2125 = arith.constant 0 : index
%c1_2126 = arith.constant 1 : index
%c0_2127 = arith.constant 0 : index
%c2_2128 = arith.constant 2 : index
%c1_2129 = arith.constant 1 : index
%c2048_2130 = arith.constant 2048 : index
%c2_2131 = arith.constant 2 : index
%c4_2132 = arith.constant 4 : index
%c3_2133 = arith.constant 3 : index
%c16_2134 = arith.constant 16 : index
%1341 = arith.index_cast %c2048_2130 : index to i64
%1342 = arith.addi %1339, %1341 : i64
%c0_i64_2135 = arith.constant 0 : i64
%1343 = arith.cmpi sge, %1339, %c0_i64_2135 : i64
%1344 = arith.select %1343, %1339, %1342 : i64
%c0_i64_2136 = arith.constant 0 : i64
%1345 = arith.cmpi slt, %1344, %c0_i64_2136 : i64
%1346 = arith.select %1345, %c0_i64_2136, %1344 : i64
%1347 = arith.cmpi sgt, %1346, %1341 : i64
%1348 = arith.select %1347, %1341, %1346 : i64
%1349 = arith.index_cast %1348 : i64 to index
%1350 = arith.index_cast %c2048_2130 : index to i64
%1351 = arith.addi %1340, %1350 : i64
%c0_i64_2137 = arith.constant 0 : i64
%1352 = arith.cmpi sge, %1340, %c0_i64_2137 : i64
%1353 = arith.select %1352, %1340, %1351 : i64
%c0_i64_2138 = arith.constant 0 : i64
%1354 = arith.cmpi slt, %1353, %c0_i64_2138 : i64
%1355 = arith.select %1354, %c0_i64_2138, %1353 : i64
%1356 = arith.cmpi sgt, %1355, %1350 : i64
%1357 = arith.select %1356, %1350, %1355 : i64
%1358 = arith.index_cast %1357 : i64 to index
%1359 = arith.cmpi sge, %1358, %1349 : index
%1360 = arith.select %1359, %1358, %1349 : index
%c1_2139 = arith.constant 1 : index
%c0_2140 = arith.constant 0 : index
%c2_2141 = arith.constant 2 : index
%c1_2142 = arith.constant 1 : index
%c2048_2143 = arith.constant 2048 : index
%c2_2144 = arith.constant 2 : index
%c4_2145 = arith.constant 4 : index
%c3_2146 = arith.constant 3 : index
%c16_2147 = arith.constant 16 : index
%1361 = arith.subi %1360, %1349 : index
%1362 = arith.addi %1361, %c1_2139 : index
%1363 = arith.subi %1362, %c1_2126 : index
%1364 = arith.floordivsi %1363, %c1_2139 : index
%1365 = arith.muli %c1_2126, %c1_2139 : index
%extracted_slice_2148 = tensor.extract_slice %cast_2120[%c0_2125, %1349, %c0_2125, %c0_2125] [%c2_2141, %1364, %c4_2145, %c16_2147] [%c1_2126, %1365, %c1_2126, %c1_2126] : tensor<2x2048x4x16xf32> to tensor<?x?x?x?xf32>
%cast_2149 = tensor.cast %extracted_slice_2148 : tensor<?x?x?x?xf32> to tensor<2x8x4x16xf32>
%int1_2150 = torch.constant.int 1
%int2_2151 = torch.constant.int 2
%c0_2152 = arith.constant 0 : index
%c2_2153 = arith.constant 2 : index
%c1_2154 = arith.constant 1 : index
%c8_2155 = arith.constant 8 : index
%c2_2156 = arith.constant 2 : index
%c4_2157 = arith.constant 4 : index
%c3_2158 = arith.constant 3 : index
%c16_2159 = arith.constant 16 : index
%1366 = tensor.empty() : tensor<2x4x8x16xf32>
%1367 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_2149 : tensor<2x8x4x16xf32>) outs(%1366 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_2160 = tensor.cast %1367 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%1368 = torch_c.from_builtin_tensor %cast_2160 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int2_2161 = torch.constant.int 2
%int4_2162 = torch.constant.int 4
%int8_2163 = torch.constant.int 8
%int16_2164 = torch.constant.int 16
%1369 = torch.prim.ListConstruct %int2_2161, %int4_2162, %int8_2163, %int16_2164 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_2165 = torch.constant.bool false
%1370 = torch.aten.expand %1368, %1369, %false_2165 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%1371 = torch_c.to_builtin_tensor %1370 : !torch.vtensor<[2,4,8,16],f32> -> tensor<2x4x8x16xf32>
%int0_2166 = torch.constant.int 0
%c1_2167 = arith.constant 1 : index
%c0_2168 = arith.constant 0 : index
%c2_2169 = arith.constant 2 : index
%c1_2170 = arith.constant 1 : index
%c4_2171 = arith.constant 4 : index
%c2_2172 = arith.constant 2 : index
%c8_2173 = arith.constant 8 : index
%c3_2174 = arith.constant 3 : index
%c16_2175 = arith.constant 16 : index
%1372 = tensor.empty() : tensor<2x4x8x16xf32>
%1373 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1371 : tensor<2x4x8x16xf32>) outs(%1372 : tensor<2x4x8x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x4x8x16xf32>
%cast_2176 = tensor.cast %1373 : tensor<2x4x8x16xf32> to tensor<2x4x8x16xf32>
%1374 = torch_c.from_builtin_tensor %cast_2176 : tensor<2x4x8x16xf32> -> !torch.vtensor<[2,4,8,16],f32>
%int8_2177 = torch.constant.int 8
%int8_2178 = torch.constant.int 8
%int16_2179 = torch.constant.int 16
%1375 = torch.prim.ListConstruct %int8_2177, %int8_2178, %int16_2179 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1376 = torch.aten._unsafe_view %1374, %1375 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%1377 = torch_c.to_builtin_tensor %1376 : !torch.vtensor<[8,8,16],f32> -> tensor<8x8x16xf32>
%c0_2180 = arith.constant 0 : index
%c8_2181 = arith.constant 8 : index
%c1_2182 = arith.constant 1 : index
%c8_2183 = arith.constant 8 : index
%c2_2184 = arith.constant 2 : index
%c8_2185 = arith.constant 8 : index
%c0_2186 = arith.constant 0 : index
%c8_2187 = arith.constant 8 : index
%c1_2188 = arith.constant 1 : index
%c8_2189 = arith.constant 8 : index
%c2_2190 = arith.constant 2 : index
%c16_2191 = arith.constant 16 : index
%1378 = arith.index_cast %c8_2181 : index to i64
%1379 = arith.index_cast %c8_2187 : index to i64
%1380 = arith.cmpi eq, %1378, %1379 : i64
cf.assert %1380, "mismatching contracting dimension"
%1381 = arith.index_cast %c8_2185 : index to i64
%1382 = arith.index_cast %c8_2189 : index to i64
%1383 = arith.cmpi eq, %1381, %1382 : i64
cf.assert %1383, "mismatching contracting dimension"
%1384 = tensor.empty() : tensor<8x8x16xf32>
%cst_2192 = arith.constant 0.000000e+00 : f32
%1385 = linalg.fill ins(%cst_2192 : f32) outs(%1384 : tensor<8x8x16xf32>) -> tensor<8x8x16xf32>
%1386 = linalg.batch_matmul ins(%collapsed_2091, %1377 : tensor<8x8x8xf32>, tensor<8x8x16xf32>) outs(%1385 : tensor<8x8x16xf32>) -> tensor<8x8x16xf32>
%cast_2193 = tensor.cast %1386 : tensor<8x8x16xf32> to tensor<8x8x16xf32>
%int2_2194 = torch.constant.int 2
%int4_2195 = torch.constant.int 4
%int8_2196 = torch.constant.int 8
%int16_2197 = torch.constant.int 16
%1387 = torch.prim.ListConstruct %int2_2194, %int4_2195, %int8_2196, %int16_2197 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2198 = arith.constant 0 : index
%c8_2199 = arith.constant 8 : index
%c1_2200 = arith.constant 1 : index
%c8_2201 = arith.constant 8 : index
%c2_2202 = arith.constant 2 : index
%c16_2203 = arith.constant 16 : index
%1388 = torch_c.to_i64 %int2_2194
%1389 = torch_c.to_i64 %int4_2195
%1390 = torch_c.to_i64 %int8_2196
%1391 = torch_c.to_i64 %int16_2197
%expanded_2204 = tensor.expand_shape %cast_2193 [[0, 1], [2], [3]] : tensor<8x8x16xf32> into tensor<2x4x8x16xf32>
%int1_2205 = torch.constant.int 1
%int2_2206 = torch.constant.int 2
%c0_2207 = arith.constant 0 : index
%c2_2208 = arith.constant 2 : index
%c1_2209 = arith.constant 1 : index
%c4_2210 = arith.constant 4 : index
%c2_2211 = arith.constant 2 : index
%c8_2212 = arith.constant 8 : index
%c3_2213 = arith.constant 3 : index
%c16_2214 = arith.constant 16 : index
%1392 = tensor.empty() : tensor<2x8x4x16xf32>
%1393 = linalg.generic {indexing_maps = [#map, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_2204 : tensor<2x4x8x16xf32>) outs(%1392 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_2215 = tensor.cast %1393 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int0_2216 = torch.constant.int 0
%c1_2217 = arith.constant 1 : index
%c0_2218 = arith.constant 0 : index
%c2_2219 = arith.constant 2 : index
%c1_2220 = arith.constant 1 : index
%c8_2221 = arith.constant 8 : index
%c2_2222 = arith.constant 2 : index
%c4_2223 = arith.constant 4 : index
%c3_2224 = arith.constant 3 : index
%c16_2225 = arith.constant 16 : index
%1394 = tensor.empty() : tensor<2x8x4x16xf32>
%1395 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast_2215 : tensor<2x8x4x16xf32>) outs(%1394 : tensor<2x8x4x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x8x4x16xf32>
%cast_2226 = tensor.cast %1395 : tensor<2x8x4x16xf32> to tensor<2x8x4x16xf32>
%int2_2227 = torch.constant.int 2
%int8_2228 = torch.constant.int 8
%int-1_2229 = torch.constant.int -1
%1396 = torch.prim.ListConstruct %int2_2227, %int8_2228, %int-1_2229 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2230 = arith.constant 0 : index
%c2_2231 = arith.constant 2 : index
%c1_2232 = arith.constant 1 : index
%c8_2233 = arith.constant 8 : index
%c2_2234 = arith.constant 2 : index
%c4_2235 = arith.constant 4 : index
%c3_2236 = arith.constant 3 : index
%c16_2237 = arith.constant 16 : index
%1397 = torch_c.to_i64 %int2_2227
%1398 = torch_c.to_i64 %int8_2228
%1399 = torch_c.to_i64 %int-1_2229
%collapsed_2238 = tensor.collapse_shape %cast_2226 [[0], [1], [2, 3]] : tensor<2x8x4x16xf32> into tensor<2x8x64xf32>
%int0_2239 = torch.constant.int 0
%int1_2240 = torch.constant.int 1
%c0_2241 = arith.constant 0 : index
%c64_2242 = arith.constant 64 : index
%c1_2243 = arith.constant 1 : index
%c64_2244 = arith.constant 64 : index
%1400 = tensor.empty() : tensor<64x64xf32>
%1401 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%20 : tensor<64x64xf32>) outs(%1400 : tensor<64x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x64xf32>
%cast_2245 = tensor.cast %1401 : tensor<64x64xf32> to tensor<64x64xf32>
%1402 = torch_c.from_builtin_tensor %cast_2245 : tensor<64x64xf32> -> !torch.vtensor<[64,64],f32>
%int16_2246 = torch.constant.int 16
%int64_2247 = torch.constant.int 64
%1403 = torch.prim.ListConstruct %int16_2246, %int64_2247 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_2248 = arith.constant 0 : index
%c2_2249 = arith.constant 2 : index
%c1_2250 = arith.constant 1 : index
%c8_2251 = arith.constant 8 : index
%c2_2252 = arith.constant 2 : index
%c64_2253 = arith.constant 64 : index
%1404 = torch_c.to_i64 %int16_2246
%1405 = torch_c.to_i64 %int64_2247
%collapsed_2254 = tensor.collapse_shape %collapsed_2238 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%1406 = torch_c.from_builtin_tensor %collapsed_2254 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_2255 = arith.constant 0 : index
%dim_2256 = tensor.dim %collapsed_2254, %c0_2255 : tensor<16x64xf32>
%c1_2257 = arith.constant 1 : index
%dim_2258 = tensor.dim %cast_2245, %c1_2257 : tensor<64x64xf32>
%c1_2259 = arith.constant 1 : index
%dim_2260 = tensor.dim %collapsed_2254, %c1_2259 : tensor<16x64xf32>
%c0_2261 = arith.constant 0 : index
%dim_2262 = tensor.dim %cast_2245, %c0_2261 : tensor<64x64xf32>
%1407 = arith.cmpi eq, %dim_2260, %dim_2262 : index
cf.assert %1407, "mismatching contracting dimension for torch.aten.mm"
%1408 = tensor.empty(%dim_2256, %dim_2258) : tensor<?x?xf32>
%cst_2263 = arith.constant 0.000000e+00 : f32
%1409 = linalg.fill ins(%cst_2263 : f32) outs(%1408 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1410 = linalg.matmul ins(%collapsed_2254, %cast_2245 : tensor<16x64xf32>, tensor<64x64xf32>) outs(%1409 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_2264 = tensor.cast %1410 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_2265 = torch.constant.int 2
%int8_2266 = torch.constant.int 8
%int64_2267 = torch.constant.int 64
%1411 = torch.prim.ListConstruct %int2_2265, %int8_2266, %int64_2267 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2268 = arith.constant 0 : index
%c16_2269 = arith.constant 16 : index
%c1_2270 = arith.constant 1 : index
%c64_2271 = arith.constant 64 : index
%1412 = torch_c.to_i64 %int2_2265
%1413 = torch_c.to_i64 %int8_2266
%1414 = torch_c.to_i64 %int64_2267
%expanded_2272 = tensor.expand_shape %cast_2264 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int1_2273 = torch.constant.int 1
%1415 = torch_c.to_i64 %int1_2273
%c1_2274 = arith.constant 1 : index
%c0_2275 = arith.constant 0 : index
%c2_2276 = arith.constant 2 : index
%c1_2277 = arith.constant 1 : index
%c8_2278 = arith.constant 8 : index
%c2_2279 = arith.constant 2 : index
%c64_2280 = arith.constant 64 : index
%c0_2281 = arith.constant 0 : index
%c2_2282 = arith.constant 2 : index
%1416 = arith.cmpi eq, %c2_2276, %c2_2282 : index
cf.assert %1416, "mismatched size for broadcast"
%c1_2283 = arith.constant 1 : index
%c8_2284 = arith.constant 8 : index
%1417 = arith.cmpi eq, %c8_2278, %c8_2284 : index
cf.assert %1417, "mismatched size for broadcast"
%c2_2285 = arith.constant 2 : index
%c64_2286 = arith.constant 64 : index
%1418 = arith.cmpi eq, %c64_2280, %c64_2286 : index
cf.assert %1418, "mismatched size for broadcast"
%1419 = tensor.empty() : tensor<2x8x64xf32>
%1420 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1264, %expanded_2272 : tensor<2x8x64xf32>, tensor<2x8x64xf32>) outs(%1419 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.sitofp %1415 : i64 to f32
%1544 = arith.mulf %in_2549, %1543 : f32
%1545 = arith.addf %in, %1544 : f32
linalg.yield %1545 : f32
} -> tensor<2x8x64xf32>
%cast_2287 = tensor.cast %1420 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%1421 = torch_c.from_builtin_tensor %cast_2287 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int2_2288 = torch.constant.int 2
%1422 = torch_c.to_i64 %int2_2288
%c1_2289 = arith.constant 1 : index
%c0_2290 = arith.constant 0 : index
%c2_2291 = arith.constant 2 : index
%c1_2292 = arith.constant 1 : index
%c8_2293 = arith.constant 8 : index
%c2_2294 = arith.constant 2 : index
%c64_2295 = arith.constant 64 : index
%1423 = tensor.empty() : tensor<2x8x64xf32>
%1424 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2287 : tensor<2x8x64xf32>) outs(%1423 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.sitofp %1422 : i64 to f32
%1544 = math.powf %in, %1543 : f32
linalg.yield %1544 : f32
} -> tensor<2x8x64xf32>
%cast_2296 = tensor.cast %1424 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%1425 = torch_c.from_builtin_tensor %cast_2296 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int-1_2297 = torch.constant.int -1
%1426 = torch.prim.ListConstruct %int-1_2297 : (!torch.int) -> !torch.list<int>
%true_2298 = torch.constant.bool true
%none_2299 = torch.constant.none
%1427 = torch.aten.mean.dim %1425, %1426, %true_2298, %none_2299 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%1428 = torch_c.to_builtin_tensor %1427 : !torch.vtensor<[2,8,1],f32> -> tensor<2x8x1xf32>
%float1.000000e-05_2300 = torch.constant.float 1.000000e-05
%1429 = torch_c.to_f64 %float1.000000e-05_2300
%int1_2301 = torch.constant.int 1
%1430 = torch_c.to_i64 %int1_2301
%c1_2302 = arith.constant 1 : index
%c0_2303 = arith.constant 0 : index
%c2_2304 = arith.constant 2 : index
%c1_2305 = arith.constant 1 : index
%c8_2306 = arith.constant 8 : index
%1431 = tensor.empty() : tensor<2x8x1xf32>
%1432 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1428 : tensor<2x8x1xf32>) outs(%1431 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %1429 : f64 to f32
%1544 = arith.sitofp %1430 : i64 to f32
%1545 = arith.mulf %1543, %1544 : f32
%1546 = arith.addf %in, %1545 : f32
linalg.yield %1546 : f32
} -> tensor<2x8x1xf32>
%cast_2307 = tensor.cast %1432 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%c1_2308 = arith.constant 1 : index
%c0_2309 = arith.constant 0 : index
%c2_2310 = arith.constant 2 : index
%c1_2311 = arith.constant 1 : index
%c8_2312 = arith.constant 8 : index
%1433 = tensor.empty() : tensor<2x8x1xf32>
%1434 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2307 : tensor<2x8x1xf32>) outs(%1433 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = math.rsqrt %in : f32
linalg.yield %1543 : f32
} -> tensor<2x8x1xf32>
%cast_2313 = tensor.cast %1434 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%1435 = torch_c.from_builtin_tensor %cast_2313 : tensor<2x8x1xf32> -> !torch.vtensor<[2,8,1],f32>
%c1_2314 = arith.constant 1 : index
%c0_2315 = arith.constant 0 : index
%c2_2316 = arith.constant 2 : index
%c1_2317 = arith.constant 1 : index
%c8_2318 = arith.constant 8 : index
%c2_2319 = arith.constant 2 : index
%c64_2320 = arith.constant 64 : index
%c0_2321 = arith.constant 0 : index
%c2_2322 = arith.constant 2 : index
%1436 = arith.cmpi eq, %c2_2316, %c2_2322 : index
cf.assert %1436, "mismatched size for broadcast"
%c1_2323 = arith.constant 1 : index
%c8_2324 = arith.constant 8 : index
%1437 = arith.cmpi eq, %c8_2318, %c8_2324 : index
cf.assert %1437, "mismatched size for broadcast"
%1438 = tensor.empty() : tensor<2x8x64xf32>
%1439 = linalg.generic {indexing_maps = [#map2, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2287, %cast_2313 : tensor<2x8x64xf32>, tensor<2x8x1xf32>) outs(%1438 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_2325 = tensor.cast %1439 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%1440 = torch_c.from_builtin_tensor %cast_2325 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%c1_2326 = arith.constant 1 : index
%c0_2327 = arith.constant 0 : index
%c2_2328 = arith.constant 2 : index
%c1_2329 = arith.constant 1 : index
%c8_2330 = arith.constant 8 : index
%c2_2331 = arith.constant 2 : index
%c64_2332 = arith.constant 64 : index
%c0_2333 = arith.constant 0 : index
%c64_2334 = arith.constant 64 : index
%1441 = arith.cmpi eq, %c64_2332, %c64_2334 : index
cf.assert %1441, "mismatched size for broadcast"
%1442 = tensor.empty() : tensor<2x8x64xf32>
%1443 = linalg.generic {indexing_maps = [#map2, #map5, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2325, %21 : tensor<2x8x64xf32>, tensor<64xf32>) outs(%1442 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_2335 = tensor.cast %1443 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%int0_2336 = torch.constant.int 0
%int1_2337 = torch.constant.int 1
%c0_2338 = arith.constant 0 : index
%c256_2339 = arith.constant 256 : index
%c1_2340 = arith.constant 1 : index
%c64_2341 = arith.constant 64 : index
%1444 = tensor.empty() : tensor<64x256xf32>
%1445 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%22 : tensor<256x64xf32>) outs(%1444 : tensor<64x256xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x256xf32>
%cast_2342 = tensor.cast %1445 : tensor<64x256xf32> to tensor<64x256xf32>
%1446 = torch_c.from_builtin_tensor %cast_2342 : tensor<64x256xf32> -> !torch.vtensor<[64,256],f32>
%int16_2343 = torch.constant.int 16
%int64_2344 = torch.constant.int 64
%1447 = torch.prim.ListConstruct %int16_2343, %int64_2344 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_2345 = arith.constant 0 : index
%c2_2346 = arith.constant 2 : index
%c1_2347 = arith.constant 1 : index
%c8_2348 = arith.constant 8 : index
%c2_2349 = arith.constant 2 : index
%c64_2350 = arith.constant 64 : index
%1448 = torch_c.to_i64 %int16_2343
%1449 = torch_c.to_i64 %int64_2344
%collapsed_2351 = tensor.collapse_shape %cast_2335 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%1450 = torch_c.from_builtin_tensor %collapsed_2351 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_2352 = arith.constant 0 : index
%dim_2353 = tensor.dim %collapsed_2351, %c0_2352 : tensor<16x64xf32>
%c1_2354 = arith.constant 1 : index
%dim_2355 = tensor.dim %cast_2342, %c1_2354 : tensor<64x256xf32>
%c1_2356 = arith.constant 1 : index
%dim_2357 = tensor.dim %collapsed_2351, %c1_2356 : tensor<16x64xf32>
%c0_2358 = arith.constant 0 : index
%dim_2359 = tensor.dim %cast_2342, %c0_2358 : tensor<64x256xf32>
%1451 = arith.cmpi eq, %dim_2357, %dim_2359 : index
cf.assert %1451, "mismatching contracting dimension for torch.aten.mm"
%1452 = tensor.empty(%dim_2353, %dim_2355) : tensor<?x?xf32>
%cst_2360 = arith.constant 0.000000e+00 : f32
%1453 = linalg.fill ins(%cst_2360 : f32) outs(%1452 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1454 = linalg.matmul ins(%collapsed_2351, %cast_2342 : tensor<16x64xf32>, tensor<64x256xf32>) outs(%1453 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_2361 = tensor.cast %1454 : tensor<?x?xf32> to tensor<16x256xf32>
%int2_2362 = torch.constant.int 2
%int8_2363 = torch.constant.int 8
%int256_2364 = torch.constant.int 256
%1455 = torch.prim.ListConstruct %int2_2362, %int8_2363, %int256_2364 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2365 = arith.constant 0 : index
%c16_2366 = arith.constant 16 : index
%c1_2367 = arith.constant 1 : index
%c256_2368 = arith.constant 256 : index
%1456 = torch_c.to_i64 %int2_2362
%1457 = torch_c.to_i64 %int8_2363
%1458 = torch_c.to_i64 %int256_2364
%expanded_2369 = tensor.expand_shape %cast_2361 [[0, 1], [2]] : tensor<16x256xf32> into tensor<2x8x256xf32>
%1459 = torch_c.from_builtin_tensor %expanded_2369 : tensor<2x8x256xf32> -> !torch.vtensor<[2,8,256],f32>
%1460 = torch.aten.silu %1459 : !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>
%1461 = torch_c.to_builtin_tensor %1460 : !torch.vtensor<[2,8,256],f32> -> tensor<2x8x256xf32>
%int0_2370 = torch.constant.int 0
%int1_2371 = torch.constant.int 1
%c0_2372 = arith.constant 0 : index
%c256_2373 = arith.constant 256 : index
%c1_2374 = arith.constant 1 : index
%c64_2375 = arith.constant 64 : index
%1462 = tensor.empty() : tensor<64x256xf32>
%1463 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%23 : tensor<256x64xf32>) outs(%1462 : tensor<64x256xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x256xf32>
%cast_2376 = tensor.cast %1463 : tensor<64x256xf32> to tensor<64x256xf32>
%1464 = torch_c.from_builtin_tensor %cast_2376 : tensor<64x256xf32> -> !torch.vtensor<[64,256],f32>
%int16_2377 = torch.constant.int 16
%int64_2378 = torch.constant.int 64
%1465 = torch.prim.ListConstruct %int16_2377, %int64_2378 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_2379 = arith.constant 0 : index
%c2_2380 = arith.constant 2 : index
%c1_2381 = arith.constant 1 : index
%c8_2382 = arith.constant 8 : index
%c2_2383 = arith.constant 2 : index
%c64_2384 = arith.constant 64 : index
%1466 = torch_c.to_i64 %int16_2377
%1467 = torch_c.to_i64 %int64_2378
%collapsed_2385 = tensor.collapse_shape %cast_2335 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%1468 = torch_c.from_builtin_tensor %collapsed_2385 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_2386 = arith.constant 0 : index
%dim_2387 = tensor.dim %collapsed_2385, %c0_2386 : tensor<16x64xf32>
%c1_2388 = arith.constant 1 : index
%dim_2389 = tensor.dim %cast_2376, %c1_2388 : tensor<64x256xf32>
%c1_2390 = arith.constant 1 : index
%dim_2391 = tensor.dim %collapsed_2385, %c1_2390 : tensor<16x64xf32>
%c0_2392 = arith.constant 0 : index
%dim_2393 = tensor.dim %cast_2376, %c0_2392 : tensor<64x256xf32>
%1469 = arith.cmpi eq, %dim_2391, %dim_2393 : index
cf.assert %1469, "mismatching contracting dimension for torch.aten.mm"
%1470 = tensor.empty(%dim_2387, %dim_2389) : tensor<?x?xf32>
%cst_2394 = arith.constant 0.000000e+00 : f32
%1471 = linalg.fill ins(%cst_2394 : f32) outs(%1470 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1472 = linalg.matmul ins(%collapsed_2385, %cast_2376 : tensor<16x64xf32>, tensor<64x256xf32>) outs(%1471 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_2395 = tensor.cast %1472 : tensor<?x?xf32> to tensor<16x256xf32>
%int2_2396 = torch.constant.int 2
%int8_2397 = torch.constant.int 8
%int256_2398 = torch.constant.int 256
%1473 = torch.prim.ListConstruct %int2_2396, %int8_2397, %int256_2398 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2399 = arith.constant 0 : index
%c16_2400 = arith.constant 16 : index
%c1_2401 = arith.constant 1 : index
%c256_2402 = arith.constant 256 : index
%1474 = torch_c.to_i64 %int2_2396
%1475 = torch_c.to_i64 %int8_2397
%1476 = torch_c.to_i64 %int256_2398
%expanded_2403 = tensor.expand_shape %cast_2395 [[0, 1], [2]] : tensor<16x256xf32> into tensor<2x8x256xf32>
%1477 = torch_c.from_builtin_tensor %expanded_2403 : tensor<2x8x256xf32> -> !torch.vtensor<[2,8,256],f32>
%c1_2404 = arith.constant 1 : index
%c0_2405 = arith.constant 0 : index
%c2_2406 = arith.constant 2 : index
%c1_2407 = arith.constant 1 : index
%c8_2408 = arith.constant 8 : index
%c2_2409 = arith.constant 2 : index
%c256_2410 = arith.constant 256 : index
%c0_2411 = arith.constant 0 : index
%c2_2412 = arith.constant 2 : index
%1478 = arith.cmpi eq, %c2_2406, %c2_2412 : index
cf.assert %1478, "mismatched size for broadcast"
%c1_2413 = arith.constant 1 : index
%c8_2414 = arith.constant 8 : index
%1479 = arith.cmpi eq, %c8_2408, %c8_2414 : index
cf.assert %1479, "mismatched size for broadcast"
%c2_2415 = arith.constant 2 : index
%c256_2416 = arith.constant 256 : index
%1480 = arith.cmpi eq, %c256_2410, %c256_2416 : index
cf.assert %1480, "mismatched size for broadcast"
%1481 = tensor.empty() : tensor<2x8x256xf32>
%1482 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1461, %expanded_2403 : tensor<2x8x256xf32>, tensor<2x8x256xf32>) outs(%1481 : tensor<2x8x256xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x256xf32>
%cast_2417 = tensor.cast %1482 : tensor<2x8x256xf32> to tensor<2x8x256xf32>
%int0_2418 = torch.constant.int 0
%int1_2419 = torch.constant.int 1
%c0_2420 = arith.constant 0 : index
%c64_2421 = arith.constant 64 : index
%c1_2422 = arith.constant 1 : index
%c256_2423 = arith.constant 256 : index
%1483 = tensor.empty() : tensor<256x64xf32>
%1484 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%24 : tensor<64x256xf32>) outs(%1483 : tensor<256x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<256x64xf32>
%cast_2424 = tensor.cast %1484 : tensor<256x64xf32> to tensor<256x64xf32>
%1485 = torch_c.from_builtin_tensor %cast_2424 : tensor<256x64xf32> -> !torch.vtensor<[256,64],f32>
%int16_2425 = torch.constant.int 16
%int256_2426 = torch.constant.int 256
%1486 = torch.prim.ListConstruct %int16_2425, %int256_2426 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_2427 = arith.constant 0 : index
%c2_2428 = arith.constant 2 : index
%c1_2429 = arith.constant 1 : index
%c8_2430 = arith.constant 8 : index
%c2_2431 = arith.constant 2 : index
%c256_2432 = arith.constant 256 : index
%1487 = torch_c.to_i64 %int16_2425
%1488 = torch_c.to_i64 %int256_2426
%collapsed_2433 = tensor.collapse_shape %cast_2417 [[0, 1], [2]] : tensor<2x8x256xf32> into tensor<16x256xf32>
%1489 = torch_c.from_builtin_tensor %collapsed_2433 : tensor<16x256xf32> -> !torch.vtensor<[16,256],f32>
%c0_2434 = arith.constant 0 : index
%dim_2435 = tensor.dim %collapsed_2433, %c0_2434 : tensor<16x256xf32>
%c1_2436 = arith.constant 1 : index
%dim_2437 = tensor.dim %cast_2424, %c1_2436 : tensor<256x64xf32>
%c1_2438 = arith.constant 1 : index
%dim_2439 = tensor.dim %collapsed_2433, %c1_2438 : tensor<16x256xf32>
%c0_2440 = arith.constant 0 : index
%dim_2441 = tensor.dim %cast_2424, %c0_2440 : tensor<256x64xf32>
%1490 = arith.cmpi eq, %dim_2439, %dim_2441 : index
cf.assert %1490, "mismatching contracting dimension for torch.aten.mm"
%1491 = tensor.empty(%dim_2435, %dim_2437) : tensor<?x?xf32>
%cst_2442 = arith.constant 0.000000e+00 : f32
%1492 = linalg.fill ins(%cst_2442 : f32) outs(%1491 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1493 = linalg.matmul ins(%collapsed_2433, %cast_2424 : tensor<16x256xf32>, tensor<256x64xf32>) outs(%1492 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_2443 = tensor.cast %1493 : tensor<?x?xf32> to tensor<16x64xf32>
%int2_2444 = torch.constant.int 2
%int8_2445 = torch.constant.int 8
%int64_2446 = torch.constant.int 64
%1494 = torch.prim.ListConstruct %int2_2444, %int8_2445, %int64_2446 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2447 = arith.constant 0 : index
%c16_2448 = arith.constant 16 : index
%c1_2449 = arith.constant 1 : index
%c64_2450 = arith.constant 64 : index
%1495 = torch_c.to_i64 %int2_2444
%1496 = torch_c.to_i64 %int8_2445
%1497 = torch_c.to_i64 %int64_2446
%expanded_2451 = tensor.expand_shape %cast_2443 [[0, 1], [2]] : tensor<16x64xf32> into tensor<2x8x64xf32>
%int1_2452 = torch.constant.int 1
%1498 = torch_c.to_i64 %int1_2452
%c1_2453 = arith.constant 1 : index
%c0_2454 = arith.constant 0 : index
%c2_2455 = arith.constant 2 : index
%c1_2456 = arith.constant 1 : index
%c8_2457 = arith.constant 8 : index
%c2_2458 = arith.constant 2 : index
%c64_2459 = arith.constant 64 : index
%c0_2460 = arith.constant 0 : index
%c2_2461 = arith.constant 2 : index
%1499 = arith.cmpi eq, %c2_2455, %c2_2461 : index
cf.assert %1499, "mismatched size for broadcast"
%c1_2462 = arith.constant 1 : index
%c8_2463 = arith.constant 8 : index
%1500 = arith.cmpi eq, %c8_2457, %c8_2463 : index
cf.assert %1500, "mismatched size for broadcast"
%c2_2464 = arith.constant 2 : index
%c64_2465 = arith.constant 64 : index
%1501 = arith.cmpi eq, %c64_2459, %c64_2465 : index
cf.assert %1501, "mismatched size for broadcast"
%1502 = tensor.empty() : tensor<2x8x64xf32>
%1503 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2287, %expanded_2451 : tensor<2x8x64xf32>, tensor<2x8x64xf32>) outs(%1502 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.sitofp %1498 : i64 to f32
%1544 = arith.mulf %in_2549, %1543 : f32
%1545 = arith.addf %in, %1544 : f32
linalg.yield %1545 : f32
} -> tensor<2x8x64xf32>
%cast_2466 = tensor.cast %1503 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%1504 = torch_c.from_builtin_tensor %cast_2466 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int2_2467 = torch.constant.int 2
%1505 = torch_c.to_i64 %int2_2467
%c1_2468 = arith.constant 1 : index
%c0_2469 = arith.constant 0 : index
%c2_2470 = arith.constant 2 : index
%c1_2471 = arith.constant 1 : index
%c8_2472 = arith.constant 8 : index
%c2_2473 = arith.constant 2 : index
%c64_2474 = arith.constant 64 : index
%1506 = tensor.empty() : tensor<2x8x64xf32>
%1507 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2466 : tensor<2x8x64xf32>) outs(%1506 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.sitofp %1505 : i64 to f32
%1544 = math.powf %in, %1543 : f32
linalg.yield %1544 : f32
} -> tensor<2x8x64xf32>
%cast_2475 = tensor.cast %1507 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%1508 = torch_c.from_builtin_tensor %cast_2475 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%int-1_2476 = torch.constant.int -1
%1509 = torch.prim.ListConstruct %int-1_2476 : (!torch.int) -> !torch.list<int>
%true_2477 = torch.constant.bool true
%none_2478 = torch.constant.none
%1510 = torch.aten.mean.dim %1508, %1509, %true_2477, %none_2478 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%1511 = torch_c.to_builtin_tensor %1510 : !torch.vtensor<[2,8,1],f32> -> tensor<2x8x1xf32>
%float1.000000e-05_2479 = torch.constant.float 1.000000e-05
%1512 = torch_c.to_f64 %float1.000000e-05_2479
%int1_2480 = torch.constant.int 1
%1513 = torch_c.to_i64 %int1_2480
%c1_2481 = arith.constant 1 : index
%c0_2482 = arith.constant 0 : index
%c2_2483 = arith.constant 2 : index
%c1_2484 = arith.constant 1 : index
%c8_2485 = arith.constant 8 : index
%1514 = tensor.empty() : tensor<2x8x1xf32>
%1515 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1511 : tensor<2x8x1xf32>) outs(%1514 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = arith.truncf %1512 : f64 to f32
%1544 = arith.sitofp %1513 : i64 to f32
%1545 = arith.mulf %1543, %1544 : f32
%1546 = arith.addf %in, %1545 : f32
linalg.yield %1546 : f32
} -> tensor<2x8x1xf32>
%cast_2486 = tensor.cast %1515 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%c1_2487 = arith.constant 1 : index
%c0_2488 = arith.constant 0 : index
%c2_2489 = arith.constant 2 : index
%c1_2490 = arith.constant 1 : index
%c8_2491 = arith.constant 8 : index
%1516 = tensor.empty() : tensor<2x8x1xf32>
%1517 = linalg.generic {indexing_maps = [#map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2486 : tensor<2x8x1xf32>) outs(%1516 : tensor<2x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1543 = math.rsqrt %in : f32
linalg.yield %1543 : f32
} -> tensor<2x8x1xf32>
%cast_2492 = tensor.cast %1517 : tensor<2x8x1xf32> to tensor<2x8x1xf32>
%1518 = torch_c.from_builtin_tensor %cast_2492 : tensor<2x8x1xf32> -> !torch.vtensor<[2,8,1],f32>
%c1_2493 = arith.constant 1 : index
%c0_2494 = arith.constant 0 : index
%c2_2495 = arith.constant 2 : index
%c1_2496 = arith.constant 1 : index
%c8_2497 = arith.constant 8 : index
%c2_2498 = arith.constant 2 : index
%c64_2499 = arith.constant 64 : index
%c0_2500 = arith.constant 0 : index
%c2_2501 = arith.constant 2 : index
%1519 = arith.cmpi eq, %c2_2495, %c2_2501 : index
cf.assert %1519, "mismatched size for broadcast"
%c1_2502 = arith.constant 1 : index
%c8_2503 = arith.constant 8 : index
%1520 = arith.cmpi eq, %c8_2497, %c8_2503 : index
cf.assert %1520, "mismatched size for broadcast"
%1521 = tensor.empty() : tensor<2x8x64xf32>
%1522 = linalg.generic {indexing_maps = [#map2, #map4, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2466, %cast_2492 : tensor<2x8x64xf32>, tensor<2x8x1xf32>) outs(%1521 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_2504 = tensor.cast %1522 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%1523 = torch_c.from_builtin_tensor %cast_2504 : tensor<2x8x64xf32> -> !torch.vtensor<[2,8,64],f32>
%c1_2505 = arith.constant 1 : index
%c0_2506 = arith.constant 0 : index
%c2_2507 = arith.constant 2 : index
%c1_2508 = arith.constant 1 : index
%c8_2509 = arith.constant 8 : index
%c2_2510 = arith.constant 2 : index
%c64_2511 = arith.constant 64 : index
%c0_2512 = arith.constant 0 : index
%c64_2513 = arith.constant 64 : index
%1524 = arith.cmpi eq, %c64_2511, %c64_2513 : index
cf.assert %1524, "mismatched size for broadcast"
%1525 = tensor.empty() : tensor<2x8x64xf32>
%1526 = linalg.generic {indexing_maps = [#map2, #map5, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_2504, %25 : tensor<2x8x64xf32>, tensor<64xf32>) outs(%1525 : tensor<2x8x64xf32>) {
^bb0(%in: f32, %in_2549: f32, %out: f32):
%1543 = arith.mulf %in, %in_2549 : f32
linalg.yield %1543 : f32
} -> tensor<2x8x64xf32>
%cast_2514 = tensor.cast %1526 : tensor<2x8x64xf32> to tensor<2x8x64xf32>
%int0_2515 = torch.constant.int 0
%int1_2516 = torch.constant.int 1
%c0_2517 = arith.constant 0 : index
%c16_2518 = arith.constant 16 : index
%c1_2519 = arith.constant 1 : index
%c64_2520 = arith.constant 64 : index
%1527 = tensor.empty() : tensor<64x16xf32>
%1528 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel"]} ins(%26 : tensor<16x64xf32>) outs(%1527 : tensor<64x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x16xf32>
%cast_2521 = tensor.cast %1528 : tensor<64x16xf32> to tensor<64x16xf32>
%1529 = torch_c.from_builtin_tensor %cast_2521 : tensor<64x16xf32> -> !torch.vtensor<[64,16],f32>
%int16_2522 = torch.constant.int 16
%int64_2523 = torch.constant.int 64
%1530 = torch.prim.ListConstruct %int16_2522, %int64_2523 : (!torch.int, !torch.int) -> !torch.list<int>
%c0_2524 = arith.constant 0 : index
%c2_2525 = arith.constant 2 : index
%c1_2526 = arith.constant 1 : index
%c8_2527 = arith.constant 8 : index
%c2_2528 = arith.constant 2 : index
%c64_2529 = arith.constant 64 : index
%1531 = torch_c.to_i64 %int16_2522
%1532 = torch_c.to_i64 %int64_2523
%collapsed_2530 = tensor.collapse_shape %cast_2514 [[0, 1], [2]] : tensor<2x8x64xf32> into tensor<16x64xf32>
%1533 = torch_c.from_builtin_tensor %collapsed_2530 : tensor<16x64xf32> -> !torch.vtensor<[16,64],f32>
%c0_2531 = arith.constant 0 : index
%dim_2532 = tensor.dim %collapsed_2530, %c0_2531 : tensor<16x64xf32>
%c1_2533 = arith.constant 1 : index
%dim_2534 = tensor.dim %cast_2521, %c1_2533 : tensor<64x16xf32>
%c1_2535 = arith.constant 1 : index
%dim_2536 = tensor.dim %collapsed_2530, %c1_2535 : tensor<16x64xf32>
%c0_2537 = arith.constant 0 : index
%dim_2538 = tensor.dim %cast_2521, %c0_2537 : tensor<64x16xf32>
%1534 = arith.cmpi eq, %dim_2536, %dim_2538 : index
cf.assert %1534, "mismatching contracting dimension for torch.aten.mm"
%1535 = tensor.empty(%dim_2532, %dim_2534) : tensor<?x?xf32>
%cst_2539 = arith.constant 0.000000e+00 : f32
%1536 = linalg.fill ins(%cst_2539 : f32) outs(%1535 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1537 = linalg.matmul ins(%collapsed_2530, %cast_2521 : tensor<16x64xf32>, tensor<64x16xf32>) outs(%1536 : tensor<?x?xf32>) -> tensor<?x?xf32>
%cast_2540 = tensor.cast %1537 : tensor<?x?xf32> to tensor<16x16xf32>
%int2_2541 = torch.constant.int 2
%int8_2542 = torch.constant.int 8
%int16_2543 = torch.constant.int 16
%1538 = torch.prim.ListConstruct %int2_2541, %int8_2542, %int16_2543 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%c0_2544 = arith.constant 0 : index
%c16_2545 = arith.constant 16 : index
%c1_2546 = arith.constant 1 : index
%c16_2547 = arith.constant 16 : index
%1539 = torch_c.to_i64 %int2_2541
%1540 = torch_c.to_i64 %int8_2542
%1541 = torch_c.to_i64 %int16_2543
%expanded_2548 = tensor.expand_shape %cast_2540 [[0, 1], [2]] : tensor<16x16xf32> into tensor<2x8x16xf32>
%1542 = torch_c.from_builtin_tensor %expanded_2548 : tensor<2x8x16xf32> -> !torch.vtensor<[2,8,16],f32>
return %342, %484, %1057, %1199, %1542, %arg0, %arg1, %arg2, %arg3, %arg4, %arg26, %37, %88, %93, %99, %103, %114, %118, %129, %133, %178, %495, %563, %588, %596, %661, %687, %691, %706, %720, %725, %731, %735, %744, %745, %749, %753, %762, %770, %774, %789, %803, %808, %814, %818, %829, %833, %844, %848, %893, %1210, %1278, %1303, %1311, %1376, %1402, %1406, %1421, %1435, %1440, %1446, %1450, %1459, %1460, %1464, %1468, %1477, %1485, %1489, %1504, %1518, %1523, %1529, %1533 : !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,8,16],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[2,8],si64>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,16],f32>, !torch.vtensor<[16,64],f32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment