Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AmosLewis/3d7872f07ddd3f9e6c07b6c311a900f2 to your computer and use it in GitHub Desktop.
Save AmosLewis/3d7872f07ddd3f9e6c07b6c311a900f2 to your computer and use it in GitHub Desktop.
#loc = loc(unknown)
module attributes {torch.debug_module_name = "_lambda"} {
func.func private @__torch__.torch.fx.graph_module._lambda.__code_getter(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> loc(unknown)) -> !torch.str {
%96 = torch.prim.GetAttr %arg0["_code"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.str loc(#loc)
return %96 : !torch.str loc(#loc)
} loc(#loc)
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> loc(unknown), %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,128],si64>} loc(unknown)) -> !torch.tensor {
%true_0 = torch.constant.bool true loc(#loc1)
%int11 = torch.constant.int 11 loc(#loc2)
%int-2 = torch.constant.int -2 loc(#loc3)
%none_1 = torch.constant.none loc(#loc)
%false = torch.constant.bool false loc(#loc4)
%cpu = torch.constant.device "cpu" loc(#loc)
%int4 = torch.constant.int 4 loc(#loc5)
%int-1 = torch.constant.int -1 loc(#loc6)
%int1 = torch.constant.int 1 loc(#loc7)
%int128 = torch.constant.int 128 loc(#loc8)
%int0 = torch.constant.int 0 loc(#loc9)
%int768 = torch.constant.int 768 loc(#loc10)
%float1.000000e-05 = torch.constant.float 1.000000e-05 loc(#loc11)
%int2 = torch.constant.int 2 loc(#loc12)
%int2304 = torch.constant.int 2304 loc(#loc13)
%int294912 = torch.constant.int 294912 loc(#loc14)
%int1536 = torch.constant.int 1536 loc(#loc15)
%int12 = torch.constant.int 12 loc(#loc16)
%int64 = torch.constant.int 64 loc(#loc17)
%int3 = torch.constant.int 3 loc(#loc18)
%int1024 = torch.constant.int 1024 loc(#loc19)
%int1048576 = torch.constant.int 1048576 loc(#loc20)
%int3072 = torch.constant.int 3072 loc(#loc21)
%float5.000000e-01 = torch.constant.float 5.000000e-01 loc(#loc22)
%float3.000000e00 = torch.constant.float 3.000000e+00 loc(#loc23)
%float4.471500e-02 = torch.constant.float 4.471500e-02 loc(#loc24)
%float7.978850e-01 = torch.constant.float 0.79788456080286541 loc(#loc25)
%float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc26)
%96 = torch.prim.ListConstruct %int-1, %int128 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%97 = torch.aten.view %arg1, %96 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc27)
%98 = torch.aten.arange.start %int0, %int128, %int4, %none_1, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.tensor loc(#loc28)
%99 = torch.aten.unsqueeze %98, %int0 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc29)
%100 = torch.prim.ListConstruct %int-1, %int128 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%101 = torch.aten.view %99, %100 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc30)
%102 = torch.prim.GetAttr %arg0["_param_constant0"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%103 = torch.aten.embedding %102, %97, %int-1, %false, %false : !torch.tensor, !torch.tensor, !torch.int, !torch.bool, !torch.bool -> !torch.tensor loc(#loc31)
%104 = torch.prim.GetAttr %arg0["_param_constant1"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%105 = torch.aten.embedding %104, %101, %int-1, %false, %false : !torch.tensor, !torch.tensor, !torch.int, !torch.bool, !torch.bool -> !torch.tensor loc(#loc32)
%106 = torch.aten.add.Tensor %103, %105, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc33)
%107 = torch.prim.GetAttr %arg0["_param_constant2"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%108 = torch.prim.GetAttr %arg0["_param_constant3"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%109 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0, %result1, %result2 = torch.aten.native_layer_norm %106, %109, %107, %108, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc34)
%110 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%111 = torch.aten.view %result0, %110 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc35)
%112 = torch.prim.GetAttr %arg0["_param_constant4"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%113 = torch.prim.GetAttr %arg0["_param_constant5"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%114 = torch.aten.addmm %112, %111, %113, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc36)
%115 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%116 = torch.aten.view %114, %115 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc37)
%117 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%118 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%119 = torch.aten.as_strided %116, %117, %118, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc38)
%120 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%121 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%122 = torch.aten.as_strided %116, %120, %121, %int768 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc39)
%123 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%124 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%125 = torch.aten.as_strided %116, %123, %124, %int1536 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc40)
%126 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%127 = torch.aten.view %119, %126 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc41)
%128 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%129 = torch.aten.permute %127, %128 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc42)
%130 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%131 = torch.aten.view %122, %130 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc43)
%132 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%133 = torch.aten.permute %131, %132 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc44)
%134 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%135 = torch.aten.view %125, %134 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc45)
%136 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%137 = torch.aten.permute %135, %136 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc46)
%138 = torch.aten.transpose.int %133, %int-1, %int-2 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc47)
%139 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%140 = torch.aten.expand %129, %139, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc48)
%141 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%142 = torch.aten.view %140, %141 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc49)
%143 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%144 = torch.aten.expand %138, %143, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc50)
%145 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%146 = torch.aten.view %144, %145 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc51)
%147 = torch.aten.bmm %142, %146 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc52)
%148 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%149 = torch.aten._unsafe_view %147, %148 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc53)
%150 = torch.prim.GetAttr %arg0["_tensor_constant0"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%151 = torch.aten.lift_fresh_copy %150 : !torch.tensor -> !torch.tensor loc(#loc54)
%152 = torch.aten.div.Tensor %149, %151 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc55)
%153 = torch.prim.GetAttr %arg0["_tensor_constant1"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%154 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%155 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%156 = torch.aten.as_strided %153, %154, %155, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc56)
%157 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%158 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%159 = torch.aten.as_strided %156, %157, %158, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc57)
%160 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%161 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%162 = torch.aten.as_strided %159, %160, %161, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc58)
%163 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%164 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%165 = torch.aten.as_strided %162, %163, %164, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc59)
%166 = torch.prims.convert_element_type %165, %int11 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc60)
%167 = torch.prim.GetAttr %arg0["_tensor_constant2"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%168 = torch.aten.lift_fresh_copy %167 : !torch.tensor -> !torch.tensor loc(#loc61)
%169 = torch.aten.where.self %166, %152, %168 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc62)
%170 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%171 = torch.aten.amax %169, %170, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc63)
%172 = torch.aten.sub.Tensor %169, %171, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc64)
%173 = torch.aten.exp %172 : !torch.tensor -> !torch.tensor loc(#loc65)
%174 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%175 = torch.aten.sum.dim_IntList %173, %174, %true_0, %none_1 : !torch.tensor, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor loc(#loc66)
%176 = torch.aten.div.Tensor %173, %175 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc67)
%177 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%178 = torch.aten.expand %176, %177, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc68)
%179 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%180 = torch.aten.view %178, %179 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc69)
%181 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%182 = torch.aten.expand %137, %181, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc70)
%183 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%184 = torch.aten.view %182, %183 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc71)
%185 = torch.aten.bmm %180, %184 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc72)
%186 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%187 = torch.aten._unsafe_view %185, %186 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc73)
%188 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%189 = torch.aten.permute %187, %188 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc74)
%190 = torch.aten.clone %189, %int0 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc75)
%191 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%192 = torch.aten.view %190, %191 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc76)
%193 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%194 = torch.aten.view %192, %193 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc77)
%195 = torch.prim.GetAttr %arg0["_param_constant6"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%196 = torch.prim.GetAttr %arg0["_param_constant7"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%197 = torch.aten.addmm %195, %194, %196, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc78)
%198 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%199 = torch.aten.view %197, %198 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc79)
%200 = torch.aten.add.Tensor %199, %106, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc80)
%201 = torch.prim.GetAttr %arg0["_param_constant8"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%202 = torch.prim.GetAttr %arg0["_param_constant9"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%203 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_2, %result1_3, %result2_4 = torch.aten.native_layer_norm %200, %203, %201, %202, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc81)
%204 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%205 = torch.aten.view %result0_2, %204 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc82)
%206 = torch.prim.GetAttr %arg0["_param_constant10"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%207 = torch.prim.GetAttr %arg0["_param_constant11"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%208 = torch.aten.addmm %206, %205, %207, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc83)
%209 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%210 = torch.aten.view %208, %209 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc84)
%211 = torch.aten.mul.Scalar %210, %float5.000000e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc85)
%212 = torch.aten.pow.Tensor_Scalar %210, %float3.000000e00 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc86)
%213 = torch.aten.mul.Scalar %212, %float4.471500e-02 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc87)
%214 = torch.aten.add.Tensor %210, %213, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc88)
%215 = torch.aten.mul.Scalar %214, %float7.978850e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc89)
%216 = torch.aten.tanh %215 : !torch.tensor -> !torch.tensor loc(#loc90)
%217 = torch.aten.add.Scalar %216, %float1.000000e00, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor loc(#loc91)
%218 = torch.aten.mul.Tensor %211, %217 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc92)
%219 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%220 = torch.aten.view %218, %219 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc93)
%221 = torch.prim.GetAttr %arg0["_param_constant12"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%222 = torch.prim.GetAttr %arg0["_param_constant13"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%223 = torch.aten.addmm %221, %220, %222, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc94)
%224 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%225 = torch.aten.view %223, %224 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc95)
%226 = torch.aten.add.Tensor %200, %225, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc96)
%227 = torch.prim.GetAttr %arg0["_param_constant14"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%228 = torch.prim.GetAttr %arg0["_param_constant15"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%229 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_5, %result1_6, %result2_7 = torch.aten.native_layer_norm %226, %229, %227, %228, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc97)
%230 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%231 = torch.aten.view %result0_5, %230 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc98)
%232 = torch.prim.GetAttr %arg0["_param_constant16"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%233 = torch.prim.GetAttr %arg0["_param_constant17"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%234 = torch.aten.addmm %232, %231, %233, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc99)
%235 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%236 = torch.aten.view %234, %235 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc100)
%237 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%238 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%239 = torch.aten.as_strided %236, %237, %238, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc101)
%240 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%241 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%242 = torch.aten.as_strided %236, %240, %241, %int768 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc102)
%243 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%244 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%245 = torch.aten.as_strided %236, %243, %244, %int1536 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc103)
%246 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%247 = torch.aten.view %239, %246 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc104)
%248 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%249 = torch.aten.permute %247, %248 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc105)
%250 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%251 = torch.aten.view %242, %250 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc106)
%252 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%253 = torch.aten.permute %251, %252 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc107)
%254 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%255 = torch.aten.view %245, %254 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc108)
%256 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%257 = torch.aten.permute %255, %256 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc109)
%258 = torch.aten.transpose.int %253, %int-1, %int-2 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc110)
%259 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%260 = torch.aten.expand %249, %259, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc111)
%261 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%262 = torch.aten.view %260, %261 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc112)
%263 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%264 = torch.aten.expand %258, %263, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc113)
%265 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%266 = torch.aten.view %264, %265 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc114)
%267 = torch.aten.bmm %262, %266 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc115)
%268 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%269 = torch.aten._unsafe_view %267, %268 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc116)
%270 = torch.prim.GetAttr %arg0["_tensor_constant3"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%271 = torch.aten.lift_fresh_copy %270 : !torch.tensor -> !torch.tensor loc(#loc117)
%272 = torch.aten.div.Tensor %269, %271 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc118)
%273 = torch.prim.GetAttr %arg0["_tensor_constant4"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%274 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%275 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%276 = torch.aten.as_strided %273, %274, %275, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc119)
%277 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%278 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%279 = torch.aten.as_strided %276, %277, %278, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc120)
%280 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%281 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%282 = torch.aten.as_strided %279, %280, %281, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc121)
%283 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%284 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%285 = torch.aten.as_strided %282, %283, %284, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc122)
%286 = torch.prims.convert_element_type %285, %int11 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc123)
%287 = torch.prim.GetAttr %arg0["_tensor_constant5"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%288 = torch.aten.lift_fresh_copy %287 : !torch.tensor -> !torch.tensor loc(#loc124)
%289 = torch.aten.where.self %286, %272, %288 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc125)
%290 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%291 = torch.aten.amax %289, %290, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc126)
%292 = torch.aten.sub.Tensor %289, %291, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc127)
%293 = torch.aten.exp %292 : !torch.tensor -> !torch.tensor loc(#loc128)
%294 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%295 = torch.aten.sum.dim_IntList %293, %294, %true_0, %none_1 : !torch.tensor, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor loc(#loc129)
%296 = torch.aten.div.Tensor %293, %295 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc130)
%297 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%298 = torch.aten.expand %296, %297, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc131)
%299 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%300 = torch.aten.view %298, %299 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc132)
%301 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%302 = torch.aten.expand %257, %301, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc133)
%303 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%304 = torch.aten.view %302, %303 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc134)
%305 = torch.aten.bmm %300, %304 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc135)
%306 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%307 = torch.aten._unsafe_view %305, %306 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc136)
%308 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%309 = torch.aten.permute %307, %308 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc137)
%310 = torch.aten.clone %309, %int0 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc138)
%311 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%312 = torch.aten.view %310, %311 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc139)
%313 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%314 = torch.aten.view %312, %313 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc140)
%315 = torch.prim.GetAttr %arg0["_param_constant18"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%316 = torch.prim.GetAttr %arg0["_param_constant19"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%317 = torch.aten.addmm %315, %314, %316, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc141)
%318 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%319 = torch.aten.view %317, %318 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc142)
%320 = torch.aten.add.Tensor %319, %226, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc143)
%321 = torch.prim.GetAttr %arg0["_param_constant20"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%322 = torch.prim.GetAttr %arg0["_param_constant21"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%323 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_8, %result1_9, %result2_10 = torch.aten.native_layer_norm %320, %323, %321, %322, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc144)
%324 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%325 = torch.aten.view %result0_8, %324 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc145)
%326 = torch.prim.GetAttr %arg0["_param_constant22"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%327 = torch.prim.GetAttr %arg0["_param_constant23"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%328 = torch.aten.addmm %326, %325, %327, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc146)
%329 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%330 = torch.aten.view %328, %329 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc147)
%331 = torch.aten.mul.Scalar %330, %float5.000000e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc148)
%332 = torch.aten.pow.Tensor_Scalar %330, %float3.000000e00 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc149)
%333 = torch.aten.mul.Scalar %332, %float4.471500e-02 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc150)
%334 = torch.aten.add.Tensor %330, %333, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc151)
%335 = torch.aten.mul.Scalar %334, %float7.978850e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc152)
%336 = torch.aten.tanh %335 : !torch.tensor -> !torch.tensor loc(#loc153)
%337 = torch.aten.add.Scalar %336, %float1.000000e00, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor loc(#loc154)
%338 = torch.aten.mul.Tensor %331, %337 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc155)
%339 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%340 = torch.aten.view %338, %339 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc156)
%341 = torch.prim.GetAttr %arg0["_param_constant24"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%342 = torch.prim.GetAttr %arg0["_param_constant25"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%343 = torch.aten.addmm %341, %340, %342, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc157)
%344 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%345 = torch.aten.view %343, %344 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc158)
%346 = torch.aten.add.Tensor %320, %345, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc159)
%347 = torch.prim.GetAttr %arg0["_param_constant26"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%348 = torch.prim.GetAttr %arg0["_param_constant27"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%349 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_11, %result1_12, %result2_13 = torch.aten.native_layer_norm %346, %349, %347, %348, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc160)
%350 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%351 = torch.aten.view %result0_11, %350 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc161)
%352 = torch.prim.GetAttr %arg0["_param_constant28"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%353 = torch.prim.GetAttr %arg0["_param_constant29"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%354 = torch.aten.addmm %352, %351, %353, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc162)
%355 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%356 = torch.aten.view %354, %355 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc163)
%357 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%358 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%359 = torch.aten.as_strided %356, %357, %358, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc164)
%360 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%361 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%362 = torch.aten.as_strided %356, %360, %361, %int768 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc165)
%363 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%364 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%365 = torch.aten.as_strided %356, %363, %364, %int1536 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc166)
%366 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%367 = torch.aten.view %359, %366 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc167)
%368 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%369 = torch.aten.permute %367, %368 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc168)
%370 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%371 = torch.aten.view %362, %370 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc169)
%372 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%373 = torch.aten.permute %371, %372 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc170)
%374 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%375 = torch.aten.view %365, %374 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc171)
%376 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%377 = torch.aten.permute %375, %376 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc172)
%378 = torch.aten.transpose.int %373, %int-1, %int-2 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc173)
%379 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%380 = torch.aten.expand %369, %379, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc174)
%381 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%382 = torch.aten.view %380, %381 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc175)
%383 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%384 = torch.aten.expand %378, %383, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc176)
%385 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%386 = torch.aten.view %384, %385 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc177)
%387 = torch.aten.bmm %382, %386 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc178)
%388 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%389 = torch.aten._unsafe_view %387, %388 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc179)
%390 = torch.prim.GetAttr %arg0["_tensor_constant6"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%391 = torch.aten.lift_fresh_copy %390 : !torch.tensor -> !torch.tensor loc(#loc180)
%392 = torch.aten.div.Tensor %389, %391 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc181)
%393 = torch.prim.GetAttr %arg0["_tensor_constant7"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%394 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%395 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%396 = torch.aten.as_strided %393, %394, %395, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc182)
%397 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%398 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%399 = torch.aten.as_strided %396, %397, %398, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc183)
%400 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%401 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%402 = torch.aten.as_strided %399, %400, %401, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc184)
%403 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%404 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%405 = torch.aten.as_strided %402, %403, %404, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc185)
%406 = torch.prims.convert_element_type %405, %int11 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc186)
%407 = torch.prim.GetAttr %arg0["_tensor_constant8"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%408 = torch.aten.lift_fresh_copy %407 : !torch.tensor -> !torch.tensor loc(#loc187)
%409 = torch.aten.where.self %406, %392, %408 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc188)
%410 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%411 = torch.aten.amax %409, %410, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc189)
%412 = torch.aten.sub.Tensor %409, %411, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc190)
%413 = torch.aten.exp %412 : !torch.tensor -> !torch.tensor loc(#loc191)
%414 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%415 = torch.aten.sum.dim_IntList %413, %414, %true_0, %none_1 : !torch.tensor, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor loc(#loc192)
%416 = torch.aten.div.Tensor %413, %415 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc193)
%417 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%418 = torch.aten.expand %416, %417, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc194)
%419 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%420 = torch.aten.view %418, %419 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc195)
%421 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%422 = torch.aten.expand %377, %421, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc196)
%423 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%424 = torch.aten.view %422, %423 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc197)
%425 = torch.aten.bmm %420, %424 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc198)
%426 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%427 = torch.aten._unsafe_view %425, %426 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc199)
%428 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%429 = torch.aten.permute %427, %428 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc200)
%430 = torch.aten.clone %429, %int0 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc201)
%431 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%432 = torch.aten.view %430, %431 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc202)
%433 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%434 = torch.aten.view %432, %433 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc203)
%435 = torch.prim.GetAttr %arg0["_param_constant30"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%436 = torch.prim.GetAttr %arg0["_param_constant31"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%437 = torch.aten.addmm %435, %434, %436, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc204)
%438 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%439 = torch.aten.view %437, %438 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc205)
%440 = torch.aten.add.Tensor %439, %346, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc206)
%441 = torch.prim.GetAttr %arg0["_param_constant32"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%442 = torch.prim.GetAttr %arg0["_param_constant33"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%443 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_14, %result1_15, %result2_16 = torch.aten.native_layer_norm %440, %443, %441, %442, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc207)
%444 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%445 = torch.aten.view %result0_14, %444 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc208)
%446 = torch.prim.GetAttr %arg0["_param_constant34"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%447 = torch.prim.GetAttr %arg0["_param_constant35"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%448 = torch.aten.addmm %446, %445, %447, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc209)
%449 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%450 = torch.aten.view %448, %449 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc210)
%451 = torch.aten.mul.Scalar %450, %float5.000000e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc211)
%452 = torch.aten.pow.Tensor_Scalar %450, %float3.000000e00 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc212)
%453 = torch.aten.mul.Scalar %452, %float4.471500e-02 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc213)
%454 = torch.aten.add.Tensor %450, %453, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc214)
%455 = torch.aten.mul.Scalar %454, %float7.978850e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc215)
%456 = torch.aten.tanh %455 : !torch.tensor -> !torch.tensor loc(#loc216)
%457 = torch.aten.add.Scalar %456, %float1.000000e00, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor loc(#loc217)
%458 = torch.aten.mul.Tensor %451, %457 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc218)
%459 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%460 = torch.aten.view %458, %459 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc219)
%461 = torch.prim.GetAttr %arg0["_param_constant36"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%462 = torch.prim.GetAttr %arg0["_param_constant37"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%463 = torch.aten.addmm %461, %460, %462, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc220)
%464 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%465 = torch.aten.view %463, %464 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc221)
%466 = torch.aten.add.Tensor %440, %465, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc222)
%467 = torch.prim.GetAttr %arg0["_param_constant38"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%468 = torch.prim.GetAttr %arg0["_param_constant39"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%469 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_17, %result1_18, %result2_19 = torch.aten.native_layer_norm %466, %469, %467, %468, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc223)
%470 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%471 = torch.aten.view %result0_17, %470 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc224)
%472 = torch.prim.GetAttr %arg0["_param_constant40"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%473 = torch.prim.GetAttr %arg0["_param_constant41"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%474 = torch.aten.addmm %472, %471, %473, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc225)
%475 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%476 = torch.aten.view %474, %475 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc226)
%477 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%478 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%479 = torch.aten.as_strided %476, %477, %478, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc227)
%480 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%481 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%482 = torch.aten.as_strided %476, %480, %481, %int768 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc228)
%483 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%484 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%485 = torch.aten.as_strided %476, %483, %484, %int1536 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc229)
%486 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%487 = torch.aten.view %479, %486 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc230)
%488 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%489 = torch.aten.permute %487, %488 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc231)
%490 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%491 = torch.aten.view %482, %490 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc232)
%492 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%493 = torch.aten.permute %491, %492 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc233)
%494 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%495 = torch.aten.view %485, %494 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc234)
%496 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%497 = torch.aten.permute %495, %496 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc235)
%498 = torch.aten.transpose.int %493, %int-1, %int-2 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc236)
%499 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%500 = torch.aten.expand %489, %499, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc237)
%501 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%502 = torch.aten.view %500, %501 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc238)
%503 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%504 = torch.aten.expand %498, %503, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc239)
%505 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%506 = torch.aten.view %504, %505 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc240)
%507 = torch.aten.bmm %502, %506 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc241)
%508 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%509 = torch.aten._unsafe_view %507, %508 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc242)
%510 = torch.prim.GetAttr %arg0["_tensor_constant9"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%511 = torch.aten.lift_fresh_copy %510 : !torch.tensor -> !torch.tensor loc(#loc243)
%512 = torch.aten.div.Tensor %509, %511 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc244)
%513 = torch.prim.GetAttr %arg0["_tensor_constant10"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%514 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%515 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%516 = torch.aten.as_strided %513, %514, %515, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc245)
%517 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%518 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%519 = torch.aten.as_strided %516, %517, %518, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc246)
%520 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%521 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%522 = torch.aten.as_strided %519, %520, %521, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc247)
%523 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%524 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%525 = torch.aten.as_strided %522, %523, %524, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc248)
%526 = torch.prims.convert_element_type %525, %int11 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc249)
%527 = torch.prim.GetAttr %arg0["_tensor_constant11"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%528 = torch.aten.lift_fresh_copy %527 : !torch.tensor -> !torch.tensor loc(#loc250)
%529 = torch.aten.where.self %526, %512, %528 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc251)
%530 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%531 = torch.aten.amax %529, %530, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc252)
%532 = torch.aten.sub.Tensor %529, %531, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc253)
%533 = torch.aten.exp %532 : !torch.tensor -> !torch.tensor loc(#loc254)
%534 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%535 = torch.aten.sum.dim_IntList %533, %534, %true_0, %none_1 : !torch.tensor, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor loc(#loc255)
%536 = torch.aten.div.Tensor %533, %535 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc256)
%537 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%538 = torch.aten.expand %536, %537, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc257)
%539 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%540 = torch.aten.view %538, %539 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc258)
%541 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%542 = torch.aten.expand %497, %541, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc259)
%543 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%544 = torch.aten.view %542, %543 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc260)
%545 = torch.aten.bmm %540, %544 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc261)
%546 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%547 = torch.aten._unsafe_view %545, %546 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc262)
%548 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%549 = torch.aten.permute %547, %548 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc263)
%550 = torch.aten.clone %549, %int0 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc264)
%551 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%552 = torch.aten.view %550, %551 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc265)
%553 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%554 = torch.aten.view %552, %553 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc266)
%555 = torch.prim.GetAttr %arg0["_param_constant42"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%556 = torch.prim.GetAttr %arg0["_param_constant43"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%557 = torch.aten.addmm %555, %554, %556, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc267)
%558 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%559 = torch.aten.view %557, %558 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc268)
%560 = torch.aten.add.Tensor %559, %466, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc269)
%561 = torch.prim.GetAttr %arg0["_param_constant44"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%562 = torch.prim.GetAttr %arg0["_param_constant45"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%563 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_20, %result1_21, %result2_22 = torch.aten.native_layer_norm %560, %563, %561, %562, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc270)
%564 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%565 = torch.aten.view %result0_20, %564 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc271)
%566 = torch.prim.GetAttr %arg0["_param_constant46"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%567 = torch.prim.GetAttr %arg0["_param_constant47"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%568 = torch.aten.addmm %566, %565, %567, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc272)
%569 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%570 = torch.aten.view %568, %569 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc273)
%571 = torch.aten.mul.Scalar %570, %float5.000000e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc274)
%572 = torch.aten.pow.Tensor_Scalar %570, %float3.000000e00 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc275)
%573 = torch.aten.mul.Scalar %572, %float4.471500e-02 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc276)
%574 = torch.aten.add.Tensor %570, %573, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc277)
%575 = torch.aten.mul.Scalar %574, %float7.978850e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc278)
%576 = torch.aten.tanh %575 : !torch.tensor -> !torch.tensor loc(#loc279)
%577 = torch.aten.add.Scalar %576, %float1.000000e00, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor loc(#loc280)
%578 = torch.aten.mul.Tensor %571, %577 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc281)
%579 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%580 = torch.aten.view %578, %579 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc282)
%581 = torch.prim.GetAttr %arg0["_param_constant48"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%582 = torch.prim.GetAttr %arg0["_param_constant49"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%583 = torch.aten.addmm %581, %580, %582, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc283)
%584 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%585 = torch.aten.view %583, %584 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc284)
%586 = torch.aten.add.Tensor %560, %585, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc285)
%587 = torch.prim.GetAttr %arg0["_param_constant50"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%588 = torch.prim.GetAttr %arg0["_param_constant51"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%589 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_23, %result1_24, %result2_25 = torch.aten.native_layer_norm %586, %589, %587, %588, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc286)
%590 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%591 = torch.aten.view %result0_23, %590 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc287)
%592 = torch.prim.GetAttr %arg0["_param_constant52"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%593 = torch.prim.GetAttr %arg0["_param_constant53"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%594 = torch.aten.addmm %592, %591, %593, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc288)
%595 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%596 = torch.aten.view %594, %595 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc289)
%597 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%598 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%599 = torch.aten.as_strided %596, %597, %598, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc290)
%600 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%601 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%602 = torch.aten.as_strided %596, %600, %601, %int768 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc291)
%603 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%604 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%605 = torch.aten.as_strided %596, %603, %604, %int1536 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc292)
%606 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%607 = torch.aten.view %599, %606 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc293)
%608 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%609 = torch.aten.permute %607, %608 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc294)
%610 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%611 = torch.aten.view %602, %610 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc295)
%612 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%613 = torch.aten.permute %611, %612 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc296)
%614 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%615 = torch.aten.view %605, %614 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc297)
%616 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%617 = torch.aten.permute %615, %616 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc298)
%618 = torch.aten.transpose.int %613, %int-1, %int-2 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc299)
%619 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%620 = torch.aten.expand %609, %619, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc300)
%621 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%622 = torch.aten.view %620, %621 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc301)
%623 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%624 = torch.aten.expand %618, %623, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc302)
%625 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%626 = torch.aten.view %624, %625 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc303)
%627 = torch.aten.bmm %622, %626 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc304)
%628 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%629 = torch.aten._unsafe_view %627, %628 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc305)
%630 = torch.prim.GetAttr %arg0["_tensor_constant12"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%631 = torch.aten.lift_fresh_copy %630 : !torch.tensor -> !torch.tensor loc(#loc306)
%632 = torch.aten.div.Tensor %629, %631 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc307)
%633 = torch.prim.GetAttr %arg0["_tensor_constant13"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%634 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%635 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%636 = torch.aten.as_strided %633, %634, %635, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc308)
%637 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%638 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%639 = torch.aten.as_strided %636, %637, %638, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc309)
%640 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%641 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%642 = torch.aten.as_strided %639, %640, %641, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc310)
%643 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%644 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%645 = torch.aten.as_strided %642, %643, %644, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc311)
%646 = torch.prims.convert_element_type %645, %int11 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc312)
%647 = torch.prim.GetAttr %arg0["_tensor_constant14"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%648 = torch.aten.lift_fresh_copy %647 : !torch.tensor -> !torch.tensor loc(#loc313)
%649 = torch.aten.where.self %646, %632, %648 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc314)
%650 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%651 = torch.aten.amax %649, %650, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc315)
%652 = torch.aten.sub.Tensor %649, %651, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc316)
%653 = torch.aten.exp %652 : !torch.tensor -> !torch.tensor loc(#loc317)
%654 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%655 = torch.aten.sum.dim_IntList %653, %654, %true_0, %none_1 : !torch.tensor, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor loc(#loc318)
%656 = torch.aten.div.Tensor %653, %655 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc319)
%657 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%658 = torch.aten.expand %656, %657, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc320)
%659 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%660 = torch.aten.view %658, %659 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc321)
%661 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%662 = torch.aten.expand %617, %661, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc322)
%663 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%664 = torch.aten.view %662, %663 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc323)
%665 = torch.aten.bmm %660, %664 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc324)
%666 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%667 = torch.aten._unsafe_view %665, %666 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc325)
%668 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%669 = torch.aten.permute %667, %668 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc326)
%670 = torch.aten.clone %669, %int0 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc327)
%671 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%672 = torch.aten.view %670, %671 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc328)
%673 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%674 = torch.aten.view %672, %673 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc329)
%675 = torch.prim.GetAttr %arg0["_param_constant54"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%676 = torch.prim.GetAttr %arg0["_param_constant55"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%677 = torch.aten.addmm %675, %674, %676, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc330)
%678 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%679 = torch.aten.view %677, %678 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc331)
%680 = torch.aten.add.Tensor %679, %586, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc332)
%681 = torch.prim.GetAttr %arg0["_param_constant56"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%682 = torch.prim.GetAttr %arg0["_param_constant57"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%683 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_26, %result1_27, %result2_28 = torch.aten.native_layer_norm %680, %683, %681, %682, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc333)
%684 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%685 = torch.aten.view %result0_26, %684 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc334)
%686 = torch.prim.GetAttr %arg0["_param_constant58"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%687 = torch.prim.GetAttr %arg0["_param_constant59"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%688 = torch.aten.addmm %686, %685, %687, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc335)
%689 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%690 = torch.aten.view %688, %689 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc336)
%691 = torch.aten.mul.Scalar %690, %float5.000000e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc337)
%692 = torch.aten.pow.Tensor_Scalar %690, %float3.000000e00 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc338)
%693 = torch.aten.mul.Scalar %692, %float4.471500e-02 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc339)
%694 = torch.aten.add.Tensor %690, %693, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc340)
%695 = torch.aten.mul.Scalar %694, %float7.978850e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc341)
%696 = torch.aten.tanh %695 : !torch.tensor -> !torch.tensor loc(#loc342)
%697 = torch.aten.add.Scalar %696, %float1.000000e00, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor loc(#loc343)
%698 = torch.aten.mul.Tensor %691, %697 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc344)
%699 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%700 = torch.aten.view %698, %699 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc345)
%701 = torch.prim.GetAttr %arg0["_param_constant60"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%702 = torch.prim.GetAttr %arg0["_param_constant61"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%703 = torch.aten.addmm %701, %700, %702, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc346)
%704 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%705 = torch.aten.view %703, %704 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc347)
%706 = torch.aten.add.Tensor %680, %705, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc348)
%707 = torch.prim.GetAttr %arg0["_param_constant62"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%708 = torch.prim.GetAttr %arg0["_param_constant63"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%709 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_29, %result1_30, %result2_31 = torch.aten.native_layer_norm %706, %709, %707, %708, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc349)
%710 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%711 = torch.aten.view %result0_29, %710 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc350)
%712 = torch.prim.GetAttr %arg0["_param_constant64"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%713 = torch.prim.GetAttr %arg0["_param_constant65"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%714 = torch.aten.addmm %712, %711, %713, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc351)
%715 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%716 = torch.aten.view %714, %715 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc352)
%717 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%718 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%719 = torch.aten.as_strided %716, %717, %718, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc353)
%720 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%721 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%722 = torch.aten.as_strided %716, %720, %721, %int768 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc354)
%723 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%724 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%725 = torch.aten.as_strided %716, %723, %724, %int1536 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc355)
%726 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%727 = torch.aten.view %719, %726 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc356)
%728 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%729 = torch.aten.permute %727, %728 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc357)
%730 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%731 = torch.aten.view %722, %730 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc358)
%732 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%733 = torch.aten.permute %731, %732 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc359)
%734 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%735 = torch.aten.view %725, %734 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc360)
%736 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%737 = torch.aten.permute %735, %736 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc361)
%738 = torch.aten.transpose.int %733, %int-1, %int-2 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc362)
%739 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%740 = torch.aten.expand %729, %739, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc363)
%741 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%742 = torch.aten.view %740, %741 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc364)
%743 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%744 = torch.aten.expand %738, %743, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc365)
%745 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%746 = torch.aten.view %744, %745 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc366)
%747 = torch.aten.bmm %742, %746 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc367)
%748 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%749 = torch.aten._unsafe_view %747, %748 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc368)
%750 = torch.prim.GetAttr %arg0["_tensor_constant15"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%751 = torch.aten.lift_fresh_copy %750 : !torch.tensor -> !torch.tensor loc(#loc369)
%752 = torch.aten.div.Tensor %749, %751 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc370)
%753 = torch.prim.GetAttr %arg0["_tensor_constant16"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%754 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%755 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%756 = torch.aten.as_strided %753, %754, %755, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc371)
%757 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%758 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%759 = torch.aten.as_strided %756, %757, %758, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc372)
%760 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%761 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%762 = torch.aten.as_strided %759, %760, %761, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc373)
%763 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%764 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%765 = torch.aten.as_strided %762, %763, %764, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc374)
%766 = torch.prims.convert_element_type %765, %int11 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc375)
%767 = torch.prim.GetAttr %arg0["_tensor_constant17"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%768 = torch.aten.lift_fresh_copy %767 : !torch.tensor -> !torch.tensor loc(#loc376)
%769 = torch.aten.where.self %766, %752, %768 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc377)
%770 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%771 = torch.aten.amax %769, %770, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc378)
%772 = torch.aten.sub.Tensor %769, %771, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc379)
%773 = torch.aten.exp %772 : !torch.tensor -> !torch.tensor loc(#loc380)
%774 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
%775 = torch.aten.sum.dim_IntList %773, %774, %true_0, %none_1 : !torch.tensor, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor loc(#loc381)
%776 = torch.aten.div.Tensor %773, %775 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc382)
%777 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%778 = torch.aten.expand %776, %777, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc383)
%779 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%780 = torch.aten.view %778, %779 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc384)
%781 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%782 = torch.aten.expand %737, %781, %false : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc385)
%783 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%784 = torch.aten.view %782, %783 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc386)
%785 = torch.aten.bmm %780, %784 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc387)
%786 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%787 = torch.aten._unsafe_view %785, %786 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc388)
%788 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%789 = torch.aten.permute %787, %788 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc389)
%790 = torch.aten.clone %789, %int0 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc390)
%791 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%792 = torch.aten.view %790, %791 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc391)
%793 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%794 = torch.aten.view %792, %793 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc392)
%795 = torch.prim.GetAttr %arg0["_param_constant66"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%796 = torch.prim.GetAttr %arg0["_param_constant67"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%797 = torch.aten.addmm %795, %794, %796, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc393)
%798 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%799 = torch.aten.view %797, %798 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc394)
%800 = torch.aten.add.Tensor %799, %706, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc395)
%801 = torch.prim.GetAttr %arg0["_param_constant68"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%802 = torch.prim.GetAttr %arg0["_param_constant69"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%803 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_32, %result1_33, %result2_34 = torch.aten.native_layer_norm %800, %803, %801, %802, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc396)
%804 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%805 = torch.aten.view %result0_32, %804 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc397)
%806 = torch.prim.GetAttr %arg0["_param_constant70"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%807 = torch.prim.GetAttr %arg0["_param_constant71"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%808 = torch.aten.addmm %806, %805, %807, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc398)
%809 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%810 = torch.aten.view %808, %809 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc399)
%811 = torch.aten.mul.Scalar %810, %float5.000000e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc400)
%812 = torch.aten.pow.Tensor_Scalar %810, %float3.000000e00 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc401)
%813 = torch.aten.mul.Scalar %812, %float4.471500e-02 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc402)
%814 = torch.aten.add.Tensor %810, %813, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc403)
%815 = torch.aten.mul.Scalar %814, %float7.978850e-01 : !torch.tensor, !torch.float -> !torch.tensor loc(#loc404)
%816 = torch.aten.tanh %815 : !torch.tensor -> !torch.tensor loc(#loc405)
%817 = torch.aten.add.Scalar %816, %float1.000000e00, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor loc(#loc406)
%818 = torch.aten.mul.Tensor %811, %817 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc407)
%819 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%820 = torch.aten.view %818, %819 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc408)
%821 = torch.prim.GetAttr %arg0["_param_constant72"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%822 = torch.prim.GetAttr %arg0["_param_constant73"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%823 = torch.aten.addmm %821, %820, %822, %int1, %int1 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc409)
%824 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%825 = torch.aten.view %823, %824 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc410)
%826 = torch.aten.add.Tensor %800, %825, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor loc(#loc411)
%827 = torch.prim.GetAttr %arg0["_param_constant74"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%828 = torch.prim.GetAttr %arg0["_param_constant75"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%829 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list<int> loc(#loc)
%result0_35, %result1_36, %result2_37 = torch.aten.native_layer_norm %826, %829, %827, %828, %float1.000000e-05 : !torch.tensor, !torch.list<int>, !torch.tensor, !torch.tensor, !torch.float -> !torch.tensor, !torch.tensor, !torch.tensor loc(#loc412)
%830 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%831 = torch.aten.view %result0_35, %830 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc413)
%832 = torch.prim.GetAttr %arg0["_param_constant76"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
%833 = torch.aten.t %832 : !torch.tensor -> !torch.tensor loc(#loc414)
%834 = torch.prim.ListConstruct %int128, %int768 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%835 = torch.aten.view %831, %834 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc415)
%836 = torch.aten.mm %835, %833 : !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc416)
%837 = torch.prim.ListConstruct %int1, %int128, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
%838 = torch.aten._unsafe_view %836, %837 : !torch.tensor, !torch.list<int> -> !torch.tensor loc(#loc417)
%839 = torch.aten.arange %int1, %none_1, %none_1, %cpu, %false : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.tensor loc(#loc418)
%840 = torch.aten.select.int %838, %int1, %int-1 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc419)
%841 = torch.prim.ListConstruct %839 : (!torch.tensor) -> !torch.list<tensor> loc(#loc)
%842 = torch.aten.index.Tensor_hacked_twin %840, %841 : !torch.tensor, !torch.list<tensor> -> !torch.tensor loc(#loc420)
return %842 : !torch.tensor loc(#loc)
} loc(#loc)
torch.class_type @__torch__.torch.fx.graph_module._lambda {
torch.attr private "_param_constant0" : !torch.tensor loc(#loc)
torch.attr private "_param_constant1" : !torch.tensor loc(#loc)
torch.attr private "_param_constant2" : !torch.tensor loc(#loc)
torch.attr private "_param_constant3" : !torch.tensor loc(#loc)
torch.attr private "_param_constant4" : !torch.tensor loc(#loc)
torch.attr private "_param_constant5" : !torch.tensor loc(#loc)
torch.attr private "_param_constant6" : !torch.tensor loc(#loc)
torch.attr private "_param_constant7" : !torch.tensor loc(#loc)
torch.attr private "_param_constant8" : !torch.tensor loc(#loc)
torch.attr private "_param_constant9" : !torch.tensor loc(#loc)
torch.attr private "_param_constant10" : !torch.tensor loc(#loc)
torch.attr private "_param_constant11" : !torch.tensor loc(#loc)
torch.attr private "_param_constant12" : !torch.tensor loc(#loc)
torch.attr private "_param_constant13" : !torch.tensor loc(#loc)
torch.attr private "_param_constant14" : !torch.tensor loc(#loc)
torch.attr private "_param_constant15" : !torch.tensor loc(#loc)
torch.attr private "_param_constant16" : !torch.tensor loc(#loc)
torch.attr private "_param_constant17" : !torch.tensor loc(#loc)
torch.attr private "_param_constant18" : !torch.tensor loc(#loc)
torch.attr private "_param_constant19" : !torch.tensor loc(#loc)
torch.attr private "_param_constant20" : !torch.tensor loc(#loc)
torch.attr private "_param_constant21" : !torch.tensor loc(#loc)
torch.attr private "_param_constant22" : !torch.tensor loc(#loc)
torch.attr private "_param_constant23" : !torch.tensor loc(#loc)
torch.attr private "_param_constant24" : !torch.tensor loc(#loc)
torch.attr private "_param_constant25" : !torch.tensor loc(#loc)
torch.attr private "_param_constant26" : !torch.tensor loc(#loc)
torch.attr private "_param_constant27" : !torch.tensor loc(#loc)
torch.attr private "_param_constant28" : !torch.tensor loc(#loc)
torch.attr private "_param_constant29" : !torch.tensor loc(#loc)
torch.attr private "_param_constant30" : !torch.tensor loc(#loc)
torch.attr private "_param_constant31" : !torch.tensor loc(#loc)
torch.attr private "_param_constant32" : !torch.tensor loc(#loc)
torch.attr private "_param_constant33" : !torch.tensor loc(#loc)
torch.attr private "_param_constant34" : !torch.tensor loc(#loc)
torch.attr private "_param_constant35" : !torch.tensor loc(#loc)
torch.attr private "_param_constant36" : !torch.tensor loc(#loc)
torch.attr private "_param_constant37" : !torch.tensor loc(#loc)
torch.attr private "_param_constant38" : !torch.tensor loc(#loc)
torch.attr private "_param_constant39" : !torch.tensor loc(#loc)
torch.attr private "_param_constant40" : !torch.tensor loc(#loc)
torch.attr private "_param_constant41" : !torch.tensor loc(#loc)
torch.attr private "_param_constant42" : !torch.tensor loc(#loc)
torch.attr private "_param_constant43" : !torch.tensor loc(#loc)
torch.attr private "_param_constant44" : !torch.tensor loc(#loc)
torch.attr private "_param_constant45" : !torch.tensor loc(#loc)
torch.attr private "_param_constant46" : !torch.tensor loc(#loc)
torch.attr private "_param_constant47" : !torch.tensor loc(#loc)
torch.attr private "_param_constant48" : !torch.tensor loc(#loc)
torch.attr private "_param_constant49" : !torch.tensor loc(#loc)
torch.attr private "_param_constant50" : !torch.tensor loc(#loc)
torch.attr private "_param_constant51" : !torch.tensor loc(#loc)
torch.attr private "_param_constant52" : !torch.tensor loc(#loc)
torch.attr private "_param_constant53" : !torch.tensor loc(#loc)
torch.attr private "_param_constant54" : !torch.tensor loc(#loc)
torch.attr private "_param_constant55" : !torch.tensor loc(#loc)
torch.attr private "_param_constant56" : !torch.tensor loc(#loc)
torch.attr private "_param_constant57" : !torch.tensor loc(#loc)
torch.attr private "_param_constant58" : !torch.tensor loc(#loc)
torch.attr private "_param_constant59" : !torch.tensor loc(#loc)
torch.attr private "_param_constant60" : !torch.tensor loc(#loc)
torch.attr private "_param_constant61" : !torch.tensor loc(#loc)
torch.attr private "_param_constant62" : !torch.tensor loc(#loc)
torch.attr private "_param_constant63" : !torch.tensor loc(#loc)
torch.attr private "_param_constant64" : !torch.tensor loc(#loc)
torch.attr private "_param_constant65" : !torch.tensor loc(#loc)
torch.attr private "_param_constant66" : !torch.tensor loc(#loc)
torch.attr private "_param_constant67" : !torch.tensor loc(#loc)
torch.attr private "_param_constant68" : !torch.tensor loc(#loc)
torch.attr private "_param_constant69" : !torch.tensor loc(#loc)
torch.attr private "_param_constant70" : !torch.tensor loc(#loc)
torch.attr private "_param_constant71" : !torch.tensor loc(#loc)
torch.attr private "_param_constant72" : !torch.tensor loc(#loc)
torch.attr private "_param_constant73" : !torch.tensor loc(#loc)
torch.attr private "_param_constant74" : !torch.tensor loc(#loc)
torch.attr private "_param_constant75" : !torch.tensor loc(#loc)
torch.attr private "_param_constant76" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant0" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant1" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant2" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant3" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant4" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant5" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant6" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant7" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant8" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant9" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant10" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant11" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant12" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant13" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant14" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant15" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant16" : !torch.tensor loc(#loc)
torch.attr private "_tensor_constant17" : !torch.tensor loc(#loc)
torch.attr private "training" : !torch.bool loc(#loc)
torch.attr private "_is_full_backward_hook" : !torch.optional<bool> loc(#loc)
torch.attr private "_code" : !torch.str loc(#loc)
torch.method private "__code_getter", @__torch__.torch.fx.graph_module._lambda.__code_getter loc(#loc)
torch.method "forward", @__torch__.torch.fx.graph_module._lambda.forward loc(#loc)
} loc(#loc)
%0 = torch.tensor.literal(dense_resource<__elided__> : tensor<50257x768xf32>) : !torch.tensor<[50257,768],f32> loc(#loc)
%1 = torch.tensor.literal(dense_resource<__elided__> : tensor<1024x768xf32>) : !torch.tensor<[1024,768],f32> loc(#loc)
%2 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%3 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%4 = torch.tensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.tensor<[2304],f32> loc(#loc)
%5 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.tensor<[768,2304],f32> loc(#loc)
%6 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%7 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.tensor<[768,768],f32> loc(#loc)
%8 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%9 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%10 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.tensor<[3072],f32> loc(#loc)
%11 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.tensor<[768,3072],f32> loc(#loc)
%12 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%13 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.tensor<[3072,768],f32> loc(#loc)
%14 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%15 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%16 = torch.tensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.tensor<[2304],f32> loc(#loc)
%17 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.tensor<[768,2304],f32> loc(#loc)
%18 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%19 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.tensor<[768,768],f32> loc(#loc)
%20 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%21 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%22 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.tensor<[3072],f32> loc(#loc)
%23 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.tensor<[768,3072],f32> loc(#loc)
%24 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%25 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.tensor<[3072,768],f32> loc(#loc)
%26 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%27 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%28 = torch.tensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.tensor<[2304],f32> loc(#loc)
%29 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.tensor<[768,2304],f32> loc(#loc)
%30 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%31 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.tensor<[768,768],f32> loc(#loc)
%32 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%33 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%34 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.tensor<[3072],f32> loc(#loc)
%35 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.tensor<[768,3072],f32> loc(#loc)
%36 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%37 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.tensor<[3072,768],f32> loc(#loc)
%38 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%39 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%40 = torch.tensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.tensor<[2304],f32> loc(#loc)
%41 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.tensor<[768,2304],f32> loc(#loc)
%42 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%43 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.tensor<[768,768],f32> loc(#loc)
%44 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%45 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%46 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.tensor<[3072],f32> loc(#loc)
%47 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.tensor<[768,3072],f32> loc(#loc)
%48 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%49 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.tensor<[3072,768],f32> loc(#loc)
%50 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%51 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%52 = torch.tensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.tensor<[2304],f32> loc(#loc)
%53 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.tensor<[768,2304],f32> loc(#loc)
%54 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%55 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.tensor<[768,768],f32> loc(#loc)
%56 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%57 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%58 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.tensor<[3072],f32> loc(#loc)
%59 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.tensor<[768,3072],f32> loc(#loc)
%60 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%61 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.tensor<[3072,768],f32> loc(#loc)
%62 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%63 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%64 = torch.tensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.tensor<[2304],f32> loc(#loc)
%65 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.tensor<[768,2304],f32> loc(#loc)
%66 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%67 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.tensor<[768,768],f32> loc(#loc)
%68 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%69 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%70 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.tensor<[3072],f32> loc(#loc)
%71 = torch.tensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.tensor<[768,3072],f32> loc(#loc)
%72 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%73 = torch.tensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.tensor<[3072,768],f32> loc(#loc)
%74 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%75 = torch.tensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.tensor<[768],f32> loc(#loc)
%76 = torch.tensor.literal(dense_resource<__elided__> : tensor<2x768xf32>) : !torch.tensor<[2,768],f32> loc(#loc)
%77 = torch.tensor.literal(dense<8.000000e+00> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%78 = torch.tensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
%79 = torch.tensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%80 = torch.tensor.literal(dense<8.000000e+00> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%81 = torch.tensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
%82 = torch.tensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%83 = torch.tensor.literal(dense<8.000000e+00> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%84 = torch.tensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
%85 = torch.tensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%86 = torch.tensor.literal(dense<8.000000e+00> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%87 = torch.tensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
%88 = torch.tensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%89 = torch.tensor.literal(dense<8.000000e+00> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%90 = torch.tensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
%91 = torch.tensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%92 = torch.tensor.literal(dense<8.000000e+00> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%93 = torch.tensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
%94 = torch.tensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.tensor<[],f32> loc(#loc)
%true = torch.constant.bool true loc(#loc)
%none = torch.constant.none loc(#loc)
%str = torch.constant.str "\0A\0A\0Adef forward(self, arg0_1):\0A view = torch.ops.aten.view(arg0_1, [-1, 128]); arg0_1 = None\0A arange = torch.ops.aten.arange(0, 128, dtype = torch.int64, device = device(type='cpu'), pin_memory = False)\0A unsqueeze = torch.ops.aten.unsqueeze(arange, 0); arange = None\0A view_1 = torch.ops.aten.view(unsqueeze, [-1, 128]); unsqueeze = None\0A _param_constant0 = self._param_constant0\0A embedding = torch.ops.aten.embedding(_param_constant0, view); _param_constant0 = view = None\0A _param_constant1 = self._param_constant1\0A embedding_1 = torch.ops.aten.embedding(_param_constant1, view_1); _param_constant1 = view_1 = None\0A add = torch.ops.aten.add(embedding, embedding_1); embedding = embedding_1 = None\0A _param_constant2 = self._param_constant2\0A _param_constant3 = self._param_constant3\0A native_layer_norm = torch.ops.aten.native_layer_norm(add, [768], _param_constant2, _param_constant3, 1e-05); _param_constant2 = _param_constant3 = None\0A getitem = native_layer_norm[0]\0A getitem_1 = native_layer_norm[1]\0A getitem_2 = native_layer_norm[2]; native_layer_norm = None\0A view_2 = torch.ops.aten.view(getitem, [-1, 768]); getitem = None\0A _param_constant4 = self._param_constant4\0A _param_constant5 = self._param_constant5\0A addmm = torch.ops.aten.addmm(_param_constant4, view_2, _param_constant5); _param_constant4 = view_2 = _param_constant5 = None\0A view_3 = torch.ops.aten.view(addmm, [1, 128, 2304]); addmm = None\0A as_strided = torch.ops.aten.as_strided(view_3, [1, 128, 768], [294912, 2304, 1], 0)\0A as_strided_1 = torch.ops.aten.as_strided(view_3, [1, 128, 768], [294912, 2304, 1], 768)\0A as_strided_2 = torch.ops.aten.as_strided(view_3, [1, 128, 768], [294912, 2304, 1], 1536); view_3 = None\0A view_4 = torch.ops.aten.view(as_strided, [1, 128, 12, 64]); as_strided = None\0A permute = torch.ops.aten.permute(view_4, [0, 2, 1, 3]); view_4 = None\0A view_5 = torch.ops.aten.view(as_strided_1, [1, 128, 12, 64]); as_strided_1 = None\0A permute_1 = torch.ops.aten.permute(view_5, [0, 2, 1, 3]); view_5 = None\0A view_6 = torch.ops.aten.view(as_strided_2, [1, 128, 12, 64]); as_strided_2 = None\0A permute_2 = torch.ops.aten.permute(view_6, [0, 2, 1, 3]); view_6 = None\0A transpose = torch.ops.aten.transpose(permute_1, -1, -2); permute_1 = None\0A expand = torch.ops.aten.expand(permute, [1, 12, 128, 64]); permute = None\0A view_7 = torch.ops.aten.view(expand, [12, 128, 64]); expand = None\0A expand_1 = torch.ops.aten.expand(transpose, [1, 12, 64, 128]); transpose = None\0A view_8 = torch.ops.aten.view(expand_1, [12, 64, 128]); expand_1 = None\0A bmm = torch.ops.aten.bmm(view_7, view_8); view_7 = view_8 = None\0A _unsafe_view = torch.ops.aten._unsafe_view(bmm, [1, 12, 128, 128]); bmm = None\0A _tensor_constant0 = self._tensor_constant0\0A lift_fresh_copy = torch.ops.aten.lift_fresh_copy(_tensor_constant0); _tensor_constant0 = None\0A div = torch.ops.aten.div(_unsafe_view, lift_fresh_copy); _unsafe_view = lift_fresh_copy = None\0A _tensor_constant1 = self._tensor_constant1\0A as_strided_3 = torch.ops.aten.as_strided(_tensor_constant1, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); _tensor_constant1 = None\0A as_strided_4 = torch.ops.aten.as_strided(as_strided_3, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_3 = None\0A as_strided_5 = torch.ops.aten.as_strided(as_strided_4, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_4 = None\0A as_strided_6 = torch.ops.aten.as_strided(as_strided_5, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0); as_strided_5 = None\0A convert_element_type = torch.ops.prims.convert_element_type(as_strided_6, torch.bool); as_strided_6 = None\0A _tensor_constant2 = self._tensor_constant2\0A lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy(_tensor_constant2); _tensor_constant2 = None\0A where = torch.ops.aten.where(convert_element_type, div, lift_fresh_copy_1); convert_element_type = div = lift_fresh_copy_1 = None\0A amax = torch.ops.aten.amax(where, [-1], True)\0A sub = torch.ops.aten.sub(where, amax); where = amax = None\0A exp = torch.ops.aten.exp(sub); sub = None\0A sum_1 = torch.ops.aten.sum(exp, [-1], True)\0A div_1 = torch.ops.aten.div(exp, sum_1); exp = sum_1 = None\0A detach = torch.ops.aten.detach(div_1)\0A expand_2 = torch.ops.aten.expand(div_1, [1, 12, 128, 128]); div_1 = None\0A view_9 = torch.ops.aten.view(expand_2, [12, 128, 128]); expand_2 = None\0A expand_3 = torch.ops.aten.expand(permute_2, [1, 12, 128, 64]); permute_2 = None\0A view_10 = torch.ops.aten.view(expand_3, [12, 128, 64]); expand_3 = None\0A bmm_1 = torch.ops.aten.bmm(view_9, view_10); view_9 = view_10 = None\0A _unsafe_view_1 = torch.ops.aten._unsafe_view(bmm_1, [1, 12, 128, 64]); bmm_1 = None\0A permute_3 = torch.ops.aten.permute(_unsafe_view_1, [0, 2, 1, 3]); _unsafe_view_1 = None\0A clone = torch.ops.aten.clone(permute_3, memory_format = torch.contiguous_format); permute_3 = None\0A view_11 = torch.ops.aten.view(clone, [1, 128, 768]); clone = None\0A view_12 = torch.ops.aten.view(view_11, [-1, 768]); view_11 = None\0A _param_constant6 = self._param_constant6\0A _param_constant7 = self._param_constant7\0A addmm_1 = torch.ops.aten.addmm(_param_constant6, view_12, _param_constant7); _param_constant6 = view_12 = _param_constant7 = None\0A view_13 = torch.ops.aten.view(addmm_1, [1, 128, 768]); addmm_1 = None\0A add_1 = torch.ops.aten.add(view_13, add); view_13 = add = None\0A _param_constant8 = self._param_constant8\0A _param_constant9 = self._param_constant9\0A native_layer_norm_1 = torch.ops.aten.native_layer_norm(add_1, [768], _param_constant8, _param_constant9, 1e-05); _param_constant8 = _param_constant9 = None\0A getitem_3 = native_layer_norm_1[0]\0A getitem_4 = native_layer_norm_1[1]\0A getitem_5 = native_layer_norm_1[2]; native_layer_norm_1 = None\0A view_14 = torch.ops.aten.view(getitem_3, [-1, 768]); getitem_3 = None\0A _param_constant10 = self._param_constant10\0A _param_constant11 = self._param_constant11\0A addmm_2 = torch.ops.aten.addmm(_param_constant10, view_14, _param_constant11); _param_constant10 = view_14 = _param_constant11 = None\0A view_15 = torch.ops.aten.view(addmm_2, [1, 128, 3072]); addmm_2 = None\0A mul = torch.ops.aten.mul(view_15, 0.5)\0A pow_1 = torch.ops.aten.pow(view_15, 3.0)\0A mul_1 = torch.ops.aten.mul(pow_1, 0.044715); pow_1 = None\0A add_2 = torch.ops.aten.add(view_15, mul_1); view_15 = mul_1 = None\0A mul_2 = torch.ops.aten.mul(add_2, 0.7978845608028654); add_2 = None\0A tanh = torch.ops.aten.tanh(mul_2); mul_2 = None\0A detach_1 = torch.ops.aten.detach(tanh)\0A add_3 = torch.ops.aten.add(tanh, 1.0); tanh = None\0A mul_3 = torch.ops.aten.mul(mul, add_3); mul = add_3 = None\0A view_16 = torch.ops.aten.view(mul_3, [-1, 3072]); mul_3 = None\0A _param_constant12 = self._param_constant12\0A _param_constant13 = self._param_constant13\0A addmm_3 = torch.ops.aten.addmm(_param_constant12, view_16, _param_constant13); _param_constant12 = view_16 = _param_constant13 = None\0A view_17 = torch.ops.aten.view(addmm_3, [1, 128, 768]); addmm_3 = None\0A add_4 = torch.ops.aten.add(add_1, view_17); add_1 = view_17 = None\0A _param_constant14 = self._param_constant14\0A _param_constant15 = self._param_constant15\0A native_layer_norm_2 = torch.ops.aten.native_layer_norm(add_4, [768], _param_constant14, _param_constant15, 1e-05); _param_constant14 = _param_constant15 = None\0A getitem_6 = native_layer_norm_2[0]\0A getitem_7 = native_layer_norm_2[1]\0A getitem_8 = native_layer_norm_2[2]; native_layer_norm_2 = None\0A view_18 = torch.ops.aten.view(getitem_6, [-1, 768]); getitem_6 = None\0A _param_constant16 = self._param_constant16\0A _param_constant17 = self._param_constant17\0A addmm_4 = torch.ops.aten.addmm(_param_constant16, view_18, _param_constant17); _param_constant16 = view_18 = _param_constant17 = None\0A view_19 = torch.ops.aten.view(addmm_4, [1, 128, 2304]); addmm_4 = None\0A as_strided_7 = torch.ops.aten.as_strided(view_19, [1, 128, 768], [294912, 2304, 1], 0)\0A as_strided_8 = torch.ops.aten.as_strided(view_19, [1, 128, 768], [294912, 2304, 1], 768)\0A as_strided_9 = torch.ops.aten.as_strided(view_19, [1, 128, 768], [294912, 2304, 1], 1536); view_19 = None\0A view_20 = torch.ops.aten.view(as_strided_7, [1, 128, 12, 64]); as_strided_7 = None\0A permute_4 = torch.ops.aten.permute(view_20, [0, 2, 1, 3]); view_20 = None\0A view_21 = torch.ops.aten.view(as_strided_8, [1, 128, 12, 64]); as_strided_8 = None\0A permute_5 = torch.ops.aten.permute(view_21, [0, 2, 1, 3]); view_21 = None\0A view_22 = torch.ops.aten.view(as_strided_9, [1, 128, 12, 64]); as_strided_9 = None\0A permute_6 = torch.ops.aten.permute(view_22, [0, 2, 1, 3]); view_22 = None\0A transpose_1 = torch.ops.aten.transpose(permute_5, -1, -2); permute_5 = None\0A expand_4 = torch.ops.aten.expand(permute_4, [1, 12, 128, 64]); permute_4 = None\0A view_23 = torch.ops.aten.view(expand_4, [12, 128, 64]); expand_4 = None\0A expand_5 = torch.ops.aten.expand(transpose_1, [1, 12, 64, 128]); transpose_1 = None\0A view_24 = torch.ops.aten.view(expand_5, [12, 64, 128]); expand_5 = None\0A bmm_2 = torch.ops.aten.bmm(view_23, view_24); view_23 = view_24 = None\0A _unsafe_view_2 = torch.ops.aten._unsafe_view(bmm_2, [1, 12, 128, 128]); bmm_2 = None\0A _tensor_constant3 = self._tensor_constant3\0A lift_fresh_copy_2 = torch.ops.aten.lift_fresh_copy(_tensor_constant3); _tensor_constant3 = None\0A div_2 = torch.ops.aten.div(_unsafe_view_2, lift_fresh_copy_2); _unsafe_view_2 = lift_fresh_copy_2 = None\0A _tensor_constant4 = self._tensor_constant4\0A as_strided_10 = torch.ops.aten.as_strided(_tensor_constant4, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); _tensor_constant4 = None\0A as_strided_11 = torch.ops.aten.as_strided(as_strided_10, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_10 = None\0A as_strided_12 = torch.ops.aten.as_strided(as_strided_11, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_11 = None\0A as_strided_13 = torch.ops.aten.as_strided(as_strided_12, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0); as_strided_12 = None\0A convert_element_type_1 = torch.ops.prims.convert_element_type(as_strided_13, torch.bool); as_strided_13 = None\0A _tensor_constant5 = self._tensor_constant5\0A lift_fresh_copy_3 = torch.ops.aten.lift_fresh_copy(_tensor_constant5); _tensor_constant5 = None\0A where_1 = torch.ops.aten.where(convert_element_type_1, div_2, lift_fresh_copy_3); convert_element_type_1 = div_2 = lift_fresh_copy_3 = None\0A amax_1 = torch.ops.aten.amax(where_1, [-1], True)\0A sub_1 = torch.ops.aten.sub(where_1, amax_1); where_1 = amax_1 = None\0A exp_1 = torch.ops.aten.exp(sub_1); sub_1 = None\0A sum_2 = torch.ops.aten.sum(exp_1, [-1], True)\0A div_3 = torch.ops.aten.div(exp_1, sum_2); exp_1 = sum_2 = None\0A detach_2 = torch.ops.aten.detach(div_3)\0A expand_6 = torch.ops.aten.expand(div_3, [1, 12, 128, 128]); div_3 = None\0A view_25 = torch.ops.aten.view(expand_6, [12, 128, 128]); expand_6 = None\0A expand_7 = torch.ops.aten.expand(permute_6, [1, 12, 128, 64]); permute_6 = None\0A view_26 = torch.ops.aten.view(expand_7, [12, 128, 64]); expand_7 = None\0A bmm_3 = torch.ops.aten.bmm(view_25, view_26); view_25 = view_26 = None\0A _unsafe_view_3 = torch.ops.aten._unsafe_view(bmm_3, [1, 12, 128, 64]); bmm_3 = None\0A permute_7 = torch.ops.aten.permute(_unsafe_view_3, [0, 2, 1, 3]); _unsafe_view_3 = None\0A clone_1 = torch.ops.aten.clone(permute_7, memory_format = torch.contiguous_format); permute_7 = None\0A view_27 = torch.ops.aten.view(clone_1, [1, 128, 768]); clone_1 = None\0A view_28 = torch.ops.aten.view(view_27, [-1, 768]); view_27 = None\0A _param_constant18 = self._param_constant18\0A _param_constant19 = self._param_constant19\0A addmm_5 = torch.ops.aten.addmm(_param_constant18, view_28, _param_constant19); _param_constant18 = view_28 = _param_constant19 = None\0A view_29 = torch.ops.aten.view(addmm_5, [1, 128, 768]); addmm_5 = None\0A add_5 = torch.ops.aten.add(view_29, add_4); view_29 = add_4 = None\0A _param_constant20 = self._param_constant20\0A _param_constant21 = self._param_constant21\0A native_layer_norm_3 = torch.ops.aten.native_layer_norm(add_5, [768], _param_constant20, _param_constant21, 1e-05); _param_constant20 = _param_constant21 = None\0A getitem_9 = native_layer_norm_3[0]\0A getitem_10 = native_layer_norm_3[1]\0A getitem_11 = native_layer_norm_3[2]; native_layer_norm_3 = None\0A view_30 = torch.ops.aten.view(getitem_9, [-1, 768]); getitem_9 = None\0A _param_constant22 = self._param_constant22\0A _param_constant23 = self._param_constant23\0A addmm_6 = torch.ops.aten.addmm(_param_constant22, view_30, _param_constant23); _param_constant22 = view_30 = _param_constant23 = None\0A view_31 = torch.ops.aten.view(addmm_6, [1, 128, 3072]); addmm_6 = None\0A mul_4 = torch.ops.aten.mul(view_31, 0.5)\0A pow_2 = torch.ops.aten.pow(view_31, 3.0)\0A mul_5 = torch.ops.aten.mul(pow_2, 0.044715); pow_2 = None\0A add_6 = torch.ops.aten.add(view_31, mul_5); view_31 = mul_5 = None\0A mul_6 = torch.ops.aten.mul(add_6, 0.7978845608028654); add_6 = None\0A tanh_1 = torch.ops.aten.tanh(mul_6); mul_6 = None\0A detach_3 = torch.ops.aten.detach(tanh_1)\0A add_7 = torch.ops.aten.add(tanh_1, 1.0); tanh_1 = None\0A mul_7 = torch.ops.aten.mul(mul_4, add_7); mul_4 = add_7 = None\0A view_32 = torch.ops.aten.view(mul_7, [-1, 3072]); mul_7 = None\0A _param_constant24 = self._param_constant24\0A _param_constant25 = self._param_constant25\0A addmm_7 = torch.ops.aten.addmm(_param_constant24, view_32, _param_constant25); _param_constant24 = view_32 = _param_constant25 = None\0A view_33 = torch.ops.aten.view(addmm_7, [1, 128, 768]); addmm_7 = None\0A add_8 = torch.ops.aten.add(add_5, view_33); add_5 = view_33 = None\0A _param_constant26 = self._param_constant26\0A _param_constant27 = self._param_constant27\0A native_layer_norm_4 = torch.ops.aten.native_layer_norm(add_8, [768], _param_constant26, _param_constant27, 1e-05); _param_constant26 = _param_constant27 = None\0A getitem_12 = native_layer_norm_4[0]\0A getitem_13 = native_layer_norm_4[1]\0A getitem_14 = native_layer_norm_4[2]; native_layer_norm_4 = None\0A view_34 = torch.ops.aten.view(getitem_12, [-1, 768]); getitem_12 = None\0A _param_constant28 = self._param_constant28\0A _param_constant29 = self._param_constant29\0A addmm_8 = torch.ops.aten.addmm(_param_constant28, view_34, _param_constant29); _param_constant28 = view_34 = _param_constant29 = None\0A view_35 = torch.ops.aten.view(addmm_8, [1, 128, 2304]); addmm_8 = None\0A as_strided_14 = torch.ops.aten.as_strided(view_35, [1, 128, 768], [294912, 2304, 1], 0)\0A as_strided_15 = torch.ops.aten.as_strided(view_35, [1, 128, 768], [294912, 2304, 1], 768)\0A as_strided_16 = torch.ops.aten.as_strided(view_35, [1, 128, 768], [294912, 2304, 1], 1536); view_35 = None\0A view_36 = torch.ops.aten.view(as_strided_14, [1, 128, 12, 64]); as_strided_14 = None\0A permute_8 = torch.ops.aten.permute(view_36, [0, 2, 1, 3]); view_36 = None\0A view_37 = torch.ops.aten.view(as_strided_15, [1, 128, 12, 64]); as_strided_15 = None\0A permute_9 = torch.ops.aten.permute(view_37, [0, 2, 1, 3]); view_37 = None\0A view_38 = torch.ops.aten.view(as_strided_16, [1, 128, 12, 64]); as_strided_16 = None\0A permute_10 = torch.ops.aten.permute(view_38, [0, 2, 1, 3]); view_38 = None\0A transpose_2 = torch.ops.aten.transpose(permute_9, -1, -2); permute_9 = None\0A expand_8 = torch.ops.aten.expand(permute_8, [1, 12, 128, 64]); permute_8 = None\0A view_39 = torch.ops.aten.view(expand_8, [12, 128, 64]); expand_8 = None\0A expand_9 = torch.ops.aten.expand(transpose_2, [1, 12, 64, 128]); transpose_2 = None\0A view_40 = torch.ops.aten.view(expand_9, [12, 64, 128]); expand_9 = None\0A bmm_4 = torch.ops.aten.bmm(view_39, view_40); view_39 = view_40 = None\0A _unsafe_view_4 = torch.ops.aten._unsafe_view(bmm_4, [1, 12, 128, 128]); bmm_4 = None\0A _tensor_constant6 = self._tensor_constant6\0A lift_fresh_copy_4 = torch.ops.aten.lift_fresh_copy(_tensor_constant6); _tensor_constant6 = None\0A div_4 = torch.ops.aten.div(_unsafe_view_4, lift_fresh_copy_4); _unsafe_view_4 = lift_fresh_copy_4 = None\0A _tensor_constant7 = self._tensor_constant7\0A as_strided_17 = torch.ops.aten.as_strided(_tensor_constant7, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); _tensor_constant7 = None\0A as_strided_18 = torch.ops.aten.as_strided(as_strided_17, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_17 = None\0A as_strided_19 = torch.ops.aten.as_strided(as_strided_18, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_18 = None\0A as_strided_20 = torch.ops.aten.as_strided(as_strided_19, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0); as_strided_19 = None\0A convert_element_type_2 = torch.ops.prims.convert_element_type(as_strided_20, torch.bool); as_strided_20 = None\0A _tensor_constant8 = self._tensor_constant8\0A lift_fresh_copy_5 = torch.ops.aten.lift_fresh_copy(_tensor_constant8); _tensor_constant8 = None\0A where_2 = torch.ops.aten.where(convert_element_type_2, div_4, lift_fresh_copy_5); convert_element_type_2 = div_4 = lift_fresh_copy_5 = None\0A amax_2 = torch.ops.aten.amax(where_2, [-1], True)\0A sub_2 = torch.ops.aten.sub(where_2, amax_2); where_2 = amax_2 = None\0A exp_2 = torch.ops.aten.exp(sub_2); sub_2 = None\0A sum_3 = torch.ops.aten.sum(exp_2, [-1], True)\0A div_5 = torch.ops.aten.div(exp_2, sum_3); exp_2 = sum_3 = None\0A detach_4 = torch.ops.aten.detach(div_5)\0A expand_10 = torch.ops.aten.expand(div_5, [1, 12, 128, 128]); div_5 = None\0A view_41 = torch.ops.aten.view(expand_10, [12, 128, 128]); expand_10 = None\0A expand_11 = torch.ops.aten.expand(permute_10, [1, 12, 128, 64]); permute_10 = None\0A view_42 = torch.ops.aten.view(expand_11, [12, 128, 64]); expand_11 = None\0A bmm_5 = torch.ops.aten.bmm(view_41, view_42); view_41 = view_42 = None\0A _unsafe_view_5 = torch.ops.aten._unsafe_view(bmm_5, [1, 12, 128, 64]); bmm_5 = None\0A permute_11 = torch.ops.aten.permute(_unsafe_view_5, [0, 2, 1, 3]); _unsafe_view_5 = None\0A clone_2 = torch.ops.aten.clone(permute_11, memory_format = torch.contiguous_format); permute_11 = None\0A view_43 = torch.ops.aten.view(clone_2, [1, 128, 768]); clone_2 = None\0A view_44 = torch.ops.aten.view(view_43, [-1, 768]); view_43 = None\0A _param_constant30 = self._param_constant30\0A _param_constant31 = self._param_constant31\0A addmm_9 = torch.ops.aten.addmm(_param_constant30, view_44, _param_constant31); _param_constant30 = view_44 = _param_constant31 = None\0A view_45 = torch.ops.aten.view(addmm_9, [1, 128, 768]); addmm_9 = None\0A add_9 = torch.ops.aten.add(view_45, add_8); view_45 = add_8 = None\0A _param_constant32 = self._param_constant32\0A _param_constant33 = self._param_constant33\0A native_layer_norm_5 = torch.ops.aten.native_layer_norm(add_9, [768], _param_constant32, _param_constant33, 1e-05); _param_constant32 = _param_constant33 = None\0A getitem_15 = native_layer_norm_5[0]\0A getitem_16 = native_layer_norm_5[1]\0A getitem_17 = native_layer_norm_5[2]; native_layer_norm_5 = None\0A view_46 = torch.ops.aten.view(getitem_15, [-1, 768]); getitem_15 = None\0A _param_constant34 = self._param_constant34\0A _param_constant35 = self._param_constant35\0A addmm_10 = torch.ops.aten.addmm(_param_constant34, view_46, _param_constant35); _param_constant34 = view_46 = _param_constant35 = None\0A view_47 = torch.ops.aten.view(addmm_10, [1, 128, 3072]); addmm_10 = None\0A mul_8 = torch.ops.aten.mul(view_47, 0.5)\0A pow_3 = torch.ops.aten.pow(view_47, 3.0)\0A mul_9 = torch.ops.aten.mul(pow_3, 0.044715); pow_3 = None\0A add_10 = torch.ops.aten.add(view_47, mul_9); view_47 = mul_9 = None\0A mul_10 = torch.ops.aten.mul(add_10, 0.7978845608028654); add_10 = None\0A tanh_2 = torch.ops.aten.tanh(mul_10); mul_10 = None\0A detach_5 = torch.ops.aten.detach(tanh_2)\0A add_11 = torch.ops.aten.add(tanh_2, 1.0); tanh_2 = None\0A mul_11 = torch.ops.aten.mul(mul_8, add_11); mul_8 = add_11 = None\0A view_48 = torch.ops.aten.view(mul_11, [-1, 3072]); mul_11 = None\0A _param_constant36 = self._param_constant36\0A _param_constant37 = self._param_constant37\0A addmm_11 = torch.ops.aten.addmm(_param_constant36, view_48, _param_constant37); _param_constant36 = view_48 = _param_constant37 = None\0A view_49 = torch.ops.aten.view(addmm_11, [1, 128, 768]); addmm_11 = None\0A add_12 = torch.ops.aten.add(add_9, view_49); add_9 = view_49 = None\0A _param_constant38 = self._param_constant38\0A _param_constant39 = self._param_constant39\0A native_layer_norm_6 = torch.ops.aten.native_layer_norm(add_12, [768], _param_constant38, _param_constant39, 1e-05); _param_constant38 = _param_constant39 = None\0A getitem_18 = native_layer_norm_6[0]\0A getitem_19 = native_layer_norm_6[1]\0A getitem_20 = native_layer_norm_6[2]; native_layer_norm_6 = None\0A view_50 = torch.ops.aten.view(getitem_18, [-1, 768]); getitem_18 = None\0A _param_constant40 = self._param_constant40\0A _param_constant41 = self._param_constant41\0A addmm_12 = torch.ops.aten.addmm(_param_constant40, view_50, _param_constant41); _param_constant40 = view_50 = _param_constant41 = None\0A view_51 = torch.ops.aten.view(addmm_12, [1, 128, 2304]); addmm_12 = None\0A as_strided_21 = torch.ops.aten.as_strided(view_51, [1, 128, 768], [294912, 2304, 1], 0)\0A as_strided_22 = torch.ops.aten.as_strided(view_51, [1, 128, 768], [294912, 2304, 1], 768)\0A as_strided_23 = torch.ops.aten.as_strided(view_51, [1, 128, 768], [294912, 2304, 1], 1536); view_51 = None\0A view_52 = torch.ops.aten.view(as_strided_21, [1, 128, 12, 64]); as_strided_21 = None\0A permute_12 = torch.ops.aten.permute(view_52, [0, 2, 1, 3]); view_52 = None\0A view_53 = torch.ops.aten.view(as_strided_22, [1, 128, 12, 64]); as_strided_22 = None\0A permute_13 = torch.ops.aten.permute(view_53, [0, 2, 1, 3]); view_53 = None\0A view_54 = torch.ops.aten.view(as_strided_23, [1, 128, 12, 64]); as_strided_23 = None\0A permute_14 = torch.ops.aten.permute(view_54, [0, 2, 1, 3]); view_54 = None\0A transpose_3 = torch.ops.aten.transpose(permute_13, -1, -2); permute_13 = None\0A expand_12 = torch.ops.aten.expand(permute_12, [1, 12, 128, 64]); permute_12 = None\0A view_55 = torch.ops.aten.view(expand_12, [12, 128, 64]); expand_12 = None\0A expand_13 = torch.ops.aten.expand(transpose_3, [1, 12, 64, 128]); transpose_3 = None\0A view_56 = torch.ops.aten.view(expand_13, [12, 64, 128]); expand_13 = None\0A bmm_6 = torch.ops.aten.bmm(view_55, view_56); view_55 = view_56 = None\0A _unsafe_view_6 = torch.ops.aten._unsafe_view(bmm_6, [1, 12, 128, 128]); bmm_6 = None\0A _tensor_constant9 = self._tensor_constant9\0A lift_fresh_copy_6 = torch.ops.aten.lift_fresh_copy(_tensor_constant9); _tensor_constant9 = None\0A div_6 = torch.ops.aten.div(_unsafe_view_6, lift_fresh_copy_6); _unsafe_view_6 = lift_fresh_copy_6 = None\0A _tensor_constant10 = self._tensor_constant10\0A as_strided_24 = torch.ops.aten.as_strided(_tensor_constant10, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); _tensor_constant10 = None\0A as_strided_25 = torch.ops.aten.as_strided(as_strided_24, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_24 = None\0A as_strided_26 = torch.ops.aten.as_strided(as_strided_25, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_25 = None\0A as_strided_27 = torch.ops.aten.as_strided(as_strided_26, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0); as_strided_26 = None\0A convert_element_type_3 = torch.ops.prims.convert_element_type(as_strided_27, torch.bool); as_strided_27 = None\0A _tensor_constant11 = self._tensor_constant11\0A lift_fresh_copy_7 = torch.ops.aten.lift_fresh_copy(_tensor_constant11); _tensor_constant11 = None\0A where_3 = torch.ops.aten.where(convert_element_type_3, div_6, lift_fresh_copy_7); convert_element_type_3 = div_6 = lift_fresh_copy_7 = None\0A amax_3 = torch.ops.aten.amax(where_3, [-1], True)\0A sub_3 = torch.ops.aten.sub(where_3, amax_3); where_3 = amax_3 = None\0A exp_3 = torch.ops.aten.exp(sub_3); sub_3 = None\0A sum_4 = torch.ops.aten.sum(exp_3, [-1], True)\0A div_7 = torch.ops.aten.div(exp_3, sum_4); exp_3 = sum_4 = None\0A detach_6 = torch.ops.aten.detach(div_7)\0A expand_14 = torch.ops.aten.expand(div_7, [1, 12, 128, 128]); div_7 = None\0A view_57 = torch.ops.aten.view(expand_14, [12, 128, 128]); expand_14 = None\0A expand_15 = torch.ops.aten.expand(permute_14, [1, 12, 128, 64]); permute_14 = None\0A view_58 = torch.ops.aten.view(expand_15, [12, 128, 64]); expand_15 = None\0A bmm_7 = torch.ops.aten.bmm(view_57, view_58); view_57 = view_58 = None\0A _unsafe_view_7 = torch.ops.aten._unsafe_view(bmm_7, [1, 12, 128, 64]); bmm_7 = None\0A permute_15 = torch.ops.aten.permute(_unsafe_view_7, [0, 2, 1, 3]); _unsafe_view_7 = None\0A clone_3 = torch.ops.aten.clone(permute_15, memory_format = torch.contiguous_format); permute_15 = None\0A view_59 = torch.ops.aten.view(clone_3, [1, 128, 768]); clone_3 = None\0A view_60 = torch.ops.aten.view(view_59, [-1, 768]); view_59 = None\0A _param_constant42 = self._param_constant42\0A _param_constant43 = self._param_constant43\0A addmm_13 = torch.ops.aten.addmm(_param_constant42, view_60, _param_constant43); _param_constant42 = view_60 = _param_constant43 = None\0A view_61 = torch.ops.aten.view(addmm_13, [1, 128, 768]); addmm_13 = None\0A add_13 = torch.ops.aten.add(view_61, add_12); view_61 = add_12 = None\0A _param_constant44 = self._param_constant44\0A _param_constant45 = self._param_constant45\0A native_layer_norm_7 = torch.ops.aten.native_layer_norm(add_13, [768], _param_constant44, _param_constant45, 1e-05); _param_constant44 = _param_constant45 = None\0A getitem_21 = native_layer_norm_7[0]\0A getitem_22 = native_layer_norm_7[1]\0A getitem_23 = native_layer_norm_7[2]; native_layer_norm_7 = None\0A view_62 = torch.ops.aten.view(getitem_21, [-1, 768]); getitem_21 = None\0A _param_constant46 = self._param_constant46\0A _param_constant47 = self._param_constant47\0A addmm_14 = torch.ops.aten.addmm(_param_constant46, view_62, _param_constant47); _param_constant46 = view_62 = _param_constant47 = None\0A view_63 = torch.ops.aten.view(addmm_14, [1, 128, 3072]); addmm_14 = None\0A mul_12 = torch.ops.aten.mul(view_63, 0.5)\0A pow_4 = torch.ops.aten.pow(view_63, 3.0)\0A mul_13 = torch.ops.aten.mul(pow_4, 0.044715); pow_4 = None\0A add_14 = torch.ops.aten.add(view_63, mul_13); view_63 = mul_13 = None\0A mul_14 = torch.ops.aten.mul(add_14, 0.7978845608028654); add_14 = None\0A tanh_3 = torch.ops.aten.tanh(mul_14); mul_14 = None\0A detach_7 = torch.ops.aten.detach(tanh_3)\0A add_15 = torch.ops.aten.add(tanh_3, 1.0); tanh_3 = None\0A mul_15 = torch.ops.aten.mul(mul_12, add_15); mul_12 = add_15 = None\0A view_64 = torch.ops.aten.view(mul_15, [-1, 3072]); mul_15 = None\0A _param_constant48 = self._param_constant48\0A _param_constant49 = self._param_constant49\0A addmm_15 = torch.ops.aten.addmm(_param_constant48, view_64, _param_constant49); _param_constant48 = view_64 = _param_constant49 = None\0A view_65 = torch.ops.aten.view(addmm_15, [1, 128, 768]); addmm_15 = None\0A add_16 = torch.ops.aten.add(add_13, view_65); add_13 = view_65 = None\0A _param_constant50 = self._param_constant50\0A _param_constant51 = self._param_constant51\0A native_layer_norm_8 = torch.ops.aten.native_layer_norm(add_16, [768], _param_constant50, _param_constant51, 1e-05); _param_constant50 = _param_constant51 = None\0A getitem_24 = native_layer_norm_8[0]\0A getitem_25 = native_layer_norm_8[1]\0A getitem_26 = native_layer_norm_8[2]; native_layer_norm_8 = None\0A view_66 = torch.ops.aten.view(getitem_24, [-1, 768]); getitem_24 = None\0A _param_constant52 = self._param_constant52\0A _param_constant53 = self._param_constant53\0A addmm_16 = torch.ops.aten.addmm(_param_constant52, view_66, _param_constant53); _param_constant52 = view_66 = _param_constant53 = None\0A view_67 = torch.ops.aten.view(addmm_16, [1, 128, 2304]); addmm_16 = None\0A as_strided_28 = torch.ops.aten.as_strided(view_67, [1, 128, 768], [294912, 2304, 1], 0)\0A as_strided_29 = torch.ops.aten.as_strided(view_67, [1, 128, 768], [294912, 2304, 1], 768)\0A as_strided_30 = torch.ops.aten.as_strided(view_67, [1, 128, 768], [294912, 2304, 1], 1536); view_67 = None\0A view_68 = torch.ops.aten.view(as_strided_28, [1, 128, 12, 64]); as_strided_28 = None\0A permute_16 = torch.ops.aten.permute(view_68, [0, 2, 1, 3]); view_68 = None\0A view_69 = torch.ops.aten.view(as_strided_29, [1, 128, 12, 64]); as_strided_29 = None\0A permute_17 = torch.ops.aten.permute(view_69, [0, 2, 1, 3]); view_69 = None\0A view_70 = torch.ops.aten.view(as_strided_30, [1, 128, 12, 64]); as_strided_30 = None\0A permute_18 = torch.ops.aten.permute(view_70, [0, 2, 1, 3]); view_70 = None\0A transpose_4 = torch.ops.aten.transpose(permute_17, -1, -2); permute_17 = None\0A expand_16 = torch.ops.aten.expand(permute_16, [1, 12, 128, 64]); permute_16 = None\0A view_71 = torch.ops.aten.view(expand_16, [12, 128, 64]); expand_16 = None\0A expand_17 = torch.ops.aten.expand(transpose_4, [1, 12, 64, 128]); transpose_4 = None\0A view_72 = torch.ops.aten.view(expand_17, [12, 64, 128]); expand_17 = None\0A bmm_8 = torch.ops.aten.bmm(view_71, view_72); view_71 = view_72 = None\0A _unsafe_view_8 = torch.ops.aten._unsafe_view(bmm_8, [1, 12, 128, 128]); bmm_8 = None\0A _tensor_constant12 = self._tensor_constant12\0A lift_fresh_copy_8 = torch.ops.aten.lift_fresh_copy(_tensor_constant12); _tensor_constant12 = None\0A div_8 = torch.ops.aten.div(_unsafe_view_8, lift_fresh_copy_8); _unsafe_view_8 = lift_fresh_copy_8 = None\0A _tensor_constant13 = self._tensor_constant13\0A as_strided_31 = torch.ops.aten.as_strided(_tensor_constant13, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); _tensor_constant13 = None\0A as_strided_32 = torch.ops.aten.as_strided(as_strided_31, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_31 = None\0A as_strided_33 = torch.ops.aten.as_strided(as_strided_32, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_32 = None\0A as_strided_34 = torch.ops.aten.as_strided(as_strided_33, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0); as_strided_33 = None\0A convert_element_type_4 = torch.ops.prims.convert_element_type(as_strided_34, torch.bool); as_strided_34 = None\0A _tensor_constant14 = self._tensor_constant14\0A lift_fresh_copy_9 = torch.ops.aten.lift_fresh_copy(_tensor_constant14); _tensor_constant14 = None\0A where_4 = torch.ops.aten.where(convert_element_type_4, div_8, lift_fresh_copy_9); convert_element_type_4 = div_8 = lift_fresh_copy_9 = None\0A amax_4 = torch.ops.aten.amax(where_4, [-1], True)\0A sub_4 = torch.ops.aten.sub(where_4, amax_4); where_4 = amax_4 = None\0A exp_4 = torch.ops.aten.exp(sub_4); sub_4 = None\0A sum_5 = torch.ops.aten.sum(exp_4, [-1], True)\0A div_9 = torch.ops.aten.div(exp_4, sum_5); exp_4 = sum_5 = None\0A detach_8 = torch.ops.aten.detach(div_9)\0A expand_18 = torch.ops.aten.expand(div_9, [1, 12, 128, 128]); div_9 = None\0A view_73 = torch.ops.aten.view(expand_18, [12, 128, 128]); expand_18 = None\0A expand_19 = torch.ops.aten.expand(permute_18, [1, 12, 128, 64]); permute_18 = None\0A view_74 = torch.ops.aten.view(expand_19, [12, 128, 64]); expand_19 = None\0A bmm_9 = torch.ops.aten.bmm(view_73, view_74); view_73 = view_74 = None\0A _unsafe_view_9 = torch.ops.aten._unsafe_view(bmm_9, [1, 12, 128, 64]); bmm_9 = None\0A permute_19 = torch.ops.aten.permute(_unsafe_view_9, [0, 2, 1, 3]); _unsafe_view_9 = None\0A clone_4 = torch.ops.aten.clone(permute_19, memory_format = torch.contiguous_format); permute_19 = None\0A view_75 = torch.ops.aten.view(clone_4, [1, 128, 768]); clone_4 = None\0A view_76 = torch.ops.aten.view(view_75, [-1, 768]); view_75 = None\0A _param_constant54 = self._param_constant54\0A _param_constant55 = self._param_constant55\0A addmm_17 = torch.ops.aten.addmm(_param_constant54, view_76, _param_constant55); _param_constant54 = view_76 = _param_constant55 = None\0A view_77 = torch.ops.aten.view(addmm_17, [1, 128, 768]); addmm_17 = None\0A add_17 = torch.ops.aten.add(view_77, add_16); view_77 = add_16 = None\0A _param_constant56 = self._param_constant56\0A _param_constant57 = self._param_constant57\0A native_layer_norm_9 = torch.ops.aten.native_layer_norm(add_17, [768], _param_constant56, _param_constant57, 1e-05); _param_constant56 = _param_constant57 = None\0A getitem_27 = native_layer_norm_9[0]\0A getitem_28 = native_layer_norm_9[1]\0A getitem_29 = native_layer_norm_9[2]; native_layer_norm_9 = None\0A view_78 = torch.ops.aten.view(getitem_27, [-1, 768]); getitem_27 = None\0A _param_constant58 = self._param_constant58\0A _param_constant59 = self._param_constant59\0A addmm_18 = torch.ops.aten.addmm(_param_constant58, view_78, _param_constant59); _param_constant58 = view_78 = _param_constant59 = None\0A view_79 = torch.ops.aten.view(addmm_18, [1, 128, 3072]); addmm_18 = None\0A mul_16 = torch.ops.aten.mul(view_79, 0.5)\0A pow_5 = torch.ops.aten.pow(view_79, 3.0)\0A mul_17 = torch.ops.aten.mul(pow_5, 0.044715); pow_5 = None\0A add_18 = torch.ops.aten.add(view_79, mul_17); view_79 = mul_17 = None\0A mul_18 = torch.ops.aten.mul(add_18, 0.7978845608028654); add_18 = None\0A tanh_4 = torch.ops.aten.tanh(mul_18); mul_18 = None\0A detach_9 = torch.ops.aten.detach(tanh_4)\0A add_19 = torch.ops.aten.add(tanh_4, 1.0); tanh_4 = None\0A mul_19 = torch.ops.aten.mul(mul_16, add_19); mul_16 = add_19 = None\0A view_80 = torch.ops.aten.view(mul_19, [-1, 3072]); mul_19 = None\0A _param_constant60 = self._param_constant60\0A _param_constant61 = self._param_constant61\0A addmm_19 = torch.ops.aten.addmm(_param_constant60, view_80, _param_constant61); _param_constant60 = view_80 = _param_constant61 = None\0A view_81 = torch.ops.aten.view(addmm_19, [1, 128, 768]); addmm_19 = None\0A add_20 = torch.ops.aten.add(add_17, view_81); add_17 = view_81 = None\0A _param_constant62 = self._param_constant62\0A _param_constant63 = self._param_constant63\0A native_layer_norm_10 = torch.ops.aten.native_layer_norm(add_20, [768], _param_constant62, _param_constant63, 1e-05); _param_constant62 = _param_constant63 = None\0A getitem_30 = native_layer_norm_10[0]\0A getitem_31 = native_layer_norm_10[1]\0A getitem_32 = native_layer_norm_10[2]; native_layer_norm_10 = None\0A view_82 = torch.ops.aten.view(getitem_30, [-1, 768]); getitem_30 = None\0A _param_constant64 = self._param_constant64\0A _param_constant65 = self._param_constant65\0A addmm_20 = torch.ops.aten.addmm(_param_constant64, view_82, _param_constant65); _param_constant64 = view_82 = _param_constant65 = None\0A view_83 = torch.ops.aten.view(addmm_20, [1, 128, 2304]); addmm_20 = None\0A as_strided_35 = torch.ops.aten.as_strided(view_83, [1, 128, 768], [294912, 2304, 1], 0)\0A as_strided_36 = torch.ops.aten.as_strided(view_83, [1, 128, 768], [294912, 2304, 1], 768)\0A as_strided_37 = torch.ops.aten.as_strided(view_83, [1, 128, 768], [294912, 2304, 1], 1536); view_83 = None\0A view_84 = torch.ops.aten.view(as_strided_35, [1, 128, 12, 64]); as_strided_35 = None\0A permute_20 = torch.ops.aten.permute(view_84, [0, 2, 1, 3]); view_84 = None\0A view_85 = torch.ops.aten.view(as_strided_36, [1, 128, 12, 64]); as_strided_36 = None\0A permute_21 = torch.ops.aten.permute(view_85, [0, 2, 1, 3]); view_85 = None\0A view_86 = torch.ops.aten.view(as_strided_37, [1, 128, 12, 64]); as_strided_37 = None\0A permute_22 = torch.ops.aten.permute(view_86, [0, 2, 1, 3]); view_86 = None\0A transpose_5 = torch.ops.aten.transpose(permute_21, -1, -2); permute_21 = None\0A expand_20 = torch.ops.aten.expand(permute_20, [1, 12, 128, 64]); permute_20 = None\0A view_87 = torch.ops.aten.view(expand_20, [12, 128, 64]); expand_20 = None\0A expand_21 = torch.ops.aten.expand(transpose_5, [1, 12, 64, 128]); transpose_5 = None\0A view_88 = torch.ops.aten.view(expand_21, [12, 64, 128]); expand_21 = None\0A bmm_10 = torch.ops.aten.bmm(view_87, view_88); view_87 = view_88 = None\0A _unsafe_view_10 = torch.ops.aten._unsafe_view(bmm_10, [1, 12, 128, 128]); bmm_10 = None\0A _tensor_constant15 = self._tensor_constant15\0A lift_fresh_copy_10 = torch.ops.aten.lift_fresh_copy(_tensor_constant15); _tensor_constant15 = None\0A div_10 = torch.ops.aten.div(_unsafe_view_10, lift_fresh_copy_10); _unsafe_view_10 = lift_fresh_copy_10 = None\0A _tensor_constant16 = self._tensor_constant16\0A as_strided_38 = torch.ops.aten.as_strided(_tensor_constant16, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); _tensor_constant16 = None\0A as_strided_39 = torch.ops.aten.as_strided(as_strided_38, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_38 = None\0A as_strided_40 = torch.ops.aten.as_strided(as_strided_39, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0); as_strided_39 = None\0A as_strided_41 = torch.ops.aten.as_strided(as_strided_40, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0); as_strided_40 = None\0A convert_element_type_5 = torch.ops.prims.convert_element_type(as_strided_41, torch.bool); as_strided_41 = None\0A _tensor_constant17 = self._tensor_constant17\0A lift_fresh_copy_11 = torch.ops.aten.lift_fresh_copy(_tensor_constant17); _tensor_constant17 = None\0A where_5 = torch.ops.aten.where(convert_element_type_5, div_10, lift_fresh_copy_11); convert_element_type_5 = div_10 = lift_fresh_copy_11 = None\0A amax_5 = torch.ops.aten.amax(where_5, [-1], True)\0A sub_5 = torch.ops.aten.sub(where_5, amax_5); where_5 = amax_5 = None\0A exp_5 = torch.ops.aten.exp(sub_5); sub_5 = None\0A sum_6 = torch.ops.aten.sum(exp_5, [-1], True)\0A div_11 = torch.ops.aten.div(exp_5, sum_6); exp_5 = sum_6 = None\0A detach_10 = torch.ops.aten.detach(div_11)\0A expand_22 = torch.ops.aten.expand(div_11, [1, 12, 128, 128]); div_11 = None\0A view_89 = torch.ops.aten.view(expand_22, [12, 128, 128]); expand_22 = None\0A expand_23 = torch.ops.aten.expand(permute_22, [1, 12, 128, 64]); permute_22 = None\0A view_90 = torch.ops.aten.view(expand_23, [12, 128, 64]); expand_23 = None\0A bmm_11 = torch.ops.aten.bmm(view_89, view_90); view_89 = view_90 = None\0A _unsafe_view_11 = torch.ops.aten._unsafe_view(bmm_11, [1, 12, 128, 64]); bmm_11 = None\0A permute_23 = torch.ops.aten.permute(_unsafe_view_11, [0, 2, 1, 3]); _unsafe_view_11 = None\0A clone_5 = torch.ops.aten.clone(permute_23, memory_format = torch.contiguous_format); permute_23 = None\0A view_91 = torch.ops.aten.view(clone_5, [1, 128, 768]); clone_5 = None\0A view_92 = torch.ops.aten.view(view_91, [-1, 768]); view_91 = None\0A _param_constant66 = self._param_constant66\0A _param_constant67 = self._param_constant67\0A addmm_21 = torch.ops.aten.addmm(_param_constant66, view_92, _param_constant67); _param_constant66 = view_92 = _param_constant67 = None\0A view_93 = torch.ops.aten.view(addmm_21, [1, 128, 768]); addmm_21 = None\0A add_21 = torch.ops.aten.add(view_93, add_20); view_93 = add_20 = None\0A _param_constant68 = self._param_constant68\0A _param_constant69 = self._param_constant69\0A native_layer_norm_11 = torch.ops.aten.native_layer_norm(add_21, [768], _param_constant68, _param_constant69, 1e-05); _param_constant68 = _param_constant69 = None\0A getitem_33 = native_layer_norm_11[0]\0A getitem_34 = native_layer_norm_11[1]\0A getitem_35 = native_layer_norm_11[2]; native_layer_norm_11 = None\0A view_94 = torch.ops.aten.view(getitem_33, [-1, 768]); getitem_33 = None\0A _param_constant70 = self._param_constant70\0A _param_constant71 = self._param_constant71\0A addmm_22 = torch.ops.aten.addmm(_param_constant70, view_94, _param_constant71); _param_constant70 = view_94 = _param_constant71 = None\0A view_95 = torch.ops.aten.view(addmm_22, [1, 128, 3072]); addmm_22 = None\0A mul_20 = torch.ops.aten.mul(view_95, 0.5)\0A pow_6 = torch.ops.aten.pow(view_95, 3.0)\0A mul_21 = torch.ops.aten.mul(pow_6, 0.044715); pow_6 = None\0A add_22 = torch.ops.aten.add(view_95, mul_21); view_95 = mul_21 = None\0A mul_22 = torch.ops.aten.mul(add_22, 0.7978845608028654); add_22 = None\0A tanh_5 = torch.ops.aten.tanh(mul_22); mul_22 = None\0A detach_11 = torch.ops.aten.detach(tanh_5)\0A add_23 = torch.ops.aten.add(tanh_5, 1.0); tanh_5 = None\0A mul_23 = torch.ops.aten.mul(mul_20, add_23); mul_20 = add_23 = None\0A view_96 = torch.ops.aten.view(mul_23, [-1, 3072]); mul_23 = None\0A _param_constant72 = self._param_constant72\0A _param_constant73 = self._param_constant73\0A addmm_23 = torch.ops.aten.addmm(_param_constant72, view_96, _param_constant73); _param_constant72 = view_96 = _param_constant73 = None\0A view_97 = torch.ops.aten.view(addmm_23, [1, 128, 768]); addmm_23 = None\0A add_24 = torch.ops.aten.add(add_21, view_97); add_21 = view_97 = None\0A _param_constant74 = self._param_constant74\0A _param_constant75 = self._param_constant75\0A native_layer_norm_12 = torch.ops.aten.native_layer_norm(add_24, [768], _param_constant74, _param_constant75, 1e-05); add_24 = _param_constant74 = _param_constant75 = None\0A getitem_36 = native_layer_norm_12[0]\0A getitem_37 = native_layer_norm_12[1]\0A getitem_38 = native_layer_norm_12[2]; native_layer_norm_12 = None\0A view_98 = torch.ops.aten.view(getitem_36, [1, 128, 768]); getitem_36 = None\0A _param_constant76 = self._param_constant76\0A t = torch.ops.aten.t(_param_constant76); _param_constant76 = None\0A view_99 = torch.ops.aten.view(view_98, [128, 768]); view_98 = None\0A mm = torch.ops.aten.mm(view_99, t); view_99 = t = None\0A _unsafe_view_12 = torch.ops.aten._unsafe_view(mm, [1, 128, 2]); mm = None\0A arange_1 = torch.ops.aten.arange(1, device = device(type='cpu'), pin_memory = False)\0A select = torch.ops.aten.select(_unsafe_view_12, 1, -1); _unsafe_view_12 = None\0A index = torch.ops.aten.index(select, [arange_1]); select = arange_1 = None\0A return index\0A " loc(#loc)
%95 = torch.nn_module {
torch.slot "_param_constant0", %0 : !torch.tensor<[50257,768],f32> loc(#loc)
torch.slot "_param_constant1", %1 : !torch.tensor<[1024,768],f32> loc(#loc)
torch.slot "_param_constant2", %2 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant3", %3 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant4", %4 : !torch.tensor<[2304],f32> loc(#loc)
torch.slot "_param_constant5", %5 : !torch.tensor<[768,2304],f32> loc(#loc)
torch.slot "_param_constant6", %6 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant7", %7 : !torch.tensor<[768,768],f32> loc(#loc)
torch.slot "_param_constant8", %8 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant9", %9 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant10", %10 : !torch.tensor<[3072],f32> loc(#loc)
torch.slot "_param_constant11", %11 : !torch.tensor<[768,3072],f32> loc(#loc)
torch.slot "_param_constant12", %12 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant13", %13 : !torch.tensor<[3072,768],f32> loc(#loc)
torch.slot "_param_constant14", %14 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant15", %15 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant16", %16 : !torch.tensor<[2304],f32> loc(#loc)
torch.slot "_param_constant17", %17 : !torch.tensor<[768,2304],f32> loc(#loc)
torch.slot "_param_constant18", %18 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant19", %19 : !torch.tensor<[768,768],f32> loc(#loc)
torch.slot "_param_constant20", %20 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant21", %21 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant22", %22 : !torch.tensor<[3072],f32> loc(#loc)
torch.slot "_param_constant23", %23 : !torch.tensor<[768,3072],f32> loc(#loc)
torch.slot "_param_constant24", %24 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant25", %25 : !torch.tensor<[3072,768],f32> loc(#loc)
torch.slot "_param_constant26", %26 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant27", %27 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant28", %28 : !torch.tensor<[2304],f32> loc(#loc)
torch.slot "_param_constant29", %29 : !torch.tensor<[768,2304],f32> loc(#loc)
torch.slot "_param_constant30", %30 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant31", %31 : !torch.tensor<[768,768],f32> loc(#loc)
torch.slot "_param_constant32", %32 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant33", %33 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant34", %34 : !torch.tensor<[3072],f32> loc(#loc)
torch.slot "_param_constant35", %35 : !torch.tensor<[768,3072],f32> loc(#loc)
torch.slot "_param_constant36", %36 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant37", %37 : !torch.tensor<[3072,768],f32> loc(#loc)
torch.slot "_param_constant38", %38 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant39", %39 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant40", %40 : !torch.tensor<[2304],f32> loc(#loc)
torch.slot "_param_constant41", %41 : !torch.tensor<[768,2304],f32> loc(#loc)
torch.slot "_param_constant42", %42 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant43", %43 : !torch.tensor<[768,768],f32> loc(#loc)
torch.slot "_param_constant44", %44 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant45", %45 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant46", %46 : !torch.tensor<[3072],f32> loc(#loc)
torch.slot "_param_constant47", %47 : !torch.tensor<[768,3072],f32> loc(#loc)
torch.slot "_param_constant48", %48 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant49", %49 : !torch.tensor<[3072,768],f32> loc(#loc)
torch.slot "_param_constant50", %50 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant51", %51 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant52", %52 : !torch.tensor<[2304],f32> loc(#loc)
torch.slot "_param_constant53", %53 : !torch.tensor<[768,2304],f32> loc(#loc)
torch.slot "_param_constant54", %54 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant55", %55 : !torch.tensor<[768,768],f32> loc(#loc)
torch.slot "_param_constant56", %56 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant57", %57 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant58", %58 : !torch.tensor<[3072],f32> loc(#loc)
torch.slot "_param_constant59", %59 : !torch.tensor<[768,3072],f32> loc(#loc)
torch.slot "_param_constant60", %60 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant61", %61 : !torch.tensor<[3072,768],f32> loc(#loc)
torch.slot "_param_constant62", %62 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant63", %63 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant64", %64 : !torch.tensor<[2304],f32> loc(#loc)
torch.slot "_param_constant65", %65 : !torch.tensor<[768,2304],f32> loc(#loc)
torch.slot "_param_constant66", %66 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant67", %67 : !torch.tensor<[768,768],f32> loc(#loc)
torch.slot "_param_constant68", %68 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant69", %69 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant70", %70 : !torch.tensor<[3072],f32> loc(#loc)
torch.slot "_param_constant71", %71 : !torch.tensor<[768,3072],f32> loc(#loc)
torch.slot "_param_constant72", %72 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant73", %73 : !torch.tensor<[3072,768],f32> loc(#loc)
torch.slot "_param_constant74", %74 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant75", %75 : !torch.tensor<[768],f32> loc(#loc)
torch.slot "_param_constant76", %76 : !torch.tensor<[2,768],f32> loc(#loc)
torch.slot "_tensor_constant0", %77 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant1", %78 : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
torch.slot "_tensor_constant2", %79 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant3", %80 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant4", %81 : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
torch.slot "_tensor_constant5", %82 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant6", %83 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant7", %84 : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
torch.slot "_tensor_constant8", %85 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant9", %86 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant10", %87 : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
torch.slot "_tensor_constant11", %88 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant12", %89 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant13", %90 : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
torch.slot "_tensor_constant14", %91 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant15", %92 : !torch.tensor<[],f32> loc(#loc)
torch.slot "_tensor_constant16", %93 : !torch.tensor<[1,1,1024,1024],ui8> loc(#loc)
torch.slot "_tensor_constant17", %94 : !torch.tensor<[],f32> loc(#loc)
torch.slot "training", %true : !torch.bool loc(#loc)
torch.slot "_is_full_backward_hook", %none : !torch.none loc(#loc)
torch.slot "_code", %str : !torch.str loc(#loc)
} : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> loc(#loc)
} loc(#loc)
#loc1 = loc("<eval_with_key>.2":53:44)
#loc2 = loc("<eval_with_key>.2":49:78)
#loc3 = loc("<eval_with_key>.2":34:56)
#loc4 = loc("<eval_with_key>.2":6:106)
#loc5 = loc("<eval_with_key>.2":6:51)
#loc6 = loc("<eval_with_key>.2":5:40)
#loc7 = loc("<eval_with_key>.2":5:41)
#loc8 = loc("<eval_with_key>.2":5:44)
#loc9 = loc("<eval_with_key>.2":6:35)
#loc10 = loc("<eval_with_key>.2":16:63)
#loc11 = loc("<eval_with_key>.2":16:105)
#loc12 = loc("<eval_with_key>.2":19:34)
#loc13 = loc("<eval_with_key>.2":24:49)
#loc14 = loc("<eval_with_key>.2":25:67)
#loc15 = loc("<eval_with_key>.2":27:87)
#loc16 = loc("<eval_with_key>.2":28:54)
#loc17 = loc("<eval_with_key>.2":28:58)
#loc18 = loc("<eval_with_key>.2":29:55)
#loc19 = loc("<eval_with_key>.2":45:71)
#loc20 = loc("<eval_with_key>.2":45:85)
#loc21 = loc("<eval_with_key>.2":84:52)
#loc22 = loc("<eval_with_key>.2":85:38)
#loc23 = loc("<eval_with_key>.2":86:40)
#loc24 = loc("<eval_with_key>.2":87:38)
#loc25 = loc("<eval_with_key>.2":89:38)
#loc26 = loc("<eval_with_key>.2":92:37)
#loc27 = loc("<eval_with_key>.2":5:11)
#loc28 = loc("<eval_with_key>.2":6:13)
#loc29 = loc("<eval_with_key>.2":7:16)
#loc30 = loc("<eval_with_key>.2":8:13)
#loc31 = loc("<eval_with_key>.2":10:16)
#loc32 = loc("<eval_with_key>.2":12:18)
#loc33 = loc("<eval_with_key>.2":13:10)
#loc34 = loc("<eval_with_key>.2":16:24)
#loc35 = loc("<eval_with_key>.2":20:13)
#loc36 = loc("<eval_with_key>.2":23:12)
#loc37 = loc("<eval_with_key>.2":24:13)
#loc38 = loc("<eval_with_key>.2":25:17)
#loc39 = loc("<eval_with_key>.2":26:19)
#loc40 = loc("<eval_with_key>.2":27:19)
#loc41 = loc("<eval_with_key>.2":28:13)
#loc42 = loc("<eval_with_key>.2":29:14)
#loc43 = loc("<eval_with_key>.2":30:13)
#loc44 = loc("<eval_with_key>.2":31:16)
#loc45 = loc("<eval_with_key>.2":32:13)
#loc46 = loc("<eval_with_key>.2":33:16)
#loc47 = loc("<eval_with_key>.2":34:16)
#loc48 = loc("<eval_with_key>.2":35:13)
#loc49 = loc("<eval_with_key>.2":36:13)
#loc50 = loc("<eval_with_key>.2":37:15)
#loc51 = loc("<eval_with_key>.2":38:13)
#loc52 = loc("<eval_with_key>.2":39:10)
#loc53 = loc("<eval_with_key>.2":40:19)
#loc54 = loc("<eval_with_key>.2":42:22)
#loc55 = loc("<eval_with_key>.2":43:10)
#loc56 = loc("<eval_with_key>.2":45:19)
#loc57 = loc("<eval_with_key>.2":46:19)
#loc58 = loc("<eval_with_key>.2":47:19)
#loc59 = loc("<eval_with_key>.2":48:19)
#loc60 = loc("<eval_with_key>.2":49:27)
#loc61 = loc("<eval_with_key>.2":51:24)
#loc62 = loc("<eval_with_key>.2":52:12)
#loc63 = loc("<eval_with_key>.2":53:11)
#loc64 = loc("<eval_with_key>.2":54:10)
#loc65 = loc("<eval_with_key>.2":55:10)
#loc66 = loc("<eval_with_key>.2":56:12)
#loc67 = loc("<eval_with_key>.2":57:12)
#loc68 = loc("<eval_with_key>.2":59:15)
#loc69 = loc("<eval_with_key>.2":60:13)
#loc70 = loc("<eval_with_key>.2":61:15)
#loc71 = loc("<eval_with_key>.2":62:14)
#loc72 = loc("<eval_with_key>.2":63:12)
#loc73 = loc("<eval_with_key>.2":64:21)
#loc74 = loc("<eval_with_key>.2":65:16)
#loc75 = loc("<eval_with_key>.2":66:12)
#loc76 = loc("<eval_with_key>.2":67:14)
#loc77 = loc("<eval_with_key>.2":68:14)
#loc78 = loc("<eval_with_key>.2":71:14)
#loc79 = loc("<eval_with_key>.2":72:14)
#loc80 = loc("<eval_with_key>.2":73:12)
#loc81 = loc("<eval_with_key>.2":76:26)
#loc82 = loc("<eval_with_key>.2":80:14)
#loc83 = loc("<eval_with_key>.2":83:14)
#loc84 = loc("<eval_with_key>.2":84:14)
#loc85 = loc("<eval_with_key>.2":85:10)
#loc86 = loc("<eval_with_key>.2":86:12)
#loc87 = loc("<eval_with_key>.2":87:12)
#loc88 = loc("<eval_with_key>.2":88:12)
#loc89 = loc("<eval_with_key>.2":89:12)
#loc90 = loc("<eval_with_key>.2":90:11)
#loc91 = loc("<eval_with_key>.2":92:12)
#loc92 = loc("<eval_with_key>.2":93:12)
#loc93 = loc("<eval_with_key>.2":94:14)
#loc94 = loc("<eval_with_key>.2":97:14)
#loc95 = loc("<eval_with_key>.2":98:14)
#loc96 = loc("<eval_with_key>.2":99:12)
#loc97 = loc("<eval_with_key>.2":102:26)
#loc98 = loc("<eval_with_key>.2":106:14)
#loc99 = loc("<eval_with_key>.2":109:14)
#loc100 = loc("<eval_with_key>.2":110:14)
#loc101 = loc("<eval_with_key>.2":111:19)
#loc102 = loc("<eval_with_key>.2":112:19)
#loc103 = loc("<eval_with_key>.2":113:19)
#loc104 = loc("<eval_with_key>.2":114:14)
#loc105 = loc("<eval_with_key>.2":115:16)
#loc106 = loc("<eval_with_key>.2":116:14)
#loc107 = loc("<eval_with_key>.2":117:16)
#loc108 = loc("<eval_with_key>.2":118:14)
#loc109 = loc("<eval_with_key>.2":119:16)
#loc110 = loc("<eval_with_key>.2":120:18)
#loc111 = loc("<eval_with_key>.2":121:15)
#loc112 = loc("<eval_with_key>.2":122:14)
#loc113 = loc("<eval_with_key>.2":123:15)
#loc114 = loc("<eval_with_key>.2":124:14)
#loc115 = loc("<eval_with_key>.2":125:12)
#loc116 = loc("<eval_with_key>.2":126:21)
#loc117 = loc("<eval_with_key>.2":128:24)
#loc118 = loc("<eval_with_key>.2":129:12)
#loc119 = loc("<eval_with_key>.2":131:20)
#loc120 = loc("<eval_with_key>.2":132:20)
#loc121 = loc("<eval_with_key>.2":133:20)
#loc122 = loc("<eval_with_key>.2":134:20)
#loc123 = loc("<eval_with_key>.2":135:29)
#loc124 = loc("<eval_with_key>.2":137:24)
#loc125 = loc("<eval_with_key>.2":138:14)
#loc126 = loc("<eval_with_key>.2":139:13)
#loc127 = loc("<eval_with_key>.2":140:12)
#loc128 = loc("<eval_with_key>.2":141:12)
#loc129 = loc("<eval_with_key>.2":142:12)
#loc130 = loc("<eval_with_key>.2":143:12)
#loc131 = loc("<eval_with_key>.2":145:15)
#loc132 = loc("<eval_with_key>.2":146:14)
#loc133 = loc("<eval_with_key>.2":147:15)
#loc134 = loc("<eval_with_key>.2":148:14)
#loc135 = loc("<eval_with_key>.2":149:12)
#loc136 = loc("<eval_with_key>.2":150:21)
#loc137 = loc("<eval_with_key>.2":151:16)
#loc138 = loc("<eval_with_key>.2":152:14)
#loc139 = loc("<eval_with_key>.2":153:14)
#loc140 = loc("<eval_with_key>.2":154:14)
#loc141 = loc("<eval_with_key>.2":157:14)
#loc142 = loc("<eval_with_key>.2":158:14)
#loc143 = loc("<eval_with_key>.2":159:12)
#loc144 = loc("<eval_with_key>.2":162:26)
#loc145 = loc("<eval_with_key>.2":166:14)
#loc146 = loc("<eval_with_key>.2":169:14)
#loc147 = loc("<eval_with_key>.2":170:14)
#loc148 = loc("<eval_with_key>.2":171:12)
#loc149 = loc("<eval_with_key>.2":172:12)
#loc150 = loc("<eval_with_key>.2":173:12)
#loc151 = loc("<eval_with_key>.2":174:12)
#loc152 = loc("<eval_with_key>.2":175:12)
#loc153 = loc("<eval_with_key>.2":176:13)
#loc154 = loc("<eval_with_key>.2":178:12)
#loc155 = loc("<eval_with_key>.2":179:12)
#loc156 = loc("<eval_with_key>.2":180:14)
#loc157 = loc("<eval_with_key>.2":183:14)
#loc158 = loc("<eval_with_key>.2":184:14)
#loc159 = loc("<eval_with_key>.2":185:12)
#loc160 = loc("<eval_with_key>.2":188:26)
#loc161 = loc("<eval_with_key>.2":192:14)
#loc162 = loc("<eval_with_key>.2":195:14)
#loc163 = loc("<eval_with_key>.2":196:14)
#loc164 = loc("<eval_with_key>.2":197:20)
#loc165 = loc("<eval_with_key>.2":198:20)
#loc166 = loc("<eval_with_key>.2":199:20)
#loc167 = loc("<eval_with_key>.2":200:14)
#loc168 = loc("<eval_with_key>.2":201:16)
#loc169 = loc("<eval_with_key>.2":202:14)
#loc170 = loc("<eval_with_key>.2":203:16)
#loc171 = loc("<eval_with_key>.2":204:14)
#loc172 = loc("<eval_with_key>.2":205:17)
#loc173 = loc("<eval_with_key>.2":206:18)
#loc174 = loc("<eval_with_key>.2":207:15)
#loc175 = loc("<eval_with_key>.2":208:14)
#loc176 = loc("<eval_with_key>.2":209:15)
#loc177 = loc("<eval_with_key>.2":210:14)
#loc178 = loc("<eval_with_key>.2":211:12)
#loc179 = loc("<eval_with_key>.2":212:21)
#loc180 = loc("<eval_with_key>.2":214:24)
#loc181 = loc("<eval_with_key>.2":215:12)
#loc182 = loc("<eval_with_key>.2":217:20)
#loc183 = loc("<eval_with_key>.2":218:20)
#loc184 = loc("<eval_with_key>.2":219:20)
#loc185 = loc("<eval_with_key>.2":220:20)
#loc186 = loc("<eval_with_key>.2":221:29)
#loc187 = loc("<eval_with_key>.2":223:24)
#loc188 = loc("<eval_with_key>.2":224:14)
#loc189 = loc("<eval_with_key>.2":225:13)
#loc190 = loc("<eval_with_key>.2":226:12)
#loc191 = loc("<eval_with_key>.2":227:12)
#loc192 = loc("<eval_with_key>.2":228:12)
#loc193 = loc("<eval_with_key>.2":229:12)
#loc194 = loc("<eval_with_key>.2":231:16)
#loc195 = loc("<eval_with_key>.2":232:14)
#loc196 = loc("<eval_with_key>.2":233:16)
#loc197 = loc("<eval_with_key>.2":234:14)
#loc198 = loc("<eval_with_key>.2":235:12)
#loc199 = loc("<eval_with_key>.2":236:21)
#loc200 = loc("<eval_with_key>.2":237:17)
#loc201 = loc("<eval_with_key>.2":238:14)
#loc202 = loc("<eval_with_key>.2":239:14)
#loc203 = loc("<eval_with_key>.2":240:14)
#loc204 = loc("<eval_with_key>.2":243:14)
#loc205 = loc("<eval_with_key>.2":244:14)
#loc206 = loc("<eval_with_key>.2":245:12)
#loc207 = loc("<eval_with_key>.2":248:26)
#loc208 = loc("<eval_with_key>.2":252:14)
#loc209 = loc("<eval_with_key>.2":255:15)
#loc210 = loc("<eval_with_key>.2":256:14)
#loc211 = loc("<eval_with_key>.2":257:12)
#loc212 = loc("<eval_with_key>.2":258:12)
#loc213 = loc("<eval_with_key>.2":259:12)
#loc214 = loc("<eval_with_key>.2":260:13)
#loc215 = loc("<eval_with_key>.2":261:13)
#loc216 = loc("<eval_with_key>.2":262:13)
#loc217 = loc("<eval_with_key>.2":264:13)
#loc218 = loc("<eval_with_key>.2":265:13)
#loc219 = loc("<eval_with_key>.2":266:14)
#loc220 = loc("<eval_with_key>.2":269:15)
#loc221 = loc("<eval_with_key>.2":270:14)
#loc222 = loc("<eval_with_key>.2":271:13)
#loc223 = loc("<eval_with_key>.2":274:26)
#loc224 = loc("<eval_with_key>.2":278:14)
#loc225 = loc("<eval_with_key>.2":281:15)
#loc226 = loc("<eval_with_key>.2":282:14)
#loc227 = loc("<eval_with_key>.2":283:20)
#loc228 = loc("<eval_with_key>.2":284:20)
#loc229 = loc("<eval_with_key>.2":285:20)
#loc230 = loc("<eval_with_key>.2":286:14)
#loc231 = loc("<eval_with_key>.2":287:17)
#loc232 = loc("<eval_with_key>.2":288:14)
#loc233 = loc("<eval_with_key>.2":289:17)
#loc234 = loc("<eval_with_key>.2":290:14)
#loc235 = loc("<eval_with_key>.2":291:17)
#loc236 = loc("<eval_with_key>.2":292:18)
#loc237 = loc("<eval_with_key>.2":293:16)
#loc238 = loc("<eval_with_key>.2":294:14)
#loc239 = loc("<eval_with_key>.2":295:16)
#loc240 = loc("<eval_with_key>.2":296:14)
#loc241 = loc("<eval_with_key>.2":297:12)
#loc242 = loc("<eval_with_key>.2":298:21)
#loc243 = loc("<eval_with_key>.2":300:24)
#loc244 = loc("<eval_with_key>.2":301:12)
#loc245 = loc("<eval_with_key>.2":303:20)
#loc246 = loc("<eval_with_key>.2":304:20)
#loc247 = loc("<eval_with_key>.2":305:20)
#loc248 = loc("<eval_with_key>.2":306:20)
#loc249 = loc("<eval_with_key>.2":307:29)
#loc250 = loc("<eval_with_key>.2":309:24)
#loc251 = loc("<eval_with_key>.2":310:14)
#loc252 = loc("<eval_with_key>.2":311:13)
#loc253 = loc("<eval_with_key>.2":312:12)
#loc254 = loc("<eval_with_key>.2":313:12)
#loc255 = loc("<eval_with_key>.2":314:12)
#loc256 = loc("<eval_with_key>.2":315:12)
#loc257 = loc("<eval_with_key>.2":317:16)
#loc258 = loc("<eval_with_key>.2":318:14)
#loc259 = loc("<eval_with_key>.2":319:16)
#loc260 = loc("<eval_with_key>.2":320:14)
#loc261 = loc("<eval_with_key>.2":321:12)
#loc262 = loc("<eval_with_key>.2":322:21)
#loc263 = loc("<eval_with_key>.2":323:17)
#loc264 = loc("<eval_with_key>.2":324:14)
#loc265 = loc("<eval_with_key>.2":325:14)
#loc266 = loc("<eval_with_key>.2":326:14)
#loc267 = loc("<eval_with_key>.2":329:15)
#loc268 = loc("<eval_with_key>.2":330:14)
#loc269 = loc("<eval_with_key>.2":331:13)
#loc270 = loc("<eval_with_key>.2":334:26)
#loc271 = loc("<eval_with_key>.2":338:14)
#loc272 = loc("<eval_with_key>.2":341:15)
#loc273 = loc("<eval_with_key>.2":342:14)
#loc274 = loc("<eval_with_key>.2":343:13)
#loc275 = loc("<eval_with_key>.2":344:12)
#loc276 = loc("<eval_with_key>.2":345:13)
#loc277 = loc("<eval_with_key>.2":346:13)
#loc278 = loc("<eval_with_key>.2":347:13)
#loc279 = loc("<eval_with_key>.2":348:13)
#loc280 = loc("<eval_with_key>.2":350:13)
#loc281 = loc("<eval_with_key>.2":351:13)
#loc282 = loc("<eval_with_key>.2":352:14)
#loc283 = loc("<eval_with_key>.2":355:15)
#loc284 = loc("<eval_with_key>.2":356:14)
#loc285 = loc("<eval_with_key>.2":357:13)
#loc286 = loc("<eval_with_key>.2":360:26)
#loc287 = loc("<eval_with_key>.2":364:14)
#loc288 = loc("<eval_with_key>.2":367:15)
#loc289 = loc("<eval_with_key>.2":368:14)
#loc290 = loc("<eval_with_key>.2":369:20)
#loc291 = loc("<eval_with_key>.2":370:20)
#loc292 = loc("<eval_with_key>.2":371:20)
#loc293 = loc("<eval_with_key>.2":372:14)
#loc294 = loc("<eval_with_key>.2":373:17)
#loc295 = loc("<eval_with_key>.2":374:14)
#loc296 = loc("<eval_with_key>.2":375:17)
#loc297 = loc("<eval_with_key>.2":376:14)
#loc298 = loc("<eval_with_key>.2":377:17)
#loc299 = loc("<eval_with_key>.2":378:18)
#loc300 = loc("<eval_with_key>.2":379:16)
#loc301 = loc("<eval_with_key>.2":380:14)
#loc302 = loc("<eval_with_key>.2":381:16)
#loc303 = loc("<eval_with_key>.2":382:14)
#loc304 = loc("<eval_with_key>.2":383:12)
#loc305 = loc("<eval_with_key>.2":384:21)
#loc306 = loc("<eval_with_key>.2":386:24)
#loc307 = loc("<eval_with_key>.2":387:12)
#loc308 = loc("<eval_with_key>.2":389:20)
#loc309 = loc("<eval_with_key>.2":390:20)
#loc310 = loc("<eval_with_key>.2":391:20)
#loc311 = loc("<eval_with_key>.2":392:20)
#loc312 = loc("<eval_with_key>.2":393:29)
#loc313 = loc("<eval_with_key>.2":395:24)
#loc314 = loc("<eval_with_key>.2":396:14)
#loc315 = loc("<eval_with_key>.2":397:13)
#loc316 = loc("<eval_with_key>.2":398:12)
#loc317 = loc("<eval_with_key>.2":399:12)
#loc318 = loc("<eval_with_key>.2":400:12)
#loc319 = loc("<eval_with_key>.2":401:12)
#loc320 = loc("<eval_with_key>.2":403:16)
#loc321 = loc("<eval_with_key>.2":404:14)
#loc322 = loc("<eval_with_key>.2":405:16)
#loc323 = loc("<eval_with_key>.2":406:14)
#loc324 = loc("<eval_with_key>.2":407:12)
#loc325 = loc("<eval_with_key>.2":408:21)
#loc326 = loc("<eval_with_key>.2":409:17)
#loc327 = loc("<eval_with_key>.2":410:14)
#loc328 = loc("<eval_with_key>.2":411:14)
#loc329 = loc("<eval_with_key>.2":412:14)
#loc330 = loc("<eval_with_key>.2":415:15)
#loc331 = loc("<eval_with_key>.2":416:14)
#loc332 = loc("<eval_with_key>.2":417:13)
#loc333 = loc("<eval_with_key>.2":420:26)
#loc334 = loc("<eval_with_key>.2":424:14)
#loc335 = loc("<eval_with_key>.2":427:15)
#loc336 = loc("<eval_with_key>.2":428:14)
#loc337 = loc("<eval_with_key>.2":429:13)
#loc338 = loc("<eval_with_key>.2":430:12)
#loc339 = loc("<eval_with_key>.2":431:13)
#loc340 = loc("<eval_with_key>.2":432:13)
#loc341 = loc("<eval_with_key>.2":433:13)
#loc342 = loc("<eval_with_key>.2":434:13)
#loc343 = loc("<eval_with_key>.2":436:13)
#loc344 = loc("<eval_with_key>.2":437:13)
#loc345 = loc("<eval_with_key>.2":438:14)
#loc346 = loc("<eval_with_key>.2":441:15)
#loc347 = loc("<eval_with_key>.2":442:14)
#loc348 = loc("<eval_with_key>.2":443:13)
#loc349 = loc("<eval_with_key>.2":446:27)
#loc350 = loc("<eval_with_key>.2":450:14)
#loc351 = loc("<eval_with_key>.2":453:15)
#loc352 = loc("<eval_with_key>.2":454:14)
#loc353 = loc("<eval_with_key>.2":455:20)
#loc354 = loc("<eval_with_key>.2":456:20)
#loc355 = loc("<eval_with_key>.2":457:20)
#loc356 = loc("<eval_with_key>.2":458:14)
#loc357 = loc("<eval_with_key>.2":459:17)
#loc358 = loc("<eval_with_key>.2":460:14)
#loc359 = loc("<eval_with_key>.2":461:17)
#loc360 = loc("<eval_with_key>.2":462:14)
#loc361 = loc("<eval_with_key>.2":463:17)
#loc362 = loc("<eval_with_key>.2":464:18)
#loc363 = loc("<eval_with_key>.2":465:16)
#loc364 = loc("<eval_with_key>.2":466:14)
#loc365 = loc("<eval_with_key>.2":467:16)
#loc366 = loc("<eval_with_key>.2":468:14)
#loc367 = loc("<eval_with_key>.2":469:13)
#loc368 = loc("<eval_with_key>.2":470:22)
#loc369 = loc("<eval_with_key>.2":472:25)
#loc370 = loc("<eval_with_key>.2":473:13)
#loc371 = loc("<eval_with_key>.2":475:20)
#loc372 = loc("<eval_with_key>.2":476:20)
#loc373 = loc("<eval_with_key>.2":477:20)
#loc374 = loc("<eval_with_key>.2":478:20)
#loc375 = loc("<eval_with_key>.2":479:29)
#loc376 = loc("<eval_with_key>.2":481:25)
#loc377 = loc("<eval_with_key>.2":482:14)
#loc378 = loc("<eval_with_key>.2":483:13)
#loc379 = loc("<eval_with_key>.2":484:12)
#loc380 = loc("<eval_with_key>.2":485:12)
#loc381 = loc("<eval_with_key>.2":486:12)
#loc382 = loc("<eval_with_key>.2":487:13)
#loc383 = loc("<eval_with_key>.2":489:16)
#loc384 = loc("<eval_with_key>.2":490:14)
#loc385 = loc("<eval_with_key>.2":491:16)
#loc386 = loc("<eval_with_key>.2":492:14)
#loc387 = loc("<eval_with_key>.2":493:13)
#loc388 = loc("<eval_with_key>.2":494:22)
#loc389 = loc("<eval_with_key>.2":495:17)
#loc390 = loc("<eval_with_key>.2":496:14)
#loc391 = loc("<eval_with_key>.2":497:14)
#loc392 = loc("<eval_with_key>.2":498:14)
#loc393 = loc("<eval_with_key>.2":501:15)
#loc394 = loc("<eval_with_key>.2":502:14)
#loc395 = loc("<eval_with_key>.2":503:13)
#loc396 = loc("<eval_with_key>.2":506:27)
#loc397 = loc("<eval_with_key>.2":510:14)
#loc398 = loc("<eval_with_key>.2":513:15)
#loc399 = loc("<eval_with_key>.2":514:14)
#loc400 = loc("<eval_with_key>.2":515:13)
#loc401 = loc("<eval_with_key>.2":516:12)
#loc402 = loc("<eval_with_key>.2":517:13)
#loc403 = loc("<eval_with_key>.2":518:13)
#loc404 = loc("<eval_with_key>.2":519:13)
#loc405 = loc("<eval_with_key>.2":520:13)
#loc406 = loc("<eval_with_key>.2":522:13)
#loc407 = loc("<eval_with_key>.2":523:13)
#loc408 = loc("<eval_with_key>.2":524:14)
#loc409 = loc("<eval_with_key>.2":527:15)
#loc410 = loc("<eval_with_key>.2":528:14)
#loc411 = loc("<eval_with_key>.2":529:13)
#loc412 = loc("<eval_with_key>.2":532:27)
#loc413 = loc("<eval_with_key>.2":536:14)
#loc414 = loc("<eval_with_key>.2":538:8)
#loc415 = loc("<eval_with_key>.2":539:14)
#loc416 = loc("<eval_with_key>.2":540:9)
#loc417 = loc("<eval_with_key>.2":541:22)
#loc418 = loc("<eval_with_key>.2":542:15)
#loc419 = loc("<eval_with_key>.2":543:13)
#loc420 = loc("<eval_with_key>.2":544:12)
@AmosLewis
Copy link
Author

AmosLewis commented Dec 22, 2022

With [MLIR][TORCH] Add e2e support for aten.as_stride #1742

and manually delete amax and selectint decompose int local lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

-    addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
+//    addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
-    addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
+//    addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);

I got aten.as_stride lower to torch mlir successfully:
%119 = torch.operator "aten.as_strided"(%116, %117, %118, %int0) : (!torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.tensor loc(#loc38)
--->
%119 = torch.aten.as_strided %116, %117, %118, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc38)
This distillgpt2_torch_delete_decompose_amax_selectint.mlir is the one I got
But:

➜  torch-mlir git:(as_stride) ✗ python distillGPT2/distillgpt2.py
8Some weights of the model checkpoint at distilgpt2 were not used when initializing GPT2ForSequenceClassification: ['lm_head.weight']
- This IS expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
model(test_input): 
tensor([[-1.4607,  1.7842]], grad_fn=<IndexBackward0>)
/home/chi/src/ubuntu20/shark/SHARK/shark.venv/lib/python3.10/site-packages/torch/jit/_check.py:181: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn("The TorchScript type system doesn't support "
Traceback (most recent call last):
  File "/home/chi/src/ubuntu20/shark/torch-mlir/distillGPT2/distillgpt2.py", line 92, in <module>
    module = torch_mlir.compile(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 371, in compile
    run_pipeline_with_repro_report(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: found an op that was marked as backend illegal
note: see current operation: %164 = "torch.aten.amax"(%162, %163, %88) : (!torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool) -> !torch.vtensor<[1,12,128,1],f32>
note: this is likely due to DecomposeComplexOps being unable to decompose this op


For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

In this mlir, the key mlir piece is:

    %165 = torch.aten.as_strided %162, %163, %164, %int0 : !torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.tensor loc(#loc59)
    %166 = torch.prims.convert_element_type %165, %int11 : !torch.tensor, !torch.int -> !torch.tensor loc(#loc60)
    %167 = torch.prim.GetAttr %arg0["_tensor_constant2"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc)
    %168 = torch.aten.lift_fresh_copy %167 : !torch.tensor -> !torch.tensor loc(#loc61)
    %169 = torch.aten.where.self %166, %152, %168 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tensor loc(#loc62)
    %170 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> loc(#loc)
    %171 = torch.aten.amax %169, %170, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc63)

If I get the decompose back,I will got the previous getDtype() bug:

python: /home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h:77: mlir::Type mlir::torch::Torch::BaseTensorType::getDtype() const: Assertion `hasDtype() && "must have a dtype"' failed.

#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1  0x00007ffff7ddd859 in __GI_abort () at abort.c:79
#2  0x00007ffff7ddd729 in __assert_fail_base (fmt=0x7ffff7f73588 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n", 
    assertion=0x7fffb49431d3 "hasDtype() && \"must have a dtype\"", 
    file=0x7fffb49430f1 "/home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", line=77, 
    function=<optimized out>) at assert.c:92
#3  0x00007ffff7deefd6 in __GI___assert_fail (assertion=0x7fffb49431d3 "hasDtype() && \"must have a dtype\"", 
    file=0x7fffb49430f1 "/home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", line=77, 
    function=0x7fffb49431f5 "mlir::Type mlir::torch::Torch::BaseTensorType::getDtype() const") at assert.c:101
#4  0x00007fffafaf2c5c in mlir::torch::Torch::BaseTensorType::getDtype (this=0x7fffffff9cd0)
    at /home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h:77
#5  0x00007fffafd8b7ec in computeReductionType (rewriter=..., op=0x7b17ee0, tensorType=..., dim=..., keepDim=true)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp:74
#6  0x00007fffafd8b0ef in createMaxAlongDimension (rewriter=..., loc=..., op=0x7b17ee0, input=..., dim=..., keepDim=true)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp:100
#7  0x00007fffafde37f0 in (anonymous namespace)::DecomposeAtenAmaxOp::matchAndRewrite (this=0x60167d0, op=..., rewriter=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp:214

@AmosLewis
Copy link
Author

(gdb) bt
#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1  0x00007ffff7ddd859 in __GI_abort () at abort.c:79
#2  0x00007ffff7ddd729 in __assert_fail_base (fmt=0x7ffff7f73588 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n", 
    assertion=0x7fffb49431d3 "hasDtype() && \"must have a dtype\"", 
    file=0x7fffb49430f1 "/home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", line=77, 
    function=<optimized out>) at assert.c:92
#3  0x00007ffff7deefd6 in __GI___assert_fail (assertion=0x7fffb49431d3 "hasDtype() && \"must have a dtype\"", 
    file=0x7fffb49430f1 "/home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", line=77, 
    function=0x7fffb49431f5 "mlir::Type mlir::torch::Torch::BaseTensorType::getDtype() const") at assert.c:101
#4  0x00007fffafaf2c5c in mlir::torch::Torch::BaseTensorType::getDtype (this=0x7fffffff9cd0)
    at /home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h:77
#5  0x00007fffafd8b7ec in computeReductionType (rewriter=..., op=0x7b17ee0, tensorType=..., dim=..., keepDim=true)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp:74
#6  0x00007fffafd8b0ef in createMaxAlongDimension (rewriter=..., loc=..., op=0x7b17ee0, input=..., dim=..., keepDim=true)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp:100
#7  0x00007fffafde37f0 in (anonymous namespace)::DecomposeAtenAmaxOp::matchAndRewrite (this=0x60167d0, op=..., rewriter=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp:214
#8  0x00007fffafe2b8fe in mlir::detail::OpOrInterfaceRewritePatternBase<mlir::torch::Torch::AtenAmaxOp>::matchAndRewrite (
    this=0x60167d0, op=0x7b17ee0, rewriter=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/PatternMatch.h:329
#9  0x00007fffb336969c in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (this=0x7fffffffa820, op=0x7b17ee0, rewriter=..., canApply=..., onFailure=..., onSuccess=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:200
#10 0x00007fffb3346ea1 in (anonymous namespace)::GreedyPatternRewriteDriver::simplify (this=0x7fffffffa7f8, regions=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:274
#11 0x00007fffb3346448 in mlir::applyPatternsAndFoldGreedily (regions=..., patterns=..., config=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:417
#12 0x00007fffac17e582 in mlir::applyPatternsAndFoldGreedily (op=0x6c7e1a0, patterns=..., config=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h:71
#13 0x00007fffafd8105d in (anonymous namespace)::DecomposeComplexOpsPass::runOnOperation (this=0xa10a690)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp:3513
#14 0x00007fffac1e5740 in mlir::detail::OpToOpPassAdaptor::run (pass=0xa10a690, op=0x6c7e1a0, am=..., verifyPasses=true, 
--Type <RET> for more, q to quit, c to continue without paging--
    parentInitGeneration=1) at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:465
#15 0x00007fffac1e5d4a in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x6c7e1a0, am=..., verifyPasses=true, 
    parentInitGeneration=1, instrumentor=0x0, parentInfo=0x7fffffffb190)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:529
#16 0x00007fffac1eb556 in mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_14::operator()(mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo&) const (this=0x7fffffffb128, opInfo=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:749
#17 0x00007fffac1eb19c in mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_14&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_14&) (context=0x6c51420, begin={passManagerIdx = 0, op = 0x6c7e1a0, am = {impl = 0x601f370}}, 
    end={passManagerIdx = 257, op = 0x5b94dc0, am = {impl = 0x991010}}, func=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/Threading.h:46
#18 0x00007fffac1e7096 in mlir::failableParallelForEach<std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> >&, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_14&>(mlir::MLIRContext*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> >&, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_14&) (context=0x6c51420, 
    range=std::vector of length 1, capacity 1 = {...}, func=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/Threading.h:92
#19 0x00007fffac1e6916 in mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl (this=0x6064990, verifyPasses=true)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:759
#20 0x00007fffac1e59fd in mlir::detail::OpToOpPassAdaptor::runOnOperation (this=0x6064990, verifyPasses=true)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:650
#21 0x00007fffac1e572e in mlir::detail::OpToOpPassAdaptor::run (pass=0x6064990, op=0x5454010, am=..., verifyPasses=true, 
    parentInitGeneration=1) at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:463
#22 0x00007fffac1e5d4a in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x5454010, am=..., verifyPasses=true, 
    parentInitGeneration=1, instrumentor=0x0, parentInfo=0x7fffffffbe70)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:529
#23 0x00007fffac1ea283 in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_6::operator()(mlir::OpPassManager&, mlir::Operation*) const (this=0x7fffffffbe40, pipeline=..., root=0x5454010)
--Type <RET> for more, q to quit, c to continue without paging--
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:451





#24 0x00007fffac1e9fd5 in llvm::function_ref<mlir::LogicalResult (mlir::OpPassManager&, mlir::Operation*)>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_6>(long, mlir::OpPassManager&, mlir::Operation*) (callable=140737488338496, params=0x5454010, params=0x5454010)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
#25 0x00007fffadaab2a1 in llvm::function_ref<mlir::LogicalResult (mlir::OpPassManager&, mlir::Operation*)>::operator()(mlir::OpPassManager&, mlir::Operation*) const (this=0x5f71b50, params=0x5454010, params=0x5454010)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#26 0x00007fffadaab0dc in mlir::Pass::runPipeline (this=0x5f71ae0, pipeline=..., op=0x5454010)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/include/mlir/Pass/Pass.h:195
#27 0x00007fffafe6511d in (anonymous namespace)::LowerToBackendContractPass::runOnOperation (this=0x5f71ae0)
    at /home/chi/src/ubuntu20/shark/torch-mlir/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp:277
#28 0x00007fffac1e5740 in mlir::detail::OpToOpPassAdaptor::run (pass=0x5f71ae0, op=0x5454010, am=..., verifyPasses=true, 
    parentInitGeneration=1) at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:465
#29 0x00007fffac1e5d4a in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x5454010, am=..., verifyPasses=true, 
    parentInitGeneration=1, instrumentor=0x0, parentInfo=0x0)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:529
#30 0x00007fffac1e77c3 in mlir::PassManager::runPasses (this=0x5466ce0, op=0x5454010, am=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:834
#31 0x00007fffac1e76da in mlir::PassManager::run (this=0x5466ce0, op=0x5454010)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:814
#32 0x00007fffac077a62 in mlirPassManagerRun (passManager=..., module=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/CAPI/IR/Pass.cpp:43
#33 0x00007fffb8a7b24e in mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5::operator()((anonymous namespace)::PyPassManager&, mlir::python::PyModule&) const (this=0x502e1e8, passManager=..., module=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/externals/llvm-project/mlir/lib/Bindings/Python/Pass.cpp:121
#34 0x00007fffb8a7b1ef in pybind11::detail::argument_loader<(anonymous namespace)::PyPassManager&, mlir::python::PyModule&>::call_impl<void, mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5&, 0ul, 1ul, pybind11::detail::void_type>(mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5&, std::integer_sequence<unsigned long, 0ul, 1ul>, pybind11::detail::void_type&&) && (this=0x7fffffffc678, f=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/mlir_venv/lib/python3.10/site-packages/pybind11/include/pybind11/cast.h:1441
#35 0x00007fffb8a7af56 in pybind11::detail::argument_loader<(anonymous namespace)::PyPassManager&, mlir::python::PyModule&>::call<void, pybind11::detail::void_type, mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5&> (this=0x7fffffffc678, 
    f=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/mlir_venv/lib/python3.10/site-packages/pybind11/include/pybind11/cast.h:1415
#36 0x00007fffb8a7ae6e in pybind11::cpp_function::initialize<mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5,--Type <RET> for more, q to quit, c to continue without paging--
 void, (anonymous namespace)::PyPassManager&, mlir::python::PyModule&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, char [78]>(mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5&&, void (*)((anonymous namespace)::PyPassManager&, mlir::python::PyModule&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, char const (&) [78])::{lambda(pybind11::detail::function_call&)#1}::operator()(pybind11::detail::function_call&) const
    (this=0x7fffffffce88, call=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/mlir_venv/lib/python3.10/site-packages/pybind11/include/pybind11/pybind11.h:249
#37 0x00007fffb8a7adc5 in pybind11::cpp_function::initialize<mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5, void, (anonymous namespace)::PyPassManager&, mlir::python::PyModule&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, char [78]>(mlir::python::populatePassManagerSubmodule(pybind11::module_&)::$_5&&, void (*)((anonymous namespace)::PyPassManager&, mlir::python::PyModule&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, char const (&) [78])::{lambda(pybind11::detail::function_call&)#1}::__invoke(pybind11::detail::function_call&) (
    call=...)
    at /home/chi/src/ubuntu20/shark/torch-mlir/mlir_venv/lib/python3.10/site-packages/pybind11/include/pybind11/pybind11.h:224
#38 0x00007fffb8807e45 in pybind11::cpp_function::dispatcher (self=0x7fffb8c54840, args_in=0x7fffb8e81b00, kwargs_in=0x0)
    at /home/chi/src/ubuntu20/shark/torch-mlir/mlir_venv/lib/python3.10/site-packages/pybind11/include/pybind11/pybind11.h:929

@AmosLewis
Copy link
Author

AmosLewis commented Dec 22, 2022

Since it was crashing I got the raw (torchscript) IR first by:

module = torch_mlir.compile(
    ts_g,
    (test_input),
    torch_mlir.OutputType.RAW,
    use_tracing=True,
    verbose=False,
)

from contextlib import redirect_stdout
with open('distilgpt2_raw_ir.mlir', 'w') as f:
    with redirect_stdout(f):
        print(module.operation.get_asm())

Then I got the elided IR by running this:
torch-mlir-opt --mlir-elide-elementsattrs-if-larger=4 distilgpt2_raw_ir.mlir > distilgpt2_raw_ir_elided.mlir
After this I ran the following command to get the IR after all the passes:
torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints})' distilgpt2_raw_ir_elided.mlir --mlir-print-ir-after-all > distilgpt2_debug.mlir
Since the issue was related to type, so I searched for an IR dump after the RefineTypes pass. In that IR dump, I saw that the very first op with unk dtype is the AtenTanhOp, then I fixed that and it worked.

In general, this way you can debug issues faster and more accurately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment