Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 5, 2023 05:13
Show Gist options
  • Save AmosLewis/5eb8e66e899257e7456af738269b91bf to your computer and use it in GitHub Desktop.
Save AmosLewis/5eb8e66e899257e7456af738269b91bf to your computer and use it in GitHub Desktop.
module {
func.func @main(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[64],f32>, %arg2: !torch.vtensor<[64],f32>, %arg3: !torch.vtensor<[64],f32>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[16,64],f32>, %arg6: !torch.vtensor<[64,64],f32>, %arg7: !torch.vtensor<[64,64],f32>, %arg8: !torch.vtensor<[64,64],f32>, %arg9: !torch.vtensor<[64,64],f32>, %arg10: !torch.vtensor<[256,64],f32>, %arg11: !torch.vtensor<[256,64],f32>, %arg12: !torch.vtensor<[64,256],f32>, %arg13: !torch.vtensor<[64,64],f32>, %arg14: !torch.vtensor<[64,64],f32>, %arg15: !torch.vtensor<[64,64],f32>, %arg16: !torch.vtensor<[64,64],f32>, %arg17: !torch.vtensor<[256,64],f32>, %arg18: !torch.vtensor<[256,64],f32>, %arg19: !torch.vtensor<[64,256],f32>, %arg20: !torch.vtensor<[16,64],f32>, %arg21: !torch.vtensor<[4096,8],complex<f32>>, %arg22: !torch.vtensor<[32,2048,4,16],f32>, %arg23: !torch.vtensor<[32,2048,4,16],f32>, %arg24: !torch.vtensor<[32,2048,4,16],f32>, %arg25: !torch.vtensor<[32,2048,4,16],f32>, %arg26: !torch.vtensor<[2,8],si64>) -> (!torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,8,16],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[2,8],si64>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,16],f32>, !torch.vtensor<[16,64],f32>) {
%none = torch.constant.none
%0 = torch.aten.clone %arg22, %none : !torch.vtensor<[32,2048,4,16],f32>, !torch.none -> !torch.vtensor<[32,2048,4,16],f32>
%none_0 = torch.constant.none
%1 = torch.aten.clone %arg23, %none_0 : !torch.vtensor<[32,2048,4,16],f32>, !torch.none -> !torch.vtensor<[32,2048,4,16],f32>
%none_1 = torch.constant.none
%2 = torch.aten.clone %arg24, %none_1 : !torch.vtensor<[32,2048,4,16],f32>, !torch.none -> !torch.vtensor<[32,2048,4,16],f32>
%none_2 = torch.constant.none
%3 = torch.aten.clone %arg25, %none_2 : !torch.vtensor<[32,2048,4,16],f32>, !torch.none -> !torch.vtensor<[32,2048,4,16],f32>
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%false_3 = torch.constant.bool false
%4 = torch.aten.embedding %arg5, %arg26, %int-1, %false, %false_3 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,8,64],f32>
%int0 = torch.constant.int 0
%int0_4 = torch.constant.int 0
%int8 = torch.constant.int 8
%int1 = torch.constant.int 1
%5 = torch.aten.slice.Tensor %arg21, %int0, %int0_4, %int8, %int1 : !torch.vtensor<[4096,8],complex<f32>>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8,8],complex<f32>>
%int1_5 = torch.constant.int 1
%int1_6 = torch.constant.int 1
%int8_7 = torch.constant.int 8
%int8_8 = torch.constant.int 8
%6 = torch.prim.ListConstruct %int1_5, %int1_6, %int8_7, %int8_8 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int64 = torch.constant.int 64
%int64_9 = torch.constant.int 64
%int8_10 = torch.constant.int 8
%int1_11 = torch.constant.int 1
%7 = torch.prim.ListConstruct %int64, %int64_9, %int8_10, %int1_11 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int6 = torch.constant.int 6
%int0_12 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%false_13 = torch.constant.bool false
%8 = torch.aten.empty_strided %6, %7, %int6, %int0_12, %cpu, %false_13 : !torch.list<int>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,1,8,8],f32>
%float-Inf = torch.constant.float 0xFFF0000000000000
%9 = torch.aten.fill.Scalar %8, %float-Inf : !torch.vtensor<[1,1,8,8],f32>, !torch.float -> !torch.vtensor<[1,1,8,8],f32>
%int1_14 = torch.constant.int 1
%10 = torch.aten.triu %9, %int1_14 : !torch.vtensor<[1,1,8,8],f32>, !torch.int -> !torch.vtensor<[1,1,8,8],f32>
%int2 = torch.constant.int 2
%11 = torch.aten.pow.Tensor_Scalar %4, %int2 : !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int-1_15 = torch.constant.int -1
%12 = torch.prim.ListConstruct %int-1_15 : (!torch.int) -> !torch.list<int>
%true = torch.constant.bool true
%none_16 = torch.constant.none
%13 = torch.aten.mean.dim %11, %12, %true, %none_16 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%float1.000000e-05 = torch.constant.float 1.000000e-05
%int1_17 = torch.constant.int 1
%14 = torch.aten.add.Scalar %13, %float1.000000e-05, %int1_17 : !torch.vtensor<[2,8,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[2,8,1],f32>
%15 = torch.aten.rsqrt %14 : !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,1],f32>
%16 = torch.aten.mul.Tensor %4, %15 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,64],f32>
%17 = torch.aten.mul.Tensor %16, %arg0 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[2,8,64],f32>
%int0_18 = torch.constant.int 0
%int1_19 = torch.constant.int 1
%18 = torch.aten.transpose.int %arg6, %int0_18, %int1_19 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16 = torch.constant.int 16
%int64_20 = torch.constant.int 64
%19 = torch.prim.ListConstruct %int16, %int64_20 : (!torch.int, !torch.int) -> !torch.list<int>
%20 = torch.aten.view %17, %19 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%21 = torch.aten.mm %20, %18 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_21 = torch.constant.int 2
%int8_22 = torch.constant.int 8
%int64_23 = torch.constant.int 64
%22 = torch.prim.ListConstruct %int2_21, %int8_22, %int64_23 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%23 = torch.aten.view %21, %22 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int0_24 = torch.constant.int 0
%int1_25 = torch.constant.int 1
%24 = torch.aten.transpose.int %arg7, %int0_24, %int1_25 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16_26 = torch.constant.int 16
%int64_27 = torch.constant.int 64
%25 = torch.prim.ListConstruct %int16_26, %int64_27 : (!torch.int, !torch.int) -> !torch.list<int>
%26 = torch.aten.view %17, %25 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%27 = torch.aten.mm %26, %24 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_28 = torch.constant.int 2
%int8_29 = torch.constant.int 8
%int64_30 = torch.constant.int 64
%28 = torch.prim.ListConstruct %int2_28, %int8_29, %int64_30 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%29 = torch.aten.view %27, %28 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int0_31 = torch.constant.int 0
%int1_32 = torch.constant.int 1
%30 = torch.aten.transpose.int %arg8, %int0_31, %int1_32 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16_33 = torch.constant.int 16
%int64_34 = torch.constant.int 64
%31 = torch.prim.ListConstruct %int16_33, %int64_34 : (!torch.int, !torch.int) -> !torch.list<int>
%32 = torch.aten.view %17, %31 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%33 = torch.aten.mm %32, %30 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_35 = torch.constant.int 2
%int8_36 = torch.constant.int 8
%int64_37 = torch.constant.int 64
%34 = torch.prim.ListConstruct %int2_35, %int8_36, %int64_37 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%35 = torch.aten.view %33, %34 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int2_38 = torch.constant.int 2
%int8_39 = torch.constant.int 8
%int4 = torch.constant.int 4
%int16_40 = torch.constant.int 16
%36 = torch.prim.ListConstruct %int2_38, %int8_39, %int4, %int16_40 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%37 = torch.aten.view %23, %36 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int2_41 = torch.constant.int 2
%int8_42 = torch.constant.int 8
%int4_43 = torch.constant.int 4
%int16_44 = torch.constant.int 16
%38 = torch.prim.ListConstruct %int2_41, %int8_42, %int4_43, %int16_44 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%39 = torch.aten.view %29, %38 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int2_45 = torch.constant.int 2
%int8_46 = torch.constant.int 8
%int4_47 = torch.constant.int 4
%int16_48 = torch.constant.int 16
%40 = torch.prim.ListConstruct %int2_45, %int8_46, %int4_47, %int16_48 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%41 = torch.aten.view %35, %40 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int2_49 = torch.constant.int 2
%int8_50 = torch.constant.int 8
%int4_51 = torch.constant.int 4
%int-1_52 = torch.constant.int -1
%int2_53 = torch.constant.int 2
%42 = torch.prim.ListConstruct %int2_49, %int8_50, %int4_51, %int-1_52, %int2_53 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%43 = torch.aten.view %37, %42 : !torch.vtensor<[2,8,4,16],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,8,2],f32>
%44 = torch.aten.view_as_complex %43 : !torch.vtensor<[2,8,4,8,2],f32> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%int2_54 = torch.constant.int 2
%int8_55 = torch.constant.int 8
%int4_56 = torch.constant.int 4
%int-1_57 = torch.constant.int -1
%int2_58 = torch.constant.int 2
%45 = torch.prim.ListConstruct %int2_54, %int8_55, %int4_56, %int-1_57, %int2_58 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%46 = torch.aten.view %39, %45 : !torch.vtensor<[2,8,4,16],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,8,2],f32>
%47 = torch.aten.view_as_complex %46 : !torch.vtensor<[2,8,4,8,2],f32> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%int1_59 = torch.constant.int 1
%int8_60 = torch.constant.int 8
%int1_61 = torch.constant.int 1
%int8_62 = torch.constant.int 8
%48 = torch.prim.ListConstruct %int1_59, %int8_60, %int1_61, %int8_62 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%49 = torch.aten.view %5, %48 : !torch.vtensor<[8,8],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,8,1,8],complex<f32>>
%50 = torch.aten.mul.Tensor %44, %49 : !torch.vtensor<[2,8,4,8],complex<f32>>, !torch.vtensor<[1,8,1,8],complex<f32>> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%51 = torch.aten.view_as_real %50 : !torch.vtensor<[2,8,4,8],complex<f32>> -> !torch.vtensor<[2,8,4,8,2],f32>
%int2_63 = torch.constant.int 2
%int8_64 = torch.constant.int 8
%int4_65 = torch.constant.int 4
%int16_66 = torch.constant.int 16
%52 = torch.prim.ListConstruct %int2_63, %int8_64, %int4_65, %int16_66 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%53 = torch.aten.view %51, %52 : !torch.vtensor<[2,8,4,8,2],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%54 = torch.aten.mul.Tensor %47, %49 : !torch.vtensor<[2,8,4,8],complex<f32>>, !torch.vtensor<[1,8,1,8],complex<f32>> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%55 = torch.aten.view_as_real %54 : !torch.vtensor<[2,8,4,8],complex<f32>> -> !torch.vtensor<[2,8,4,8,2],f32>
%int2_67 = torch.constant.int 2
%int8_68 = torch.constant.int 8
%int4_69 = torch.constant.int 4
%int16_70 = torch.constant.int 16
%56 = torch.prim.ListConstruct %int2_67, %int8_68, %int4_69, %int16_70 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%57 = torch.aten.view %55, %56 : !torch.vtensor<[2,8,4,8,2],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int0_71 = torch.constant.int 0
%int0_72 = torch.constant.int 0
%int2_73 = torch.constant.int 2
%int1_74 = torch.constant.int 1
%58 = torch.aten.slice.Tensor %0, %int0_71, %int0_72, %int2_73, %int1_74 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_75 = torch.constant.int 1
%int0_76 = torch.constant.int 0
%int8_77 = torch.constant.int 8
%int1_78 = torch.constant.int 1
%59 = torch.aten.slice.Tensor %58, %int1_75, %int0_76, %int8_77, %int1_78 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%false_79 = torch.constant.bool false
%60 = torch.aten.copy %59, %57, %false_79 : !torch.vtensor<[2,8,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.bool -> !torch.vtensor<[2,8,4,16],f32>
%int0_80 = torch.constant.int 0
%int0_81 = torch.constant.int 0
%int2_82 = torch.constant.int 2
%int1_83 = torch.constant.int 1
%61 = torch.aten.slice.Tensor %0, %int0_80, %int0_81, %int2_82, %int1_83 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_84 = torch.constant.int 1
%int0_85 = torch.constant.int 0
%int8_86 = torch.constant.int 8
%int1_87 = torch.constant.int 1
%62 = torch.aten.slice_scatter %61, %60, %int1_84, %int0_85, %int8_86, %int1_87 : !torch.vtensor<[2,2048,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int0_88 = torch.constant.int 0
%int0_89 = torch.constant.int 0
%int2_90 = torch.constant.int 2
%int1_91 = torch.constant.int 1
%63 = torch.aten.slice_scatter %0, %62, %int0_88, %int0_89, %int2_90, %int1_91 : !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[32,2048,4,16],f32>
%int0_92 = torch.constant.int 0
%int0_93 = torch.constant.int 0
%int2_94 = torch.constant.int 2
%int1_95 = torch.constant.int 1
%64 = torch.aten.slice.Tensor %1, %int0_92, %int0_93, %int2_94, %int1_95 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_96 = torch.constant.int 1
%int0_97 = torch.constant.int 0
%int8_98 = torch.constant.int 8
%int1_99 = torch.constant.int 1
%65 = torch.aten.slice.Tensor %64, %int1_96, %int0_97, %int8_98, %int1_99 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%false_100 = torch.constant.bool false
%66 = torch.aten.copy %65, %41, %false_100 : !torch.vtensor<[2,8,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.bool -> !torch.vtensor<[2,8,4,16],f32>
%int0_101 = torch.constant.int 0
%int0_102 = torch.constant.int 0
%int2_103 = torch.constant.int 2
%int1_104 = torch.constant.int 1
%67 = torch.aten.slice.Tensor %1, %int0_101, %int0_102, %int2_103, %int1_104 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_105 = torch.constant.int 1
%int0_106 = torch.constant.int 0
%int8_107 = torch.constant.int 8
%int1_108 = torch.constant.int 1
%68 = torch.aten.slice_scatter %67, %66, %int1_105, %int0_106, %int8_107, %int1_108 : !torch.vtensor<[2,2048,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int0_109 = torch.constant.int 0
%int0_110 = torch.constant.int 0
%int2_111 = torch.constant.int 2
%int1_112 = torch.constant.int 1
%69 = torch.aten.slice_scatter %1, %68, %int0_109, %int0_110, %int2_111, %int1_112 : !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[32,2048,4,16],f32>
%int1_113 = torch.constant.int 1
%int2_114 = torch.constant.int 2
%70 = torch.aten.transpose.int %53, %int1_113, %int2_114 : !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int2_115 = torch.constant.int 2
%int4_116 = torch.constant.int 4
%int8_117 = torch.constant.int 8
%int16_118 = torch.constant.int 16
%71 = torch.prim.ListConstruct %int2_115, %int4_116, %int8_117, %int16_118 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_119 = torch.constant.bool false
%72 = torch.aten.expand %70, %71, %false_119 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%int0_120 = torch.constant.int 0
%73 = torch.aten.clone %72, %int0_120 : !torch.vtensor<[2,4,8,16],f32>, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int8_121 = torch.constant.int 8
%int8_122 = torch.constant.int 8
%int16_123 = torch.constant.int 16
%74 = torch.prim.ListConstruct %int8_121, %int8_122, %int16_123 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%75 = torch.aten._unsafe_view %73, %74 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%int0_124 = torch.constant.int 0
%int0_125 = torch.constant.int 0
%int2_126 = torch.constant.int 2
%int1_127 = torch.constant.int 1
%76 = torch.aten.slice.Tensor %63, %int0_124, %int0_125, %int2_126, %int1_127 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_128 = torch.constant.int 1
%int0_129 = torch.constant.int 0
%int8_130 = torch.constant.int 8
%int1_131 = torch.constant.int 1
%77 = torch.aten.slice.Tensor %76, %int1_128, %int0_129, %int8_130, %int1_131 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int1_132 = torch.constant.int 1
%int2_133 = torch.constant.int 2
%78 = torch.aten.transpose.int %77, %int1_132, %int2_133 : !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int2_134 = torch.constant.int 2
%int3 = torch.constant.int 3
%79 = torch.aten.transpose.int %78, %int2_134, %int3 : !torch.vtensor<[2,4,8,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,16,8],f32>
%int2_135 = torch.constant.int 2
%int4_136 = torch.constant.int 4
%int16_137 = torch.constant.int 16
%int8_138 = torch.constant.int 8
%80 = torch.prim.ListConstruct %int2_135, %int4_136, %int16_137, %int8_138 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_139 = torch.constant.bool false
%81 = torch.aten.expand %79, %80, %false_139 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,16,8],f32>
%int0_140 = torch.constant.int 0
%82 = torch.aten.clone %81, %int0_140 : !torch.vtensor<[2,4,16,8],f32>, !torch.int -> !torch.vtensor<[2,4,16,8],f32>
%int8_141 = torch.constant.int 8
%int16_142 = torch.constant.int 16
%int8_143 = torch.constant.int 8
%83 = torch.prim.ListConstruct %int8_141, %int16_142, %int8_143 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%84 = torch.aten._unsafe_view %82, %83 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int> -> !torch.vtensor<[8,16,8],f32>
%85 = torch.aten.bmm %75, %84 : !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32> -> !torch.vtensor<[8,8,8],f32>
%int2_144 = torch.constant.int 2
%int4_145 = torch.constant.int 4
%int8_146 = torch.constant.int 8
%int8_147 = torch.constant.int 8
%86 = torch.prim.ListConstruct %int2_144, %int4_145, %int8_146, %int8_147 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%87 = torch.aten.view %85, %86 : !torch.vtensor<[8,8,8],f32>, !torch.list<int> -> !torch.vtensor<[2,4,8,8],f32>
%float4.000000e00 = torch.constant.float 4.000000e+00
%88 = torch.aten.div.Scalar %87, %float4.000000e00 : !torch.vtensor<[2,4,8,8],f32>, !torch.float -> !torch.vtensor<[2,4,8,8],f32>
%int1_148 = torch.constant.int 1
%89 = torch.aten.add.Tensor %88, %10, %int1_148 : !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[1,1,8,8],f32>, !torch.int -> !torch.vtensor<[2,4,8,8],f32>
%int-1_149 = torch.constant.int -1
%false_150 = torch.constant.bool false
%90 = torch.aten._softmax %89, %int-1_149, %false_150 : !torch.vtensor<[2,4,8,8],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%int2_151 = torch.constant.int 2
%int4_152 = torch.constant.int 4
%int8_153 = torch.constant.int 8
%int8_154 = torch.constant.int 8
%91 = torch.prim.ListConstruct %int2_151, %int4_152, %int8_153, %int8_154 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_155 = torch.constant.bool false
%92 = torch.aten.expand %90, %91, %false_155 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%int8_156 = torch.constant.int 8
%int8_157 = torch.constant.int 8
%int8_158 = torch.constant.int 8
%93 = torch.prim.ListConstruct %int8_156, %int8_157, %int8_158 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%94 = torch.aten.view %92, %93 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int> -> !torch.vtensor<[8,8,8],f32>
%int0_159 = torch.constant.int 0
%int0_160 = torch.constant.int 0
%int2_161 = torch.constant.int 2
%int1_162 = torch.constant.int 1
%95 = torch.aten.slice.Tensor %69, %int0_159, %int0_160, %int2_161, %int1_162 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_163 = torch.constant.int 1
%int0_164 = torch.constant.int 0
%int8_165 = torch.constant.int 8
%int1_166 = torch.constant.int 1
%96 = torch.aten.slice.Tensor %95, %int1_163, %int0_164, %int8_165, %int1_166 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int1_167 = torch.constant.int 1
%int2_168 = torch.constant.int 2
%97 = torch.aten.transpose.int %96, %int1_167, %int2_168 : !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int2_169 = torch.constant.int 2
%int4_170 = torch.constant.int 4
%int8_171 = torch.constant.int 8
%int16_172 = torch.constant.int 16
%98 = torch.prim.ListConstruct %int2_169, %int4_170, %int8_171, %int16_172 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_173 = torch.constant.bool false
%99 = torch.aten.expand %97, %98, %false_173 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%int0_174 = torch.constant.int 0
%100 = torch.aten.clone %99, %int0_174 : !torch.vtensor<[2,4,8,16],f32>, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int8_175 = torch.constant.int 8
%int8_176 = torch.constant.int 8
%int16_177 = torch.constant.int 16
%101 = torch.prim.ListConstruct %int8_175, %int8_176, %int16_177 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%102 = torch.aten._unsafe_view %100, %101 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%103 = torch.aten.bmm %94, %102 : !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32> -> !torch.vtensor<[8,8,16],f32>
%int2_178 = torch.constant.int 2
%int4_179 = torch.constant.int 4
%int8_180 = torch.constant.int 8
%int16_181 = torch.constant.int 16
%104 = torch.prim.ListConstruct %int2_178, %int4_179, %int8_180, %int16_181 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%105 = torch.aten.view %103, %104 : !torch.vtensor<[8,8,16],f32>, !torch.list<int> -> !torch.vtensor<[2,4,8,16],f32>
%int1_182 = torch.constant.int 1
%int2_183 = torch.constant.int 2
%106 = torch.aten.transpose.int %105, %int1_182, %int2_183 : !torch.vtensor<[2,4,8,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int0_184 = torch.constant.int 0
%107 = torch.aten.clone %106, %int0_184 : !torch.vtensor<[2,8,4,16],f32>, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int2_185 = torch.constant.int 2
%int8_186 = torch.constant.int 8
%int-1_187 = torch.constant.int -1
%108 = torch.prim.ListConstruct %int2_185, %int8_186, %int-1_187 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%109 = torch.aten.view %107, %108 : !torch.vtensor<[2,8,4,16],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int0_188 = torch.constant.int 0
%int1_189 = torch.constant.int 1
%110 = torch.aten.transpose.int %arg9, %int0_188, %int1_189 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16_190 = torch.constant.int 16
%int64_191 = torch.constant.int 64
%111 = torch.prim.ListConstruct %int16_190, %int64_191 : (!torch.int, !torch.int) -> !torch.list<int>
%112 = torch.aten.view %109, %111 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%113 = torch.aten.mm %112, %110 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_192 = torch.constant.int 2
%int8_193 = torch.constant.int 8
%int64_194 = torch.constant.int 64
%114 = torch.prim.ListConstruct %int2_192, %int8_193, %int64_194 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%115 = torch.aten.view %113, %114 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int1_195 = torch.constant.int 1
%116 = torch.aten.add.Tensor %4, %115, %int1_195 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int2_196 = torch.constant.int 2
%117 = torch.aten.pow.Tensor_Scalar %116, %int2_196 : !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int-1_197 = torch.constant.int -1
%118 = torch.prim.ListConstruct %int-1_197 : (!torch.int) -> !torch.list<int>
%true_198 = torch.constant.bool true
%none_199 = torch.constant.none
%119 = torch.aten.mean.dim %117, %118, %true_198, %none_199 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%float1.000000e-05_200 = torch.constant.float 1.000000e-05
%int1_201 = torch.constant.int 1
%120 = torch.aten.add.Scalar %119, %float1.000000e-05_200, %int1_201 : !torch.vtensor<[2,8,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[2,8,1],f32>
%121 = torch.aten.rsqrt %120 : !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,1],f32>
%122 = torch.aten.mul.Tensor %116, %121 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,64],f32>
%123 = torch.aten.mul.Tensor %122, %arg1 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[2,8,64],f32>
%int0_202 = torch.constant.int 0
%int1_203 = torch.constant.int 1
%124 = torch.aten.transpose.int %arg10, %int0_202, %int1_203 : !torch.vtensor<[256,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,256],f32>
%int16_204 = torch.constant.int 16
%int64_205 = torch.constant.int 64
%125 = torch.prim.ListConstruct %int16_204, %int64_205 : (!torch.int, !torch.int) -> !torch.list<int>
%126 = torch.aten.view %123, %125 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%127 = torch.aten.mm %126, %124 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,256],f32> -> !torch.vtensor<[16,256],f32>
%int2_206 = torch.constant.int 2
%int8_207 = torch.constant.int 8
%int256 = torch.constant.int 256
%128 = torch.prim.ListConstruct %int2_206, %int8_207, %int256 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%129 = torch.aten.view %127, %128 : !torch.vtensor<[16,256],f32>, !torch.list<int> -> !torch.vtensor<[2,8,256],f32>
%130 = torch.aten.silu %129 : !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>
%int0_208 = torch.constant.int 0
%int1_209 = torch.constant.int 1
%131 = torch.aten.transpose.int %arg11, %int0_208, %int1_209 : !torch.vtensor<[256,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,256],f32>
%int16_210 = torch.constant.int 16
%int64_211 = torch.constant.int 64
%132 = torch.prim.ListConstruct %int16_210, %int64_211 : (!torch.int, !torch.int) -> !torch.list<int>
%133 = torch.aten.view %123, %132 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%134 = torch.aten.mm %133, %131 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,256],f32> -> !torch.vtensor<[16,256],f32>
%int2_212 = torch.constant.int 2
%int8_213 = torch.constant.int 8
%int256_214 = torch.constant.int 256
%135 = torch.prim.ListConstruct %int2_212, %int8_213, %int256_214 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%136 = torch.aten.view %134, %135 : !torch.vtensor<[16,256],f32>, !torch.list<int> -> !torch.vtensor<[2,8,256],f32>
%137 = torch.aten.mul.Tensor %130, %136 : !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>
%int0_215 = torch.constant.int 0
%int1_216 = torch.constant.int 1
%138 = torch.aten.transpose.int %arg12, %int0_215, %int1_216 : !torch.vtensor<[64,256],f32>, !torch.int, !torch.int -> !torch.vtensor<[256,64],f32>
%int16_217 = torch.constant.int 16
%int256_218 = torch.constant.int 256
%139 = torch.prim.ListConstruct %int16_217, %int256_218 : (!torch.int, !torch.int) -> !torch.list<int>
%140 = torch.aten.view %137, %139 : !torch.vtensor<[2,8,256],f32>, !torch.list<int> -> !torch.vtensor<[16,256],f32>
%141 = torch.aten.mm %140, %138 : !torch.vtensor<[16,256],f32>, !torch.vtensor<[256,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_219 = torch.constant.int 2
%int8_220 = torch.constant.int 8
%int64_221 = torch.constant.int 64
%142 = torch.prim.ListConstruct %int2_219, %int8_220, %int64_221 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%143 = torch.aten.view %141, %142 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int1_222 = torch.constant.int 1
%144 = torch.aten.add.Tensor %116, %143, %int1_222 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int2_223 = torch.constant.int 2
%145 = torch.aten.pow.Tensor_Scalar %144, %int2_223 : !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int-1_224 = torch.constant.int -1
%146 = torch.prim.ListConstruct %int-1_224 : (!torch.int) -> !torch.list<int>
%true_225 = torch.constant.bool true
%none_226 = torch.constant.none
%147 = torch.aten.mean.dim %145, %146, %true_225, %none_226 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%float1.000000e-05_227 = torch.constant.float 1.000000e-05
%int1_228 = torch.constant.int 1
%148 = torch.aten.add.Scalar %147, %float1.000000e-05_227, %int1_228 : !torch.vtensor<[2,8,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[2,8,1],f32>
%149 = torch.aten.rsqrt %148 : !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,1],f32>
%150 = torch.aten.mul.Tensor %144, %149 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,64],f32>
%151 = torch.aten.mul.Tensor %150, %arg2 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[2,8,64],f32>
%int0_229 = torch.constant.int 0
%int1_230 = torch.constant.int 1
%152 = torch.aten.transpose.int %arg13, %int0_229, %int1_230 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16_231 = torch.constant.int 16
%int64_232 = torch.constant.int 64
%153 = torch.prim.ListConstruct %int16_231, %int64_232 : (!torch.int, !torch.int) -> !torch.list<int>
%154 = torch.aten.view %151, %153 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%155 = torch.aten.mm %154, %152 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_233 = torch.constant.int 2
%int8_234 = torch.constant.int 8
%int64_235 = torch.constant.int 64
%156 = torch.prim.ListConstruct %int2_233, %int8_234, %int64_235 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%157 = torch.aten.view %155, %156 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int0_236 = torch.constant.int 0
%int1_237 = torch.constant.int 1
%158 = torch.aten.transpose.int %arg14, %int0_236, %int1_237 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16_238 = torch.constant.int 16
%int64_239 = torch.constant.int 64
%159 = torch.prim.ListConstruct %int16_238, %int64_239 : (!torch.int, !torch.int) -> !torch.list<int>
%160 = torch.aten.view %151, %159 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%161 = torch.aten.mm %160, %158 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_240 = torch.constant.int 2
%int8_241 = torch.constant.int 8
%int64_242 = torch.constant.int 64
%162 = torch.prim.ListConstruct %int2_240, %int8_241, %int64_242 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%163 = torch.aten.view %161, %162 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int0_243 = torch.constant.int 0
%int1_244 = torch.constant.int 1
%164 = torch.aten.transpose.int %arg15, %int0_243, %int1_244 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16_245 = torch.constant.int 16
%int64_246 = torch.constant.int 64
%165 = torch.prim.ListConstruct %int16_245, %int64_246 : (!torch.int, !torch.int) -> !torch.list<int>
%166 = torch.aten.view %151, %165 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%167 = torch.aten.mm %166, %164 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_247 = torch.constant.int 2
%int8_248 = torch.constant.int 8
%int64_249 = torch.constant.int 64
%168 = torch.prim.ListConstruct %int2_247, %int8_248, %int64_249 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%169 = torch.aten.view %167, %168 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int2_250 = torch.constant.int 2
%int8_251 = torch.constant.int 8
%int4_252 = torch.constant.int 4
%int16_253 = torch.constant.int 16
%170 = torch.prim.ListConstruct %int2_250, %int8_251, %int4_252, %int16_253 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%171 = torch.aten.view %157, %170 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int2_254 = torch.constant.int 2
%int8_255 = torch.constant.int 8
%int4_256 = torch.constant.int 4
%int16_257 = torch.constant.int 16
%172 = torch.prim.ListConstruct %int2_254, %int8_255, %int4_256, %int16_257 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%173 = torch.aten.view %163, %172 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int2_258 = torch.constant.int 2
%int8_259 = torch.constant.int 8
%int4_260 = torch.constant.int 4
%int16_261 = torch.constant.int 16
%174 = torch.prim.ListConstruct %int2_258, %int8_259, %int4_260, %int16_261 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%175 = torch.aten.view %169, %174 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int2_262 = torch.constant.int 2
%int8_263 = torch.constant.int 8
%int4_264 = torch.constant.int 4
%int-1_265 = torch.constant.int -1
%int2_266 = torch.constant.int 2
%176 = torch.prim.ListConstruct %int2_262, %int8_263, %int4_264, %int-1_265, %int2_266 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%177 = torch.aten.view %171, %176 : !torch.vtensor<[2,8,4,16],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,8,2],f32>
%178 = torch.aten.view_as_complex %177 : !torch.vtensor<[2,8,4,8,2],f32> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%int2_267 = torch.constant.int 2
%int8_268 = torch.constant.int 8
%int4_269 = torch.constant.int 4
%int-1_270 = torch.constant.int -1
%int2_271 = torch.constant.int 2
%179 = torch.prim.ListConstruct %int2_267, %int8_268, %int4_269, %int-1_270, %int2_271 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%180 = torch.aten.view %173, %179 : !torch.vtensor<[2,8,4,16],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,8,2],f32>
%181 = torch.aten.view_as_complex %180 : !torch.vtensor<[2,8,4,8,2],f32> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%int1_272 = torch.constant.int 1
%int8_273 = torch.constant.int 8
%int1_274 = torch.constant.int 1
%int8_275 = torch.constant.int 8
%182 = torch.prim.ListConstruct %int1_272, %int8_273, %int1_274, %int8_275 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%183 = torch.aten.view %5, %182 : !torch.vtensor<[8,8],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,8,1,8],complex<f32>>
%184 = torch.aten.mul.Tensor %178, %183 : !torch.vtensor<[2,8,4,8],complex<f32>>, !torch.vtensor<[1,8,1,8],complex<f32>> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%185 = torch.aten.view_as_real %184 : !torch.vtensor<[2,8,4,8],complex<f32>> -> !torch.vtensor<[2,8,4,8,2],f32>
%int2_276 = torch.constant.int 2
%int8_277 = torch.constant.int 8
%int4_278 = torch.constant.int 4
%int16_279 = torch.constant.int 16
%186 = torch.prim.ListConstruct %int2_276, %int8_277, %int4_278, %int16_279 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%187 = torch.aten.view %185, %186 : !torch.vtensor<[2,8,4,8,2],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%188 = torch.aten.mul.Tensor %181, %183 : !torch.vtensor<[2,8,4,8],complex<f32>>, !torch.vtensor<[1,8,1,8],complex<f32>> -> !torch.vtensor<[2,8,4,8],complex<f32>>
%189 = torch.aten.view_as_real %188 : !torch.vtensor<[2,8,4,8],complex<f32>> -> !torch.vtensor<[2,8,4,8,2],f32>
%int2_280 = torch.constant.int 2
%int8_281 = torch.constant.int 8
%int4_282 = torch.constant.int 4
%int16_283 = torch.constant.int 16
%190 = torch.prim.ListConstruct %int2_280, %int8_281, %int4_282, %int16_283 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%191 = torch.aten.view %189, %190 : !torch.vtensor<[2,8,4,8,2],f32>, !torch.list<int> -> !torch.vtensor<[2,8,4,16],f32>
%int0_284 = torch.constant.int 0
%int0_285 = torch.constant.int 0
%int2_286 = torch.constant.int 2
%int1_287 = torch.constant.int 1
%192 = torch.aten.slice.Tensor %2, %int0_284, %int0_285, %int2_286, %int1_287 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_288 = torch.constant.int 1
%int0_289 = torch.constant.int 0
%int8_290 = torch.constant.int 8
%int1_291 = torch.constant.int 1
%193 = torch.aten.slice.Tensor %192, %int1_288, %int0_289, %int8_290, %int1_291 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%false_292 = torch.constant.bool false
%194 = torch.aten.copy %193, %191, %false_292 : !torch.vtensor<[2,8,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.bool -> !torch.vtensor<[2,8,4,16],f32>
%int0_293 = torch.constant.int 0
%int0_294 = torch.constant.int 0
%int2_295 = torch.constant.int 2
%int1_296 = torch.constant.int 1
%195 = torch.aten.slice.Tensor %2, %int0_293, %int0_294, %int2_295, %int1_296 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_297 = torch.constant.int 1
%int0_298 = torch.constant.int 0
%int8_299 = torch.constant.int 8
%int1_300 = torch.constant.int 1
%196 = torch.aten.slice_scatter %195, %194, %int1_297, %int0_298, %int8_299, %int1_300 : !torch.vtensor<[2,2048,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int0_301 = torch.constant.int 0
%int0_302 = torch.constant.int 0
%int2_303 = torch.constant.int 2
%int1_304 = torch.constant.int 1
%197 = torch.aten.slice_scatter %2, %196, %int0_301, %int0_302, %int2_303, %int1_304 : !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[32,2048,4,16],f32>
%int0_305 = torch.constant.int 0
%int0_306 = torch.constant.int 0
%int2_307 = torch.constant.int 2
%int1_308 = torch.constant.int 1
%198 = torch.aten.slice.Tensor %3, %int0_305, %int0_306, %int2_307, %int1_308 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_309 = torch.constant.int 1
%int0_310 = torch.constant.int 0
%int8_311 = torch.constant.int 8
%int1_312 = torch.constant.int 1
%199 = torch.aten.slice.Tensor %198, %int1_309, %int0_310, %int8_311, %int1_312 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%false_313 = torch.constant.bool false
%200 = torch.aten.copy %199, %175, %false_313 : !torch.vtensor<[2,8,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.bool -> !torch.vtensor<[2,8,4,16],f32>
%int0_314 = torch.constant.int 0
%int0_315 = torch.constant.int 0
%int2_316 = torch.constant.int 2
%int1_317 = torch.constant.int 1
%201 = torch.aten.slice.Tensor %3, %int0_314, %int0_315, %int2_316, %int1_317 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_318 = torch.constant.int 1
%int0_319 = torch.constant.int 0
%int8_320 = torch.constant.int 8
%int1_321 = torch.constant.int 1
%202 = torch.aten.slice_scatter %201, %200, %int1_318, %int0_319, %int8_320, %int1_321 : !torch.vtensor<[2,2048,4,16],f32>, !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int0_322 = torch.constant.int 0
%int0_323 = torch.constant.int 0
%int2_324 = torch.constant.int 2
%int1_325 = torch.constant.int 1
%203 = torch.aten.slice_scatter %3, %202, %int0_322, %int0_323, %int2_324, %int1_325 : !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[32,2048,4,16],f32>
%int1_326 = torch.constant.int 1
%int2_327 = torch.constant.int 2
%204 = torch.aten.transpose.int %187, %int1_326, %int2_327 : !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int2_328 = torch.constant.int 2
%int4_329 = torch.constant.int 4
%int8_330 = torch.constant.int 8
%int16_331 = torch.constant.int 16
%205 = torch.prim.ListConstruct %int2_328, %int4_329, %int8_330, %int16_331 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_332 = torch.constant.bool false
%206 = torch.aten.expand %204, %205, %false_332 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%int0_333 = torch.constant.int 0
%207 = torch.aten.clone %206, %int0_333 : !torch.vtensor<[2,4,8,16],f32>, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int8_334 = torch.constant.int 8
%int8_335 = torch.constant.int 8
%int16_336 = torch.constant.int 16
%208 = torch.prim.ListConstruct %int8_334, %int8_335, %int16_336 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%209 = torch.aten._unsafe_view %207, %208 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%int0_337 = torch.constant.int 0
%int0_338 = torch.constant.int 0
%int2_339 = torch.constant.int 2
%int1_340 = torch.constant.int 1
%210 = torch.aten.slice.Tensor %197, %int0_337, %int0_338, %int2_339, %int1_340 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_341 = torch.constant.int 1
%int0_342 = torch.constant.int 0
%int8_343 = torch.constant.int 8
%int1_344 = torch.constant.int 1
%211 = torch.aten.slice.Tensor %210, %int1_341, %int0_342, %int8_343, %int1_344 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int1_345 = torch.constant.int 1
%int2_346 = torch.constant.int 2
%212 = torch.aten.transpose.int %211, %int1_345, %int2_346 : !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int2_347 = torch.constant.int 2
%int3_348 = torch.constant.int 3
%213 = torch.aten.transpose.int %212, %int2_347, %int3_348 : !torch.vtensor<[2,4,8,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,16,8],f32>
%int2_349 = torch.constant.int 2
%int4_350 = torch.constant.int 4
%int16_351 = torch.constant.int 16
%int8_352 = torch.constant.int 8
%214 = torch.prim.ListConstruct %int2_349, %int4_350, %int16_351, %int8_352 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_353 = torch.constant.bool false
%215 = torch.aten.expand %213, %214, %false_353 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,16,8],f32>
%int0_354 = torch.constant.int 0
%216 = torch.aten.clone %215, %int0_354 : !torch.vtensor<[2,4,16,8],f32>, !torch.int -> !torch.vtensor<[2,4,16,8],f32>
%int8_355 = torch.constant.int 8
%int16_356 = torch.constant.int 16
%int8_357 = torch.constant.int 8
%217 = torch.prim.ListConstruct %int8_355, %int16_356, %int8_357 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%218 = torch.aten._unsafe_view %216, %217 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int> -> !torch.vtensor<[8,16,8],f32>
%219 = torch.aten.bmm %209, %218 : !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32> -> !torch.vtensor<[8,8,8],f32>
%int2_358 = torch.constant.int 2
%int4_359 = torch.constant.int 4
%int8_360 = torch.constant.int 8
%int8_361 = torch.constant.int 8
%220 = torch.prim.ListConstruct %int2_358, %int4_359, %int8_360, %int8_361 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%221 = torch.aten.view %219, %220 : !torch.vtensor<[8,8,8],f32>, !torch.list<int> -> !torch.vtensor<[2,4,8,8],f32>
%float4.000000e00_362 = torch.constant.float 4.000000e+00
%222 = torch.aten.div.Scalar %221, %float4.000000e00_362 : !torch.vtensor<[2,4,8,8],f32>, !torch.float -> !torch.vtensor<[2,4,8,8],f32>
%int1_363 = torch.constant.int 1
%223 = torch.aten.add.Tensor %222, %10, %int1_363 : !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[1,1,8,8],f32>, !torch.int -> !torch.vtensor<[2,4,8,8],f32>
%int-1_364 = torch.constant.int -1
%false_365 = torch.constant.bool false
%224 = torch.aten._softmax %223, %int-1_364, %false_365 : !torch.vtensor<[2,4,8,8],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%int2_366 = torch.constant.int 2
%int4_367 = torch.constant.int 4
%int8_368 = torch.constant.int 8
%int8_369 = torch.constant.int 8
%225 = torch.prim.ListConstruct %int2_366, %int4_367, %int8_368, %int8_369 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_370 = torch.constant.bool false
%226 = torch.aten.expand %224, %225, %false_370 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
%int8_371 = torch.constant.int 8
%int8_372 = torch.constant.int 8
%int8_373 = torch.constant.int 8
%227 = torch.prim.ListConstruct %int8_371, %int8_372, %int8_373 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%228 = torch.aten.view %226, %227 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int> -> !torch.vtensor<[8,8,8],f32>
%int0_374 = torch.constant.int 0
%int0_375 = torch.constant.int 0
%int2_376 = torch.constant.int 2
%int1_377 = torch.constant.int 1
%229 = torch.aten.slice.Tensor %203, %int0_374, %int0_375, %int2_376, %int1_377 : !torch.vtensor<[32,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2048,4,16],f32>
%int1_378 = torch.constant.int 1
%int0_379 = torch.constant.int 0
%int8_380 = torch.constant.int 8
%int1_381 = torch.constant.int 1
%230 = torch.aten.slice.Tensor %229, %int1_378, %int0_379, %int8_380, %int1_381 : !torch.vtensor<[2,2048,4,16],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int1_382 = torch.constant.int 1
%int2_383 = torch.constant.int 2
%231 = torch.aten.transpose.int %230, %int1_382, %int2_383 : !torch.vtensor<[2,8,4,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int2_384 = torch.constant.int 2
%int4_385 = torch.constant.int 4
%int8_386 = torch.constant.int 8
%int16_387 = torch.constant.int 16
%232 = torch.prim.ListConstruct %int2_384, %int4_385, %int8_386, %int16_387 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%false_388 = torch.constant.bool false
%233 = torch.aten.expand %231, %232, %false_388 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
%int0_389 = torch.constant.int 0
%234 = torch.aten.clone %233, %int0_389 : !torch.vtensor<[2,4,8,16],f32>, !torch.int -> !torch.vtensor<[2,4,8,16],f32>
%int8_390 = torch.constant.int 8
%int8_391 = torch.constant.int 8
%int16_392 = torch.constant.int 16
%235 = torch.prim.ListConstruct %int8_390, %int8_391, %int16_392 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%236 = torch.aten._unsafe_view %234, %235 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int> -> !torch.vtensor<[8,8,16],f32>
%237 = torch.aten.bmm %228, %236 : !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32> -> !torch.vtensor<[8,8,16],f32>
%int2_393 = torch.constant.int 2
%int4_394 = torch.constant.int 4
%int8_395 = torch.constant.int 8
%int16_396 = torch.constant.int 16
%238 = torch.prim.ListConstruct %int2_393, %int4_394, %int8_395, %int16_396 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%239 = torch.aten.view %237, %238 : !torch.vtensor<[8,8,16],f32>, !torch.list<int> -> !torch.vtensor<[2,4,8,16],f32>
%int1_397 = torch.constant.int 1
%int2_398 = torch.constant.int 2
%240 = torch.aten.transpose.int %239, %int1_397, %int2_398 : !torch.vtensor<[2,4,8,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int0_399 = torch.constant.int 0
%241 = torch.aten.clone %240, %int0_399 : !torch.vtensor<[2,8,4,16],f32>, !torch.int -> !torch.vtensor<[2,8,4,16],f32>
%int2_400 = torch.constant.int 2
%int8_401 = torch.constant.int 8
%int-1_402 = torch.constant.int -1
%242 = torch.prim.ListConstruct %int2_400, %int8_401, %int-1_402 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%243 = torch.aten.view %241, %242 : !torch.vtensor<[2,8,4,16],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int0_403 = torch.constant.int 0
%int1_404 = torch.constant.int 1
%244 = torch.aten.transpose.int %arg16, %int0_403, %int1_404 : !torch.vtensor<[64,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,64],f32>
%int16_405 = torch.constant.int 16
%int64_406 = torch.constant.int 64
%245 = torch.prim.ListConstruct %int16_405, %int64_406 : (!torch.int, !torch.int) -> !torch.list<int>
%246 = torch.aten.view %243, %245 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%247 = torch.aten.mm %246, %244 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_407 = torch.constant.int 2
%int8_408 = torch.constant.int 8
%int64_409 = torch.constant.int 64
%248 = torch.prim.ListConstruct %int2_407, %int8_408, %int64_409 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%249 = torch.aten.view %247, %248 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int1_410 = torch.constant.int 1
%250 = torch.aten.add.Tensor %144, %249, %int1_410 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int2_411 = torch.constant.int 2
%251 = torch.aten.pow.Tensor_Scalar %250, %int2_411 : !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int-1_412 = torch.constant.int -1
%252 = torch.prim.ListConstruct %int-1_412 : (!torch.int) -> !torch.list<int>
%true_413 = torch.constant.bool true
%none_414 = torch.constant.none
%253 = torch.aten.mean.dim %251, %252, %true_413, %none_414 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%float1.000000e-05_415 = torch.constant.float 1.000000e-05
%int1_416 = torch.constant.int 1
%254 = torch.aten.add.Scalar %253, %float1.000000e-05_415, %int1_416 : !torch.vtensor<[2,8,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[2,8,1],f32>
%255 = torch.aten.rsqrt %254 : !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,1],f32>
%256 = torch.aten.mul.Tensor %250, %255 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,64],f32>
%257 = torch.aten.mul.Tensor %256, %arg3 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[2,8,64],f32>
%int0_417 = torch.constant.int 0
%int1_418 = torch.constant.int 1
%258 = torch.aten.transpose.int %arg17, %int0_417, %int1_418 : !torch.vtensor<[256,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,256],f32>
%int16_419 = torch.constant.int 16
%int64_420 = torch.constant.int 64
%259 = torch.prim.ListConstruct %int16_419, %int64_420 : (!torch.int, !torch.int) -> !torch.list<int>
%260 = torch.aten.view %257, %259 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%261 = torch.aten.mm %260, %258 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,256],f32> -> !torch.vtensor<[16,256],f32>
%int2_421 = torch.constant.int 2
%int8_422 = torch.constant.int 8
%int256_423 = torch.constant.int 256
%262 = torch.prim.ListConstruct %int2_421, %int8_422, %int256_423 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%263 = torch.aten.view %261, %262 : !torch.vtensor<[16,256],f32>, !torch.list<int> -> !torch.vtensor<[2,8,256],f32>
%264 = torch.aten.silu %263 : !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>
%int0_424 = torch.constant.int 0
%int1_425 = torch.constant.int 1
%265 = torch.aten.transpose.int %arg18, %int0_424, %int1_425 : !torch.vtensor<[256,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,256],f32>
%int16_426 = torch.constant.int 16
%int64_427 = torch.constant.int 64
%266 = torch.prim.ListConstruct %int16_426, %int64_427 : (!torch.int, !torch.int) -> !torch.list<int>
%267 = torch.aten.view %257, %266 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%268 = torch.aten.mm %267, %265 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,256],f32> -> !torch.vtensor<[16,256],f32>
%int2_428 = torch.constant.int 2
%int8_429 = torch.constant.int 8
%int256_430 = torch.constant.int 256
%269 = torch.prim.ListConstruct %int2_428, %int8_429, %int256_430 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%270 = torch.aten.view %268, %269 : !torch.vtensor<[16,256],f32>, !torch.list<int> -> !torch.vtensor<[2,8,256],f32>
%271 = torch.aten.mul.Tensor %264, %270 : !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>
%int0_431 = torch.constant.int 0
%int1_432 = torch.constant.int 1
%272 = torch.aten.transpose.int %arg19, %int0_431, %int1_432 : !torch.vtensor<[64,256],f32>, !torch.int, !torch.int -> !torch.vtensor<[256,64],f32>
%int16_433 = torch.constant.int 16
%int256_434 = torch.constant.int 256
%273 = torch.prim.ListConstruct %int16_433, %int256_434 : (!torch.int, !torch.int) -> !torch.list<int>
%274 = torch.aten.view %271, %273 : !torch.vtensor<[2,8,256],f32>, !torch.list<int> -> !torch.vtensor<[16,256],f32>
%275 = torch.aten.mm %274, %272 : !torch.vtensor<[16,256],f32>, !torch.vtensor<[256,64],f32> -> !torch.vtensor<[16,64],f32>
%int2_435 = torch.constant.int 2
%int8_436 = torch.constant.int 8
%int64_437 = torch.constant.int 64
%276 = torch.prim.ListConstruct %int2_435, %int8_436, %int64_437 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%277 = torch.aten.view %275, %276 : !torch.vtensor<[16,64],f32>, !torch.list<int> -> !torch.vtensor<[2,8,64],f32>
%int1_438 = torch.constant.int 1
%278 = torch.aten.add.Tensor %250, %277, %int1_438 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int2_439 = torch.constant.int 2
%279 = torch.aten.pow.Tensor_Scalar %278, %int2_439 : !torch.vtensor<[2,8,64],f32>, !torch.int -> !torch.vtensor<[2,8,64],f32>
%int-1_440 = torch.constant.int -1
%280 = torch.prim.ListConstruct %int-1_440 : (!torch.int) -> !torch.list<int>
%true_441 = torch.constant.bool true
%none_442 = torch.constant.none
%281 = torch.aten.mean.dim %279, %280, %true_441, %none_442 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
%float1.000000e-05_443 = torch.constant.float 1.000000e-05
%int1_444 = torch.constant.int 1
%282 = torch.aten.add.Scalar %281, %float1.000000e-05_443, %int1_444 : !torch.vtensor<[2,8,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[2,8,1],f32>
%283 = torch.aten.rsqrt %282 : !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,1],f32>
%284 = torch.aten.mul.Tensor %278, %283 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32> -> !torch.vtensor<[2,8,64],f32>
%285 = torch.aten.mul.Tensor %284, %arg4 : !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[2,8,64],f32>
%int0_445 = torch.constant.int 0
%int1_446 = torch.constant.int 1
%286 = torch.aten.transpose.int %arg20, %int0_445, %int1_446 : !torch.vtensor<[16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,16],f32>
%int16_447 = torch.constant.int 16
%int64_448 = torch.constant.int 64
%287 = torch.prim.ListConstruct %int16_447, %int64_448 : (!torch.int, !torch.int) -> !torch.list<int>
%288 = torch.aten.view %285, %287 : !torch.vtensor<[2,8,64],f32>, !torch.list<int> -> !torch.vtensor<[16,64],f32>
%289 = torch.aten.mm %288, %286 : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,16],f32> -> !torch.vtensor<[16,16],f32>
%int2_449 = torch.constant.int 2
%int8_450 = torch.constant.int 8
%int16_451 = torch.constant.int 16
%290 = torch.prim.ListConstruct %int2_449, %int8_450, %int16_451 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%291 = torch.aten.view %289, %290 : !torch.vtensor<[16,16],f32>, !torch.list<int> -> !torch.vtensor<[2,8,16],f32>
return %63, %69, %197, %203, %291, %arg0, %arg1, %arg2, %arg3, %arg4, %arg26, %4, %15, %16, %18, %20, %24, %26, %30, %32, %49, %75, %84, %90, %94, %102, %110, %112, %116, %121, %122, %124, %126, %129, %130, %131, %133, %136, %138, %140, %144, %149, %150, %152, %154, %158, %160, %164, %166, %183, %209, %218, %224, %228, %236, %244, %246, %250, %255, %256, %258, %260, %263, %264, %265, %267, %270, %272, %274, %278, %283, %284, %286, %288 : !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[32,2048,4,16],f32>, !torch.vtensor<[2,8,16],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[2,8],si64>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[1,8,1,8],complex<f32>>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[8,16,8],f32>, !torch.vtensor<[2,4,8,8],f32>, !torch.vtensor<[8,8,8],f32>, !torch.vtensor<[8,8,16],f32>, !torch.vtensor<[64,64],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[64,256],f32>, !torch.vtensor<[16,64],f32>, !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[256,64],f32>, !torch.vtensor<[16,256],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,64],f32>, !torch.vtensor<[64,16],f32>, !torch.vtensor<[16,64],f32>
}
}
@AmosLewis
Copy link
Author

AmosLewis commented Dec 5, 2023

failed print https://gist.github.com/AmosLewis/259cd141333b33ec7df8a2f60eb9bf4b

  • torch.aten.empty_strided
    %67 = torch.aten.empty_strided %65, %66, %int6, %int0_62, %cpu, %false_63 : !torch.list<int>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,1,8,8],f32>

  • torch.aten.mean.dim
    %80 = torch.aten.mean.dim %78, %79, %true, %none_86 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
    %712 = torch.aten.mean.dim %710, %711, %true_1098, %none_1099 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
    %795 = torch.aten.mean.dim %793, %794, %true_1275, %none_1276 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
    %1427 = torch.aten.mean.dim %1425, %1426, %true_2298, %none_2299 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>
    %1510 = torch.aten.mean.dim %1508, %1509, %true_2477, %none_2478 : !torch.vtensor<[2,8,64],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f32>

  • torch.aten.expand
    %489 = torch.aten.expand %487, %488, %false_710 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
    %557 = torch.aten.expand %555, %556, %false_808 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,16,8],f32>
    %590 = torch.aten.expand %588, %589, %false_879 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
    %655 = torch.aten.expand %653, %654, %false_965 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
    %1204 = torch.aten.expand %1202, %1203, %false_1908 : !torch.vtensor<[2,4,8,16],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>
    %1272 = torch.aten.expand %1270, %1271, %false_2007 : !torch.vtensor<[2,4,16,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,16,8],f32>
    %1305 = torch.aten.expand %1303, %1304, %false_2079 : !torch.vtensor<[2,4,8,8],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>

  • torch.aten._softmax
    %588 = torch.aten._softmax %587, %int-1_873, %false_874 : !torch.vtensor<[2,4,8,8],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>
    %1303 = torch.aten._softmax %1302, %int-1_2073, %false_2074 : !torch.vtensor<[2,4,8,8],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4,8,8],f32>

  • torch.aten.silu
    %745 = torch.aten.silu %744 : !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>
    %1460 = torch.aten.silu %1459 : !torch.vtensor<[2,8,256],f32> -> !torch.vtensor<[2,8,256],f32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment