Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 16, 2022 21:37
Show Gist options
  • Save AmosLewis/e419fe0583482c10715b6e1dd974f557 to your computer and use it in GitHub Desktop.
Save AmosLewis/e419fe0583482c10715b6e1dd974f557 to your computer and use it in GitHub Desktop.
➜ deberta git:(main) ✗ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir --mlir-print-ir-after-failure -mlir-disable-threading
<eval_with_key>.2:8:54: error: failed to legalize operation 'torch.constant.int'
<eval_with_key>.2:8:54: note: see current operation: %1 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int
// -----// IR Dump After FinalizingBackendTypeConversion Failed (torch-finalizing-backend-type-conversion) //----- //
func.func @forward(%arg0: tensor<1x128xi64>) -> tensor<1x2xf32> {
%0 = "tosa.const"() {value = dense<[[65536, 512, 1]]> : tensor<1x3xi32>} : () -> tensor<1x3xi32>
%int0 = torch.constant.int 0
%1 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2x768xf32>} : () -> tensor<2x768xf32>
%2 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x768xf32>} : () -> tensor<768x768xf32>
%3 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768xf32>} : () -> tensor<768xf32>
%4 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x3072xf32>} : () -> tensor<768x3072xf32>
%5 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072xf32>} : () -> tensor<3072xf32>
%6 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072x768xf32>} : () -> tensor<3072x768xf32>
%7 = "tosa.const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
%8 = "tosa.const"() {value = dense_resource<__elided__> : tensor<512x768xf32>} : () -> tensor<512x768xf32>
%9 = "tosa.const"() {value = dense_resource<__elided__> : tensor<128x128xsi64>} : () -> tensor<128x128xi64>
%10 = "tosa.const"() {value = dense_resource<__elided__> : tensor<128100x768xf32>} : () -> tensor<128100x768xf32>
%11 = "tosa.const"() {value = dense<7.680000e+02> : tensor<1xf32>} : () -> tensor<1xf32>
%12 = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%13 = "tosa.const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
%14 = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%15 = "tosa.const"() {value = dense<13.8564062> : tensor<f32>} : () -> tensor<f32>
%16 = "tosa.const"() {value = dense<0> : tensor<1x128x128x1xi32>} : () -> tensor<1x128x128x1xi32>
%17 = "tosa.const"() {value = dense<"tensor<1x128x128x1xi32>} : () -> tensor<1x128x128x1xi32>
%18 = "tosa.const"() {value = dense<0> : tensor<1x1x128x128xi8>} : () -> tensor<1x1x128x128xi8>
%19 = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%20 = "tosa.const"() {value = dense<7.810800e-02> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%21 = "tosa.const"() {value = dense<9.720000e-04> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%22 = "tosa.const"() {value = dense<2.303890e-01> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%23 = "tosa.const"() {value = dense<2.783930e-01> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%24 = "tosa.const"() {value = dense<0.707106769> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%25 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%26 = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%27 = "tosa.const"() {value = dense<7.810800e-02> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%28 = "tosa.const"() {value = dense<9.720000e-04> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%29 = "tosa.const"() {value = dense<2.303890e-01> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%30 = "tosa.const"() {value = dense<2.783930e-01> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%31 = "tosa.const"() {value = dense<0.707106769> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%32 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%33 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%34 = "tosa.const"() {value = dense<256> : tensor<1x1x1x1xi32>} : () -> tensor<1x1x1x1xi32>
%35 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%36 = "tosa.const"() {value = dense<1.000000e-07> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%37 = "tosa.const"() {value = dense<1.000000e-07> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%38 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1x128xf32>} : () -> tensor<1x1x1x128xf32>
%39 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x128x1xf32>} : () -> tensor<1x1x128x1xf32>
%40 = torch_c.from_builtin_tensor %arg0 : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64>
%41 = torch_c.to_builtin_tensor %40 : !torch.vtensor<[1,128],si64> -> tensor<1x128xi64>
%42 = "tosa.reshape"(%10) {new_shape = [1, 128100, 768]} : (tensor<128100x768xf32>) -> tensor<1x128100x768xf32>
%43 = "tosa.cast"(%41) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%44 = "tosa.gather"(%42, %43) : (tensor<1x128100x768xf32>, tensor<1x128xi32>) -> tensor<1x128x768xf32>
%45 = "tosa.reciprocal"(%11) : (tensor<1xf32>) -> tensor<1xf32>
%46 = "tosa.reduce_sum"(%44) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%47 = "tosa.reshape"(%45) {new_shape = [1, 1, 1]} : (tensor<1xf32>) -> tensor<1x1x1xf32>
%48 = "tosa.mul"(%46, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%49 = "tosa.sub"(%44, %48) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%50 = "tosa.mul"(%49, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%51 = "tosa.reduce_sum"(%50) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%52 = "tosa.mul"(%51, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%53 = "tosa.reshape"(%3) {new_shape = [1, 1, 768]} : (tensor<768xf32>) -> tensor<1x1x768xf32>
%54 = "tosa.add"(%52, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%55 = "tosa.rsqrt"(%54) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%56 = "tosa.mul"(%49, %55) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%57 = "tosa.mul"(%56, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%58 = "tosa.add"(%57, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%59 = "tosa.mul"(%38, %39) {shift = 0 : i32} : (tensor<1x1x1x128xf32>, tensor<1x1x128x1xf32>) -> tensor<1x1x128x128xf32>
%60 = "tosa.cast"(%59) : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xi8>
%61 = "tosa.reduce_sum"(%8) {axis = 1 : i64} : (tensor<512x768xf32>) -> tensor<512x1xf32>
%62 = "tosa.reshape"(%45) {new_shape = [1, 1]} : (tensor<1xf32>) -> tensor<1x1xf32>
%63 = "tosa.mul"(%61, %62) {shift = 0 : i32} : (tensor<512x1xf32>, tensor<1x1xf32>) -> tensor<512x1xf32>
%64 = "tosa.sub"(%8, %63) : (tensor<512x768xf32>, tensor<512x1xf32>) -> tensor<512x768xf32>
%65 = "tosa.mul"(%64, %64) {shift = 0 : i32} : (tensor<512x768xf32>, tensor<512x768xf32>) -> tensor<512x768xf32>
%66 = "tosa.reduce_sum"(%65) {axis = 1 : i64} : (tensor<512x768xf32>) -> tensor<512x1xf32>
%67 = "tosa.mul"(%66, %62) {shift = 0 : i32} : (tensor<512x1xf32>, tensor<1x1xf32>) -> tensor<512x1xf32>
%68 = "tosa.reshape"(%3) {new_shape = [1, 768]} : (tensor<768xf32>) -> tensor<1x768xf32>
%69 = "tosa.add"(%67, %36) : (tensor<512x1xf32>, tensor<1x1xf32>) -> tensor<512x1xf32>
%70 = "tosa.rsqrt"(%69) : (tensor<512x1xf32>) -> tensor<512x1xf32>
%71 = "tosa.mul"(%64, %70) {shift = 0 : i32} : (tensor<512x768xf32>, tensor<512x1xf32>) -> tensor<512x768xf32>
%72 = "tosa.mul"(%71, %68) {shift = 0 : i32} : (tensor<512x768xf32>, tensor<1x768xf32>) -> tensor<512x768xf32>
%73 = "tosa.add"(%72, %68) : (tensor<512x768xf32>, tensor<1x768xf32>) -> tensor<512x768xf32>
%74 = "tosa.transpose"(%2, %12) : (tensor<768x768xf32>, tensor<2xi32>) -> tensor<768x768xf32>
%75 = "tosa.reshape"(%74) {new_shape = [1, 768, 768]} : (tensor<768x768xf32>) -> tensor<1x768x768xf32>
%76 = "tosa.matmul"(%58, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%77 = "tosa.reshape"(%76) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%78 = "tosa.add"(%68, %77) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%79 = "tosa.reshape"(%78) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%80 = "tosa.transpose"(%79, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%81 = "tosa.reshape"(%80) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%82 = "tosa.transpose"(%81, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%83 = "tosa.matmul"(%81, %82) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%84 = "tosa.reciprocal"(%15) : (tensor<f32>) -> tensor<f32>
%85 = "tosa.reshape"(%84) {new_shape = [1, 1, 1]} : (tensor<f32>) -> tensor<1x1x1xf32>
%86 = "tosa.mul"(%83, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%87 = "tosa.reshape"(%9) {new_shape = [1, 1, 128, 128]} : (tensor<128x128xi64>) -> tensor<1x1x128x128xi64>
%88 = "tosa.reshape"(%73) {new_shape = [1, 512, 768]} : (tensor<512x768xf32>) -> tensor<1x512x768xf32>
%89 = "tosa.matmul"(%88, %75) : (tensor<1x512x768xf32>, tensor<1x768x768xf32>) -> tensor<1x512x768xf32>
%90 = "tosa.reshape"(%89) {new_shape = [512, 768]} : (tensor<1x512x768xf32>) -> tensor<512x768xf32>
%91 = "tosa.add"(%68, %90) : (tensor<1x768xf32>, tensor<512x768xf32>) -> tensor<512x768xf32>
%92 = "tosa.reshape"(%91) {new_shape = [1, 512, 12, -1]} : (tensor<512x768xf32>) -> tensor<1x512x12x64xf32>
%93 = "tosa.transpose"(%92, %13) : (tensor<1x512x12x64xf32>, tensor<4xi64>) -> tensor<1x12x512x64xf32>
%94 = "tosa.reshape"(%93) {new_shape = [12, 512, 64]} : (tensor<1x12x512x64xf32>) -> tensor<12x512x64xf32>
%95 = "tosa.transpose"(%94, %14) : (tensor<12x512x64xf32>, tensor<3xi32>) -> tensor<12x64x512xf32>
%96 = "tosa.matmul"(%81, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%97 = "tosa.cast"(%87) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32>
%98 = "tosa.add"(%97, %34) : (tensor<1x1x128x128xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x128x128xi32>
%99 = "tosa.cast"(%98) : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64>
%100 = "tosa.clamp"(%99) {max_fp = 5.110000e+02 : f32, max_int = 511 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64>
%101 = "tosa.reshape"(%100) {new_shape = [1, 128, 128, 1]} : (tensor<1x1x128x128xi64>) -> tensor<1x128x128x1xi64>
%102 = "tosa.cast"(%101) : (tensor<1x128x128x1xi64>) -> tensor<1x128x128x1xi32>
%103 = "tosa.concat"(%16, %17, %102) {axis = 3 : i64} : (tensor<1x128x128x1xi32>, tensor<1x128x128x1xi32>, tensor<1x128x128x1xi32>) -> tensor<1x128x128x3xi32>
%104 = "tosa.reshape"(%96) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%105 = "tosa.reshape"(%103) {new_shape = [16384, 3]} : (tensor<1x128x128x3xi32>) -> tensor<16384x3xi32>
%106 = "tosa.mul"(%105, %0) {shift = 0 : i32} : (tensor<16384x3xi32>, tensor<1x3xi32>) -> tensor<16384x3xi32>
%107 = "tosa.reduce_sum"(%106) {axis = 1 : i64} : (tensor<16384x3xi32>) -> tensor<16384x1xi32>
%108 = "tosa.reshape"(%107) {new_shape = [1, 16384]} : (tensor<16384x1xi32>) -> tensor<1x16384xi32>
%109 = "tosa.gather"(%104, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%110 = "tosa.reshape"(%109) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%111 = "tosa.mul"(%110, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%112 = "tosa.add"(%111, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%113 = "tosa.negate"(%87) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64>
%114 = "tosa.cast"(%113) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32>
%115 = "tosa.add"(%114, %34) : (tensor<1x1x128x128xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x128x128xi32>
%116 = "tosa.cast"(%115) : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64>
%117 = "tosa.clamp"(%116) {max_fp = 5.110000e+02 : f32, max_int = 511 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64>
%118 = "tosa.reshape"(%117) {new_shape = [1, 128, 128, 1]} : (tensor<1x1x128x128xi64>) -> tensor<1x128x128x1xi64>
%119 = "tosa.cast"(%118) : (tensor<1x128x128x1xi64>) -> tensor<1x128x128x1xi32>
%120 = "tosa.concat"(%16, %17, %119) {axis = 3 : i64} : (tensor<1x128x128x1xi32>, tensor<1x128x128x1xi32>, tensor<1x128x128x1xi32>) -> tensor<1x128x128x3xi32>
%121 = "tosa.reshape"(%120) {new_shape = [16384, 3]} : (tensor<1x128x128x3xi32>) -> tensor<16384x3xi32>
%122 = "tosa.mul"(%121, %0) {shift = 0 : i32} : (tensor<16384x3xi32>, tensor<1x3xi32>) -> tensor<16384x3xi32>
%123 = "tosa.reduce_sum"(%122) {axis = 1 : i64} : (tensor<16384x3xi32>) -> tensor<16384x1xi32>
%124 = "tosa.reshape"(%123) {new_shape = [1, 16384]} : (tensor<16384x1xi32>) -> tensor<1x16384xi32>
%125 = "tosa.gather"(%104, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%126 = "tosa.reshape"(%125) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%127 = "tosa.transpose"(%126, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%128 = "tosa.mul"(%127, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%129 = "tosa.add"(%112, %128) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%130 = "tosa.add"(%86, %129) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%131 = "tosa.reshape"(%130) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%132 = torch_c.from_builtin_tensor %131 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%133 = "tosa.equal"(%60, %18) : (tensor<1x1x128x128xi8>, tensor<1x1x128x128xi8>) -> tensor<1x1x128x128xi1>
%134 = "tosa.logical_not"(%133) : (tensor<1x1x128x128xi1>) -> tensor<1x1x128x128xi1>
%135 = "tosa.bitwise_not"(%134) : (tensor<1x1x128x128xi1>) -> tensor<1x1x128x128xi1>
%136 = torch_c.from_builtin_tensor %135 : tensor<1x1x128x128xi1> -> !torch.vtensor<[1,1,128,128],i1>
%137 = torch_c.from_builtin_tensor %7 : tensor<f32> -> !torch.vtensor<[],f32>
%138 = torch.aten.masked_fill.Tensor %132, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%139 = torch_c.to_builtin_tensor %138 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%140 = "tosa.reduce_max"(%139) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%141 = "tosa.sub"(%139, %140) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%142 = "tosa.exp"(%141) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%143 = "tosa.reduce_sum"(%142) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%144 = "tosa.reciprocal"(%143) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%145 = "tosa.mul"(%142, %144) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%146 = torch_c.from_builtin_tensor %145 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%147 = torch.aten.masked_fill.Scalar %146, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%148 = torch_c.to_builtin_tensor %147 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%149 = "tosa.reshape"(%148) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%150 = "tosa.matmul"(%149, %81) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%151 = "tosa.reshape"(%150) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%152 = "tosa.transpose"(%151, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%153 = "tosa.reshape"(%152) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%154 = "tosa.matmul"(%153, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%155 = "tosa.reshape"(%154) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%156 = "tosa.add"(%68, %155) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%157 = "tosa.reshape"(%156) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%158 = "tosa.add"(%157, %58) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%159 = "tosa.reduce_sum"(%158) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%160 = "tosa.mul"(%159, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%161 = "tosa.sub"(%158, %160) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%162 = "tosa.mul"(%161, %161) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%163 = "tosa.reduce_sum"(%162) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%164 = "tosa.mul"(%163, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%165 = "tosa.add"(%164, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%166 = "tosa.rsqrt"(%165) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%167 = "tosa.mul"(%161, %166) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%168 = "tosa.mul"(%167, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%169 = "tosa.add"(%168, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%170 = "tosa.transpose"(%6, %12) : (tensor<3072x768xf32>, tensor<2xi32>) -> tensor<768x3072xf32>
%171 = "tosa.reshape"(%170) {new_shape = [1, 768, 3072]} : (tensor<768x3072xf32>) -> tensor<1x768x3072xf32>
%172 = "tosa.matmul"(%169, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%173 = "tosa.reshape"(%172) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%174 = "tosa.reshape"(%5) {new_shape = [1, 3072]} : (tensor<3072xf32>) -> tensor<1x3072xf32>
%175 = "tosa.add"(%174, %173) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%176 = "tosa.reshape"(%175) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%177 = "tosa.sub"(%176, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%178 = "tosa.mul"(%177, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%179 = "tosa.abs"(%178) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%180 = "tosa.mul"(%179, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%181 = "tosa.add"(%180, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%182 = "tosa.mul"(%179, %179) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%183 = "tosa.mul"(%182, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%184 = "tosa.add"(%181, %183) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%185 = "tosa.mul"(%182, %179) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%186 = "tosa.mul"(%185, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%187 = "tosa.add"(%184, %186) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%188 = "tosa.mul"(%185, %179) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%189 = "tosa.mul"(%188, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%190 = "tosa.add"(%187, %189) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%191 = "tosa.reciprocal"(%190) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%192 = "tosa.mul"(%191, %191) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%193 = "tosa.mul"(%192, %192) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%194 = "tosa.sub"(%32, %193) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%195 = "tosa.greater_equal"(%178, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%196 = "tosa.negate"(%194) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%197 = "tosa.select"(%195, %194, %196) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%198 = "tosa.add"(%197, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%199 = "tosa.mul"(%198, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%200 = "tosa.mul"(%176, %199) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%201 = "tosa.transpose"(%4, %12) : (tensor<768x3072xf32>, tensor<2xi32>) -> tensor<3072x768xf32>
%202 = "tosa.reshape"(%201) {new_shape = [1, 3072, 768]} : (tensor<3072x768xf32>) -> tensor<1x3072x768xf32>
%203 = "tosa.matmul"(%200, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%204 = "tosa.reshape"(%203) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%205 = "tosa.add"(%68, %204) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%206 = "tosa.reshape"(%205) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%207 = "tosa.add"(%206, %169) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%208 = "tosa.reduce_sum"(%207) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%209 = "tosa.mul"(%208, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%210 = "tosa.sub"(%207, %209) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%211 = "tosa.mul"(%210, %210) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%212 = "tosa.reduce_sum"(%211) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%213 = "tosa.mul"(%212, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%214 = "tosa.add"(%213, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%215 = "tosa.rsqrt"(%214) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%216 = "tosa.mul"(%210, %215) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%217 = "tosa.mul"(%216, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%218 = "tosa.add"(%217, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%219 = "tosa.matmul"(%218, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%220 = "tosa.reshape"(%219) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%221 = "tosa.add"(%68, %220) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%222 = "tosa.reshape"(%221) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%223 = "tosa.transpose"(%222, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%224 = "tosa.reshape"(%223) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%225 = "tosa.transpose"(%224, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%226 = "tosa.matmul"(%224, %225) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%227 = "tosa.mul"(%226, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%228 = "tosa.matmul"(%224, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%229 = "tosa.reshape"(%228) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%230 = "tosa.gather"(%229, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%231 = "tosa.reshape"(%230) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%232 = "tosa.mul"(%231, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%233 = "tosa.add"(%232, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%234 = "tosa.gather"(%229, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%235 = "tosa.reshape"(%234) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%236 = "tosa.transpose"(%235, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%237 = "tosa.mul"(%236, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%238 = "tosa.add"(%233, %237) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%239 = "tosa.add"(%227, %238) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%240 = "tosa.reshape"(%239) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%241 = torch_c.from_builtin_tensor %240 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%242 = torch.aten.masked_fill.Tensor %241, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%243 = torch_c.to_builtin_tensor %242 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%244 = "tosa.reduce_max"(%243) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%245 = "tosa.sub"(%243, %244) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%246 = "tosa.exp"(%245) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%247 = "tosa.reduce_sum"(%246) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%248 = "tosa.reciprocal"(%247) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%249 = "tosa.mul"(%246, %248) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%250 = torch_c.from_builtin_tensor %249 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%251 = torch.aten.masked_fill.Scalar %250, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%252 = torch_c.to_builtin_tensor %251 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%253 = "tosa.reshape"(%252) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%254 = "tosa.matmul"(%253, %224) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%255 = "tosa.reshape"(%254) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%256 = "tosa.transpose"(%255, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%257 = "tosa.reshape"(%256) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%258 = "tosa.matmul"(%257, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%259 = "tosa.reshape"(%258) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%260 = "tosa.add"(%68, %259) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%261 = "tosa.reshape"(%260) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%262 = "tosa.add"(%261, %218) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%263 = "tosa.reduce_sum"(%262) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%264 = "tosa.mul"(%263, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%265 = "tosa.sub"(%262, %264) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%266 = "tosa.mul"(%265, %265) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%267 = "tosa.reduce_sum"(%266) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%268 = "tosa.mul"(%267, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%269 = "tosa.add"(%268, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%270 = "tosa.rsqrt"(%269) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%271 = "tosa.mul"(%265, %270) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%272 = "tosa.mul"(%271, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%273 = "tosa.add"(%272, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%274 = "tosa.matmul"(%273, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%275 = "tosa.reshape"(%274) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%276 = "tosa.add"(%174, %275) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%277 = "tosa.reshape"(%276) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%278 = "tosa.sub"(%277, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%279 = "tosa.mul"(%278, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%280 = "tosa.abs"(%279) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%281 = "tosa.mul"(%280, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%282 = "tosa.add"(%281, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%283 = "tosa.mul"(%280, %280) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%284 = "tosa.mul"(%283, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%285 = "tosa.add"(%282, %284) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%286 = "tosa.mul"(%283, %280) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%287 = "tosa.mul"(%286, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%288 = "tosa.add"(%285, %287) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%289 = "tosa.mul"(%286, %280) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%290 = "tosa.mul"(%289, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%291 = "tosa.add"(%288, %290) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%292 = "tosa.reciprocal"(%291) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%293 = "tosa.mul"(%292, %292) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%294 = "tosa.mul"(%293, %293) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%295 = "tosa.sub"(%32, %294) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%296 = "tosa.greater_equal"(%279, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%297 = "tosa.negate"(%295) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%298 = "tosa.select"(%296, %295, %297) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%299 = "tosa.add"(%298, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%300 = "tosa.mul"(%299, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%301 = "tosa.mul"(%277, %300) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%302 = "tosa.matmul"(%301, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%303 = "tosa.reshape"(%302) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%304 = "tosa.add"(%68, %303) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%305 = "tosa.reshape"(%304) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%306 = "tosa.add"(%305, %273) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%307 = "tosa.reduce_sum"(%306) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%308 = "tosa.mul"(%307, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%309 = "tosa.sub"(%306, %308) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%310 = "tosa.mul"(%309, %309) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%311 = "tosa.reduce_sum"(%310) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%312 = "tosa.mul"(%311, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%313 = "tosa.add"(%312, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%314 = "tosa.rsqrt"(%313) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%315 = "tosa.mul"(%309, %314) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%316 = "tosa.mul"(%315, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%317 = "tosa.add"(%316, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%318 = "tosa.matmul"(%317, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%319 = "tosa.reshape"(%318) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%320 = "tosa.add"(%68, %319) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%321 = "tosa.reshape"(%320) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%322 = "tosa.transpose"(%321, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%323 = "tosa.reshape"(%322) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%324 = "tosa.transpose"(%323, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%325 = "tosa.matmul"(%323, %324) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%326 = "tosa.mul"(%325, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%327 = "tosa.matmul"(%323, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%328 = "tosa.reshape"(%327) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%329 = "tosa.gather"(%328, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%330 = "tosa.reshape"(%329) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%331 = "tosa.mul"(%330, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%332 = "tosa.add"(%331, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%333 = "tosa.gather"(%328, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%334 = "tosa.reshape"(%333) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%335 = "tosa.transpose"(%334, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%336 = "tosa.mul"(%335, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%337 = "tosa.add"(%332, %336) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%338 = "tosa.add"(%326, %337) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%339 = "tosa.reshape"(%338) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%340 = torch_c.from_builtin_tensor %339 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%341 = torch.aten.masked_fill.Tensor %340, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%342 = torch_c.to_builtin_tensor %341 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%343 = "tosa.reduce_max"(%342) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%344 = "tosa.sub"(%342, %343) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%345 = "tosa.exp"(%344) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%346 = "tosa.reduce_sum"(%345) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%347 = "tosa.reciprocal"(%346) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%348 = "tosa.mul"(%345, %347) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%349 = torch_c.from_builtin_tensor %348 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%350 = torch.aten.masked_fill.Scalar %349, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%351 = torch_c.to_builtin_tensor %350 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%352 = "tosa.reshape"(%351) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%353 = "tosa.matmul"(%352, %323) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%354 = "tosa.reshape"(%353) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%355 = "tosa.transpose"(%354, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%356 = "tosa.reshape"(%355) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%357 = "tosa.matmul"(%356, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%358 = "tosa.reshape"(%357) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%359 = "tosa.add"(%68, %358) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%360 = "tosa.reshape"(%359) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%361 = "tosa.add"(%360, %317) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%362 = "tosa.reduce_sum"(%361) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%363 = "tosa.mul"(%362, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%364 = "tosa.sub"(%361, %363) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%365 = "tosa.mul"(%364, %364) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%366 = "tosa.reduce_sum"(%365) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%367 = "tosa.mul"(%366, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%368 = "tosa.add"(%367, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%369 = "tosa.rsqrt"(%368) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%370 = "tosa.mul"(%364, %369) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%371 = "tosa.mul"(%370, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%372 = "tosa.add"(%371, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%373 = "tosa.matmul"(%372, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%374 = "tosa.reshape"(%373) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%375 = "tosa.add"(%174, %374) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%376 = "tosa.reshape"(%375) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%377 = "tosa.sub"(%376, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%378 = "tosa.mul"(%377, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%379 = "tosa.abs"(%378) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%380 = "tosa.mul"(%379, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%381 = "tosa.add"(%380, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%382 = "tosa.mul"(%379, %379) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%383 = "tosa.mul"(%382, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%384 = "tosa.add"(%381, %383) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%385 = "tosa.mul"(%382, %379) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%386 = "tosa.mul"(%385, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%387 = "tosa.add"(%384, %386) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%388 = "tosa.mul"(%385, %379) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%389 = "tosa.mul"(%388, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%390 = "tosa.add"(%387, %389) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%391 = "tosa.reciprocal"(%390) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%392 = "tosa.mul"(%391, %391) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%393 = "tosa.mul"(%392, %392) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%394 = "tosa.sub"(%32, %393) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%395 = "tosa.greater_equal"(%378, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%396 = "tosa.negate"(%394) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%397 = "tosa.select"(%395, %394, %396) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%398 = "tosa.add"(%397, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%399 = "tosa.mul"(%398, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%400 = "tosa.mul"(%376, %399) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%401 = "tosa.matmul"(%400, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%402 = "tosa.reshape"(%401) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%403 = "tosa.add"(%68, %402) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%404 = "tosa.reshape"(%403) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%405 = "tosa.add"(%404, %372) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%406 = "tosa.reduce_sum"(%405) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%407 = "tosa.mul"(%406, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%408 = "tosa.sub"(%405, %407) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%409 = "tosa.mul"(%408, %408) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%410 = "tosa.reduce_sum"(%409) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%411 = "tosa.mul"(%410, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%412 = "tosa.add"(%411, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%413 = "tosa.rsqrt"(%412) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%414 = "tosa.mul"(%408, %413) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%415 = "tosa.mul"(%414, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%416 = "tosa.add"(%415, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%417 = "tosa.matmul"(%416, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%418 = "tosa.reshape"(%417) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%419 = "tosa.add"(%68, %418) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%420 = "tosa.reshape"(%419) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%421 = "tosa.transpose"(%420, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%422 = "tosa.reshape"(%421) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%423 = "tosa.transpose"(%422, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%424 = "tosa.matmul"(%422, %423) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%425 = "tosa.mul"(%424, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%426 = "tosa.matmul"(%422, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%427 = "tosa.reshape"(%426) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%428 = "tosa.gather"(%427, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%429 = "tosa.reshape"(%428) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%430 = "tosa.mul"(%429, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%431 = "tosa.add"(%430, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%432 = "tosa.gather"(%427, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%433 = "tosa.reshape"(%432) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%434 = "tosa.transpose"(%433, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%435 = "tosa.mul"(%434, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%436 = "tosa.add"(%431, %435) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%437 = "tosa.add"(%425, %436) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%438 = "tosa.reshape"(%437) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%439 = torch_c.from_builtin_tensor %438 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%440 = torch.aten.masked_fill.Tensor %439, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%441 = torch_c.to_builtin_tensor %440 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%442 = "tosa.reduce_max"(%441) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%443 = "tosa.sub"(%441, %442) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%444 = "tosa.exp"(%443) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%445 = "tosa.reduce_sum"(%444) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%446 = "tosa.reciprocal"(%445) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%447 = "tosa.mul"(%444, %446) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%448 = torch_c.from_builtin_tensor %447 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%449 = torch.aten.masked_fill.Scalar %448, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%450 = torch_c.to_builtin_tensor %449 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%451 = "tosa.reshape"(%450) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%452 = "tosa.matmul"(%451, %422) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%453 = "tosa.reshape"(%452) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%454 = "tosa.transpose"(%453, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%455 = "tosa.reshape"(%454) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%456 = "tosa.matmul"(%455, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%457 = "tosa.reshape"(%456) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%458 = "tosa.add"(%68, %457) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%459 = "tosa.reshape"(%458) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%460 = "tosa.add"(%459, %416) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%461 = "tosa.reduce_sum"(%460) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%462 = "tosa.mul"(%461, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%463 = "tosa.sub"(%460, %462) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%464 = "tosa.mul"(%463, %463) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%465 = "tosa.reduce_sum"(%464) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%466 = "tosa.mul"(%465, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%467 = "tosa.add"(%466, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%468 = "tosa.rsqrt"(%467) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%469 = "tosa.mul"(%463, %468) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%470 = "tosa.mul"(%469, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%471 = "tosa.add"(%470, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%472 = "tosa.matmul"(%471, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%473 = "tosa.reshape"(%472) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%474 = "tosa.add"(%174, %473) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%475 = "tosa.reshape"(%474) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%476 = "tosa.sub"(%475, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%477 = "tosa.mul"(%476, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%478 = "tosa.abs"(%477) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%479 = "tosa.mul"(%478, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%480 = "tosa.add"(%479, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%481 = "tosa.mul"(%478, %478) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%482 = "tosa.mul"(%481, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%483 = "tosa.add"(%480, %482) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%484 = "tosa.mul"(%481, %478) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%485 = "tosa.mul"(%484, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%486 = "tosa.add"(%483, %485) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%487 = "tosa.mul"(%484, %478) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%488 = "tosa.mul"(%487, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%489 = "tosa.add"(%486, %488) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%490 = "tosa.reciprocal"(%489) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%491 = "tosa.mul"(%490, %490) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%492 = "tosa.mul"(%491, %491) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%493 = "tosa.sub"(%32, %492) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%494 = "tosa.greater_equal"(%477, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%495 = "tosa.negate"(%493) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%496 = "tosa.select"(%494, %493, %495) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%497 = "tosa.add"(%496, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%498 = "tosa.mul"(%497, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%499 = "tosa.mul"(%475, %498) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%500 = "tosa.matmul"(%499, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%501 = "tosa.reshape"(%500) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%502 = "tosa.add"(%68, %501) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%503 = "tosa.reshape"(%502) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%504 = "tosa.add"(%503, %471) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%505 = "tosa.reduce_sum"(%504) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%506 = "tosa.mul"(%505, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%507 = "tosa.sub"(%504, %506) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%508 = "tosa.mul"(%507, %507) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%509 = "tosa.reduce_sum"(%508) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%510 = "tosa.mul"(%509, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%511 = "tosa.add"(%510, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%512 = "tosa.rsqrt"(%511) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%513 = "tosa.mul"(%507, %512) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%514 = "tosa.mul"(%513, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%515 = "tosa.add"(%514, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%516 = "tosa.matmul"(%515, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%517 = "tosa.reshape"(%516) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%518 = "tosa.add"(%68, %517) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%519 = "tosa.reshape"(%518) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%520 = "tosa.transpose"(%519, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%521 = "tosa.reshape"(%520) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%522 = "tosa.transpose"(%521, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%523 = "tosa.matmul"(%521, %522) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%524 = "tosa.mul"(%523, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%525 = "tosa.matmul"(%521, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%526 = "tosa.reshape"(%525) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%527 = "tosa.gather"(%526, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%528 = "tosa.reshape"(%527) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%529 = "tosa.mul"(%528, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%530 = "tosa.add"(%529, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%531 = "tosa.gather"(%526, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%532 = "tosa.reshape"(%531) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%533 = "tosa.transpose"(%532, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%534 = "tosa.mul"(%533, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%535 = "tosa.add"(%530, %534) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%536 = "tosa.add"(%524, %535) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%537 = "tosa.reshape"(%536) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%538 = torch_c.from_builtin_tensor %537 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%539 = torch.aten.masked_fill.Tensor %538, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%540 = torch_c.to_builtin_tensor %539 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%541 = "tosa.reduce_max"(%540) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%542 = "tosa.sub"(%540, %541) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%543 = "tosa.exp"(%542) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%544 = "tosa.reduce_sum"(%543) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%545 = "tosa.reciprocal"(%544) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%546 = "tosa.mul"(%543, %545) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%547 = torch_c.from_builtin_tensor %546 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%548 = torch.aten.masked_fill.Scalar %547, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%549 = torch_c.to_builtin_tensor %548 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%550 = "tosa.reshape"(%549) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%551 = "tosa.matmul"(%550, %521) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%552 = "tosa.reshape"(%551) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%553 = "tosa.transpose"(%552, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%554 = "tosa.reshape"(%553) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%555 = "tosa.matmul"(%554, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%556 = "tosa.reshape"(%555) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%557 = "tosa.add"(%68, %556) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%558 = "tosa.reshape"(%557) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%559 = "tosa.add"(%558, %515) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%560 = "tosa.reduce_sum"(%559) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%561 = "tosa.mul"(%560, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%562 = "tosa.sub"(%559, %561) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%563 = "tosa.mul"(%562, %562) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%564 = "tosa.reduce_sum"(%563) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%565 = "tosa.mul"(%564, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%566 = "tosa.add"(%565, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%567 = "tosa.rsqrt"(%566) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%568 = "tosa.mul"(%562, %567) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%569 = "tosa.mul"(%568, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%570 = "tosa.add"(%569, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%571 = "tosa.matmul"(%570, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%572 = "tosa.reshape"(%571) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%573 = "tosa.add"(%174, %572) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%574 = "tosa.reshape"(%573) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%575 = "tosa.sub"(%574, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%576 = "tosa.mul"(%575, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%577 = "tosa.abs"(%576) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%578 = "tosa.mul"(%577, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%579 = "tosa.add"(%578, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%580 = "tosa.mul"(%577, %577) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%581 = "tosa.mul"(%580, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%582 = "tosa.add"(%579, %581) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%583 = "tosa.mul"(%580, %577) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%584 = "tosa.mul"(%583, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%585 = "tosa.add"(%582, %584) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%586 = "tosa.mul"(%583, %577) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%587 = "tosa.mul"(%586, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%588 = "tosa.add"(%585, %587) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%589 = "tosa.reciprocal"(%588) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%590 = "tosa.mul"(%589, %589) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%591 = "tosa.mul"(%590, %590) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%592 = "tosa.sub"(%32, %591) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%593 = "tosa.greater_equal"(%576, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%594 = "tosa.negate"(%592) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%595 = "tosa.select"(%593, %592, %594) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%596 = "tosa.add"(%595, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%597 = "tosa.mul"(%596, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%598 = "tosa.mul"(%574, %597) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%599 = "tosa.matmul"(%598, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%600 = "tosa.reshape"(%599) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%601 = "tosa.add"(%68, %600) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%602 = "tosa.reshape"(%601) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%603 = "tosa.add"(%602, %570) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%604 = "tosa.reduce_sum"(%603) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%605 = "tosa.mul"(%604, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%606 = "tosa.sub"(%603, %605) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%607 = "tosa.mul"(%606, %606) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%608 = "tosa.reduce_sum"(%607) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%609 = "tosa.mul"(%608, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%610 = "tosa.add"(%609, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%611 = "tosa.rsqrt"(%610) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%612 = "tosa.mul"(%606, %611) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%613 = "tosa.mul"(%612, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%614 = "tosa.add"(%613, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%615 = "tosa.matmul"(%614, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%616 = "tosa.reshape"(%615) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%617 = "tosa.add"(%68, %616) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%618 = "tosa.reshape"(%617) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%619 = "tosa.transpose"(%618, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%620 = "tosa.reshape"(%619) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%621 = "tosa.transpose"(%620, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%622 = "tosa.matmul"(%620, %621) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%623 = "tosa.mul"(%622, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%624 = "tosa.matmul"(%620, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%625 = "tosa.reshape"(%624) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%626 = "tosa.gather"(%625, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%627 = "tosa.reshape"(%626) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%628 = "tosa.mul"(%627, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%629 = "tosa.add"(%628, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%630 = "tosa.gather"(%625, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%631 = "tosa.reshape"(%630) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%632 = "tosa.transpose"(%631, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%633 = "tosa.mul"(%632, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%634 = "tosa.add"(%629, %633) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%635 = "tosa.add"(%623, %634) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%636 = "tosa.reshape"(%635) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%637 = torch_c.from_builtin_tensor %636 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%638 = torch.aten.masked_fill.Tensor %637, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%639 = torch_c.to_builtin_tensor %638 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%640 = "tosa.reduce_max"(%639) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%641 = "tosa.sub"(%639, %640) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%642 = "tosa.exp"(%641) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%643 = "tosa.reduce_sum"(%642) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%644 = "tosa.reciprocal"(%643) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%645 = "tosa.mul"(%642, %644) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%646 = torch_c.from_builtin_tensor %645 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%647 = torch.aten.masked_fill.Scalar %646, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%648 = torch_c.to_builtin_tensor %647 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%649 = "tosa.reshape"(%648) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%650 = "tosa.matmul"(%649, %620) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%651 = "tosa.reshape"(%650) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%652 = "tosa.transpose"(%651, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%653 = "tosa.reshape"(%652) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%654 = "tosa.matmul"(%653, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%655 = "tosa.reshape"(%654) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%656 = "tosa.add"(%68, %655) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%657 = "tosa.reshape"(%656) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%658 = "tosa.add"(%657, %614) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%659 = "tosa.reduce_sum"(%658) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%660 = "tosa.mul"(%659, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%661 = "tosa.sub"(%658, %660) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%662 = "tosa.mul"(%661, %661) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%663 = "tosa.reduce_sum"(%662) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%664 = "tosa.mul"(%663, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%665 = "tosa.add"(%664, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%666 = "tosa.rsqrt"(%665) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%667 = "tosa.mul"(%661, %666) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%668 = "tosa.mul"(%667, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%669 = "tosa.add"(%668, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%670 = "tosa.matmul"(%669, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%671 = "tosa.reshape"(%670) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%672 = "tosa.add"(%174, %671) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%673 = "tosa.reshape"(%672) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%674 = "tosa.sub"(%673, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%675 = "tosa.mul"(%674, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%676 = "tosa.abs"(%675) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%677 = "tosa.mul"(%676, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%678 = "tosa.add"(%677, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%679 = "tosa.mul"(%676, %676) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%680 = "tosa.mul"(%679, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%681 = "tosa.add"(%678, %680) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%682 = "tosa.mul"(%679, %676) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%683 = "tosa.mul"(%682, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%684 = "tosa.add"(%681, %683) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%685 = "tosa.mul"(%682, %676) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%686 = "tosa.mul"(%685, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%687 = "tosa.add"(%684, %686) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%688 = "tosa.reciprocal"(%687) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%689 = "tosa.mul"(%688, %688) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%690 = "tosa.mul"(%689, %689) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%691 = "tosa.sub"(%32, %690) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%692 = "tosa.greater_equal"(%675, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%693 = "tosa.negate"(%691) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%694 = "tosa.select"(%692, %691, %693) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%695 = "tosa.add"(%694, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%696 = "tosa.mul"(%695, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%697 = "tosa.mul"(%673, %696) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%698 = "tosa.matmul"(%697, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%699 = "tosa.reshape"(%698) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%700 = "tosa.add"(%68, %699) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%701 = "tosa.reshape"(%700) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%702 = "tosa.add"(%701, %669) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%703 = "tosa.reduce_sum"(%702) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%704 = "tosa.mul"(%703, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%705 = "tosa.sub"(%702, %704) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%706 = "tosa.mul"(%705, %705) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%707 = "tosa.reduce_sum"(%706) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%708 = "tosa.mul"(%707, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%709 = "tosa.add"(%708, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%710 = "tosa.rsqrt"(%709) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%711 = "tosa.mul"(%705, %710) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%712 = "tosa.mul"(%711, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%713 = "tosa.add"(%712, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%714 = "tosa.matmul"(%713, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%715 = "tosa.reshape"(%714) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%716 = "tosa.add"(%68, %715) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%717 = "tosa.reshape"(%716) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%718 = "tosa.transpose"(%717, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%719 = "tosa.reshape"(%718) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%720 = "tosa.transpose"(%719, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%721 = "tosa.matmul"(%719, %720) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%722 = "tosa.mul"(%721, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%723 = "tosa.matmul"(%719, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%724 = "tosa.reshape"(%723) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%725 = "tosa.gather"(%724, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%726 = "tosa.reshape"(%725) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%727 = "tosa.mul"(%726, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%728 = "tosa.add"(%727, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%729 = "tosa.gather"(%724, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%730 = "tosa.reshape"(%729) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%731 = "tosa.transpose"(%730, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%732 = "tosa.mul"(%731, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%733 = "tosa.add"(%728, %732) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%734 = "tosa.add"(%722, %733) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%735 = "tosa.reshape"(%734) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%736 = torch_c.from_builtin_tensor %735 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%737 = torch.aten.masked_fill.Tensor %736, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%738 = torch_c.to_builtin_tensor %737 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%739 = "tosa.reduce_max"(%738) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%740 = "tosa.sub"(%738, %739) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%741 = "tosa.exp"(%740) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%742 = "tosa.reduce_sum"(%741) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%743 = "tosa.reciprocal"(%742) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%744 = "tosa.mul"(%741, %743) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%745 = torch_c.from_builtin_tensor %744 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%746 = torch.aten.masked_fill.Scalar %745, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%747 = torch_c.to_builtin_tensor %746 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%748 = "tosa.reshape"(%747) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%749 = "tosa.matmul"(%748, %719) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%750 = "tosa.reshape"(%749) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%751 = "tosa.transpose"(%750, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%752 = "tosa.reshape"(%751) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%753 = "tosa.matmul"(%752, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%754 = "tosa.reshape"(%753) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%755 = "tosa.add"(%68, %754) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%756 = "tosa.reshape"(%755) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%757 = "tosa.add"(%756, %713) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%758 = "tosa.reduce_sum"(%757) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%759 = "tosa.mul"(%758, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%760 = "tosa.sub"(%757, %759) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%761 = "tosa.mul"(%760, %760) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%762 = "tosa.reduce_sum"(%761) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%763 = "tosa.mul"(%762, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%764 = "tosa.add"(%763, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%765 = "tosa.rsqrt"(%764) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%766 = "tosa.mul"(%760, %765) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%767 = "tosa.mul"(%766, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%768 = "tosa.add"(%767, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%769 = "tosa.matmul"(%768, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%770 = "tosa.reshape"(%769) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%771 = "tosa.add"(%174, %770) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%772 = "tosa.reshape"(%771) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%773 = "tosa.sub"(%772, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%774 = "tosa.mul"(%773, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%775 = "tosa.abs"(%774) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%776 = "tosa.mul"(%775, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%777 = "tosa.add"(%776, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%778 = "tosa.mul"(%775, %775) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%779 = "tosa.mul"(%778, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%780 = "tosa.add"(%777, %779) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%781 = "tosa.mul"(%778, %775) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%782 = "tosa.mul"(%781, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%783 = "tosa.add"(%780, %782) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%784 = "tosa.mul"(%781, %775) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%785 = "tosa.mul"(%784, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%786 = "tosa.add"(%783, %785) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%787 = "tosa.reciprocal"(%786) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%788 = "tosa.mul"(%787, %787) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%789 = "tosa.mul"(%788, %788) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%790 = "tosa.sub"(%32, %789) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%791 = "tosa.greater_equal"(%774, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%792 = "tosa.negate"(%790) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%793 = "tosa.select"(%791, %790, %792) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%794 = "tosa.add"(%793, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%795 = "tosa.mul"(%794, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%796 = "tosa.mul"(%772, %795) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%797 = "tosa.matmul"(%796, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%798 = "tosa.reshape"(%797) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%799 = "tosa.add"(%68, %798) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%800 = "tosa.reshape"(%799) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%801 = "tosa.add"(%800, %768) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%802 = "tosa.reduce_sum"(%801) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%803 = "tosa.mul"(%802, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%804 = "tosa.sub"(%801, %803) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%805 = "tosa.mul"(%804, %804) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%806 = "tosa.reduce_sum"(%805) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%807 = "tosa.mul"(%806, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%808 = "tosa.add"(%807, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%809 = "tosa.rsqrt"(%808) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%810 = "tosa.mul"(%804, %809) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%811 = "tosa.mul"(%810, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%812 = "tosa.add"(%811, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%813 = "tosa.matmul"(%812, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%814 = "tosa.reshape"(%813) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%815 = "tosa.add"(%68, %814) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%816 = "tosa.reshape"(%815) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%817 = "tosa.transpose"(%816, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%818 = "tosa.reshape"(%817) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%819 = "tosa.transpose"(%818, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%820 = "tosa.matmul"(%818, %819) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%821 = "tosa.mul"(%820, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%822 = "tosa.matmul"(%818, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%823 = "tosa.reshape"(%822) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%824 = "tosa.gather"(%823, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%825 = "tosa.reshape"(%824) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%826 = "tosa.mul"(%825, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%827 = "tosa.add"(%826, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%828 = "tosa.gather"(%823, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%829 = "tosa.reshape"(%828) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%830 = "tosa.transpose"(%829, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%831 = "tosa.mul"(%830, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%832 = "tosa.add"(%827, %831) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%833 = "tosa.add"(%821, %832) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%834 = "tosa.reshape"(%833) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%835 = torch_c.from_builtin_tensor %834 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%836 = torch.aten.masked_fill.Tensor %835, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%837 = torch_c.to_builtin_tensor %836 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%838 = "tosa.reduce_max"(%837) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%839 = "tosa.sub"(%837, %838) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%840 = "tosa.exp"(%839) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%841 = "tosa.reduce_sum"(%840) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%842 = "tosa.reciprocal"(%841) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%843 = "tosa.mul"(%840, %842) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%844 = torch_c.from_builtin_tensor %843 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%845 = torch.aten.masked_fill.Scalar %844, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%846 = torch_c.to_builtin_tensor %845 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%847 = "tosa.reshape"(%846) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%848 = "tosa.matmul"(%847, %818) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%849 = "tosa.reshape"(%848) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%850 = "tosa.transpose"(%849, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%851 = "tosa.reshape"(%850) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%852 = "tosa.matmul"(%851, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%853 = "tosa.reshape"(%852) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%854 = "tosa.add"(%68, %853) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%855 = "tosa.reshape"(%854) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%856 = "tosa.add"(%855, %812) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%857 = "tosa.reduce_sum"(%856) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%858 = "tosa.mul"(%857, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%859 = "tosa.sub"(%856, %858) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%860 = "tosa.mul"(%859, %859) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%861 = "tosa.reduce_sum"(%860) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%862 = "tosa.mul"(%861, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%863 = "tosa.add"(%862, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%864 = "tosa.rsqrt"(%863) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%865 = "tosa.mul"(%859, %864) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%866 = "tosa.mul"(%865, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%867 = "tosa.add"(%866, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%868 = "tosa.matmul"(%867, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%869 = "tosa.reshape"(%868) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%870 = "tosa.add"(%174, %869) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%871 = "tosa.reshape"(%870) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%872 = "tosa.sub"(%871, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%873 = "tosa.mul"(%872, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%874 = "tosa.abs"(%873) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%875 = "tosa.mul"(%874, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%876 = "tosa.add"(%875, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%877 = "tosa.mul"(%874, %874) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%878 = "tosa.mul"(%877, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%879 = "tosa.add"(%876, %878) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%880 = "tosa.mul"(%877, %874) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%881 = "tosa.mul"(%880, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%882 = "tosa.add"(%879, %881) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%883 = "tosa.mul"(%880, %874) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%884 = "tosa.mul"(%883, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%885 = "tosa.add"(%882, %884) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%886 = "tosa.reciprocal"(%885) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%887 = "tosa.mul"(%886, %886) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%888 = "tosa.mul"(%887, %887) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%889 = "tosa.sub"(%32, %888) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%890 = "tosa.greater_equal"(%873, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%891 = "tosa.negate"(%889) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%892 = "tosa.select"(%890, %889, %891) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%893 = "tosa.add"(%892, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%894 = "tosa.mul"(%893, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%895 = "tosa.mul"(%871, %894) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%896 = "tosa.matmul"(%895, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%897 = "tosa.reshape"(%896) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%898 = "tosa.add"(%68, %897) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%899 = "tosa.reshape"(%898) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%900 = "tosa.add"(%899, %867) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%901 = "tosa.reduce_sum"(%900) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%902 = "tosa.mul"(%901, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%903 = "tosa.sub"(%900, %902) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%904 = "tosa.mul"(%903, %903) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%905 = "tosa.reduce_sum"(%904) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%906 = "tosa.mul"(%905, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%907 = "tosa.add"(%906, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%908 = "tosa.rsqrt"(%907) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%909 = "tosa.mul"(%903, %908) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%910 = "tosa.mul"(%909, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%911 = "tosa.add"(%910, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%912 = "tosa.matmul"(%911, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%913 = "tosa.reshape"(%912) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%914 = "tosa.add"(%68, %913) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%915 = "tosa.reshape"(%914) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%916 = "tosa.transpose"(%915, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%917 = "tosa.reshape"(%916) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%918 = "tosa.transpose"(%917, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%919 = "tosa.matmul"(%917, %918) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%920 = "tosa.mul"(%919, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%921 = "tosa.matmul"(%917, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%922 = "tosa.reshape"(%921) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%923 = "tosa.gather"(%922, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%924 = "tosa.reshape"(%923) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%925 = "tosa.mul"(%924, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%926 = "tosa.add"(%925, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%927 = "tosa.gather"(%922, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%928 = "tosa.reshape"(%927) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%929 = "tosa.transpose"(%928, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%930 = "tosa.mul"(%929, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%931 = "tosa.add"(%926, %930) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%932 = "tosa.add"(%920, %931) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%933 = "tosa.reshape"(%932) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%934 = torch_c.from_builtin_tensor %933 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%935 = torch.aten.masked_fill.Tensor %934, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%936 = torch_c.to_builtin_tensor %935 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%937 = "tosa.reduce_max"(%936) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%938 = "tosa.sub"(%936, %937) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%939 = "tosa.exp"(%938) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%940 = "tosa.reduce_sum"(%939) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%941 = "tosa.reciprocal"(%940) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%942 = "tosa.mul"(%939, %941) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%943 = torch_c.from_builtin_tensor %942 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%944 = torch.aten.masked_fill.Scalar %943, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%945 = torch_c.to_builtin_tensor %944 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%946 = "tosa.reshape"(%945) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%947 = "tosa.matmul"(%946, %917) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%948 = "tosa.reshape"(%947) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%949 = "tosa.transpose"(%948, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%950 = "tosa.reshape"(%949) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%951 = "tosa.matmul"(%950, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%952 = "tosa.reshape"(%951) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%953 = "tosa.add"(%68, %952) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%954 = "tosa.reshape"(%953) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%955 = "tosa.add"(%954, %911) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%956 = "tosa.reduce_sum"(%955) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%957 = "tosa.mul"(%956, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%958 = "tosa.sub"(%955, %957) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%959 = "tosa.mul"(%958, %958) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%960 = "tosa.reduce_sum"(%959) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%961 = "tosa.mul"(%960, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%962 = "tosa.add"(%961, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%963 = "tosa.rsqrt"(%962) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%964 = "tosa.mul"(%958, %963) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%965 = "tosa.mul"(%964, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%966 = "tosa.add"(%965, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%967 = "tosa.matmul"(%966, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%968 = "tosa.reshape"(%967) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%969 = "tosa.add"(%174, %968) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%970 = "tosa.reshape"(%969) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%971 = "tosa.sub"(%970, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%972 = "tosa.mul"(%971, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%973 = "tosa.abs"(%972) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%974 = "tosa.mul"(%973, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%975 = "tosa.add"(%974, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%976 = "tosa.mul"(%973, %973) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%977 = "tosa.mul"(%976, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%978 = "tosa.add"(%975, %977) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%979 = "tosa.mul"(%976, %973) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%980 = "tosa.mul"(%979, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%981 = "tosa.add"(%978, %980) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%982 = "tosa.mul"(%979, %973) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%983 = "tosa.mul"(%982, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%984 = "tosa.add"(%981, %983) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%985 = "tosa.reciprocal"(%984) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%986 = "tosa.mul"(%985, %985) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%987 = "tosa.mul"(%986, %986) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%988 = "tosa.sub"(%32, %987) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%989 = "tosa.greater_equal"(%972, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%990 = "tosa.negate"(%988) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%991 = "tosa.select"(%989, %988, %990) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%992 = "tosa.add"(%991, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%993 = "tosa.mul"(%992, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%994 = "tosa.mul"(%970, %993) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%995 = "tosa.matmul"(%994, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%996 = "tosa.reshape"(%995) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%997 = "tosa.add"(%68, %996) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%998 = "tosa.reshape"(%997) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%999 = "tosa.add"(%998, %966) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1000 = "tosa.reduce_sum"(%999) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1001 = "tosa.mul"(%1000, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1002 = "tosa.sub"(%999, %1001) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1003 = "tosa.mul"(%1002, %1002) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1004 = "tosa.reduce_sum"(%1003) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1005 = "tosa.mul"(%1004, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1006 = "tosa.add"(%1005, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1007 = "tosa.rsqrt"(%1006) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%1008 = "tosa.mul"(%1002, %1007) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1009 = "tosa.mul"(%1008, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1010 = "tosa.add"(%1009, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1011 = "tosa.matmul"(%1010, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%1012 = "tosa.reshape"(%1011) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1013 = "tosa.add"(%68, %1012) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1014 = "tosa.reshape"(%1013) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%1015 = "tosa.transpose"(%1014, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%1016 = "tosa.reshape"(%1015) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%1017 = "tosa.transpose"(%1016, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%1018 = "tosa.matmul"(%1016, %1017) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%1019 = "tosa.mul"(%1018, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1020 = "tosa.matmul"(%1016, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%1021 = "tosa.reshape"(%1020) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%1022 = "tosa.gather"(%1021, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%1023 = "tosa.reshape"(%1022) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%1024 = "tosa.mul"(%1023, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1025 = "tosa.add"(%1024, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1026 = "tosa.gather"(%1021, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%1027 = "tosa.reshape"(%1026) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%1028 = "tosa.transpose"(%1027, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%1029 = "tosa.mul"(%1028, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1030 = "tosa.add"(%1025, %1029) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%1031 = "tosa.add"(%1019, %1030) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%1032 = "tosa.reshape"(%1031) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%1033 = torch_c.from_builtin_tensor %1032 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%1034 = torch.aten.masked_fill.Tensor %1033, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%1035 = torch_c.to_builtin_tensor %1034 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%1036 = "tosa.reduce_max"(%1035) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%1037 = "tosa.sub"(%1035, %1036) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%1038 = "tosa.exp"(%1037) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%1039 = "tosa.reduce_sum"(%1038) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%1040 = "tosa.reciprocal"(%1039) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%1041 = "tosa.mul"(%1038, %1040) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%1042 = torch_c.from_builtin_tensor %1041 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%1043 = torch.aten.masked_fill.Scalar %1042, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%1044 = torch_c.to_builtin_tensor %1043 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%1045 = "tosa.reshape"(%1044) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%1046 = "tosa.matmul"(%1045, %1016) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%1047 = "tosa.reshape"(%1046) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%1048 = "tosa.transpose"(%1047, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%1049 = "tosa.reshape"(%1048) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%1050 = "tosa.matmul"(%1049, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%1051 = "tosa.reshape"(%1050) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1052 = "tosa.add"(%68, %1051) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1053 = "tosa.reshape"(%1052) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%1054 = "tosa.add"(%1053, %1010) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1055 = "tosa.reduce_sum"(%1054) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1056 = "tosa.mul"(%1055, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1057 = "tosa.sub"(%1054, %1056) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1058 = "tosa.mul"(%1057, %1057) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1059 = "tosa.reduce_sum"(%1058) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1060 = "tosa.mul"(%1059, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1061 = "tosa.add"(%1060, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1062 = "tosa.rsqrt"(%1061) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%1063 = "tosa.mul"(%1057, %1062) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1064 = "tosa.mul"(%1063, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1065 = "tosa.add"(%1064, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1066 = "tosa.matmul"(%1065, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%1067 = "tosa.reshape"(%1066) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%1068 = "tosa.add"(%174, %1067) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%1069 = "tosa.reshape"(%1068) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%1070 = "tosa.sub"(%1069, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1071 = "tosa.mul"(%1070, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1072 = "tosa.abs"(%1071) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1073 = "tosa.mul"(%1072, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1074 = "tosa.add"(%1073, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1075 = "tosa.mul"(%1072, %1072) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1076 = "tosa.mul"(%1075, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1077 = "tosa.add"(%1074, %1076) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1078 = "tosa.mul"(%1075, %1072) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1079 = "tosa.mul"(%1078, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1080 = "tosa.add"(%1077, %1079) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1081 = "tosa.mul"(%1078, %1072) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1082 = "tosa.mul"(%1081, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1083 = "tosa.add"(%1080, %1082) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1084 = "tosa.reciprocal"(%1083) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1085 = "tosa.mul"(%1084, %1084) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1086 = "tosa.mul"(%1085, %1085) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1087 = "tosa.sub"(%32, %1086) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1088 = "tosa.greater_equal"(%1071, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%1089 = "tosa.negate"(%1087) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1090 = "tosa.select"(%1088, %1087, %1089) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1091 = "tosa.add"(%1090, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1092 = "tosa.mul"(%1091, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1093 = "tosa.mul"(%1069, %1092) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1094 = "tosa.matmul"(%1093, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%1095 = "tosa.reshape"(%1094) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1096 = "tosa.add"(%68, %1095) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1097 = "tosa.reshape"(%1096) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%1098 = "tosa.add"(%1097, %1065) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1099 = "tosa.reduce_sum"(%1098) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1100 = "tosa.mul"(%1099, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1101 = "tosa.sub"(%1098, %1100) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1102 = "tosa.mul"(%1101, %1101) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1103 = "tosa.reduce_sum"(%1102) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1104 = "tosa.mul"(%1103, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1105 = "tosa.add"(%1104, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1106 = "tosa.rsqrt"(%1105) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%1107 = "tosa.mul"(%1101, %1106) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1108 = "tosa.mul"(%1107, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1109 = "tosa.add"(%1108, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1110 = "tosa.matmul"(%1109, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%1111 = "tosa.reshape"(%1110) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1112 = "tosa.add"(%68, %1111) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1113 = "tosa.reshape"(%1112) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%1114 = "tosa.transpose"(%1113, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%1115 = "tosa.reshape"(%1114) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%1116 = "tosa.transpose"(%1115, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%1117 = "tosa.matmul"(%1115, %1116) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%1118 = "tosa.mul"(%1117, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1119 = "tosa.matmul"(%1115, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%1120 = "tosa.reshape"(%1119) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%1121 = "tosa.gather"(%1120, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%1122 = "tosa.reshape"(%1121) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%1123 = "tosa.mul"(%1122, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1124 = "tosa.add"(%1123, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1125 = "tosa.gather"(%1120, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%1126 = "tosa.reshape"(%1125) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%1127 = "tosa.transpose"(%1126, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%1128 = "tosa.mul"(%1127, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1129 = "tosa.add"(%1124, %1128) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%1130 = "tosa.add"(%1118, %1129) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%1131 = "tosa.reshape"(%1130) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%1132 = torch_c.from_builtin_tensor %1131 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%1133 = torch.aten.masked_fill.Tensor %1132, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%1134 = torch_c.to_builtin_tensor %1133 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%1135 = "tosa.reduce_max"(%1134) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%1136 = "tosa.sub"(%1134, %1135) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%1137 = "tosa.exp"(%1136) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%1138 = "tosa.reduce_sum"(%1137) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%1139 = "tosa.reciprocal"(%1138) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%1140 = "tosa.mul"(%1137, %1139) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%1141 = torch_c.from_builtin_tensor %1140 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%1142 = torch.aten.masked_fill.Scalar %1141, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%1143 = torch_c.to_builtin_tensor %1142 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%1144 = "tosa.reshape"(%1143) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%1145 = "tosa.matmul"(%1144, %1115) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%1146 = "tosa.reshape"(%1145) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%1147 = "tosa.transpose"(%1146, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%1148 = "tosa.reshape"(%1147) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%1149 = "tosa.matmul"(%1148, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%1150 = "tosa.reshape"(%1149) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1151 = "tosa.add"(%68, %1150) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1152 = "tosa.reshape"(%1151) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%1153 = "tosa.add"(%1152, %1109) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1154 = "tosa.reduce_sum"(%1153) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1155 = "tosa.mul"(%1154, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1156 = "tosa.sub"(%1153, %1155) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1157 = "tosa.mul"(%1156, %1156) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1158 = "tosa.reduce_sum"(%1157) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1159 = "tosa.mul"(%1158, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1160 = "tosa.add"(%1159, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1161 = "tosa.rsqrt"(%1160) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%1162 = "tosa.mul"(%1156, %1161) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1163 = "tosa.mul"(%1162, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1164 = "tosa.add"(%1163, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1165 = "tosa.matmul"(%1164, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%1166 = "tosa.reshape"(%1165) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%1167 = "tosa.add"(%174, %1166) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%1168 = "tosa.reshape"(%1167) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%1169 = "tosa.sub"(%1168, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1170 = "tosa.mul"(%1169, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1171 = "tosa.abs"(%1170) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1172 = "tosa.mul"(%1171, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1173 = "tosa.add"(%1172, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1174 = "tosa.mul"(%1171, %1171) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1175 = "tosa.mul"(%1174, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1176 = "tosa.add"(%1173, %1175) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1177 = "tosa.mul"(%1174, %1171) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1178 = "tosa.mul"(%1177, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1179 = "tosa.add"(%1176, %1178) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1180 = "tosa.mul"(%1177, %1171) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1181 = "tosa.mul"(%1180, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1182 = "tosa.add"(%1179, %1181) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1183 = "tosa.reciprocal"(%1182) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1184 = "tosa.mul"(%1183, %1183) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1185 = "tosa.mul"(%1184, %1184) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1186 = "tosa.sub"(%32, %1185) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1187 = "tosa.greater_equal"(%1170, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%1188 = "tosa.negate"(%1186) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1189 = "tosa.select"(%1187, %1186, %1188) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1190 = "tosa.add"(%1189, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1191 = "tosa.mul"(%1190, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1192 = "tosa.mul"(%1168, %1191) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1193 = "tosa.matmul"(%1192, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%1194 = "tosa.reshape"(%1193) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1195 = "tosa.add"(%68, %1194) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1196 = "tosa.reshape"(%1195) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%1197 = "tosa.add"(%1196, %1164) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1198 = "tosa.reduce_sum"(%1197) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1199 = "tosa.mul"(%1198, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1200 = "tosa.sub"(%1197, %1199) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1201 = "tosa.mul"(%1200, %1200) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1202 = "tosa.reduce_sum"(%1201) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1203 = "tosa.mul"(%1202, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1204 = "tosa.add"(%1203, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1205 = "tosa.rsqrt"(%1204) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%1206 = "tosa.mul"(%1200, %1205) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1207 = "tosa.mul"(%1206, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1208 = "tosa.add"(%1207, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1209 = "tosa.matmul"(%1208, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%1210 = "tosa.reshape"(%1209) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1211 = "tosa.add"(%68, %1210) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1212 = "tosa.reshape"(%1211) {new_shape = [1, 128, 12, -1]} : (tensor<128x768xf32>) -> tensor<1x128x12x64xf32>
%1213 = "tosa.transpose"(%1212, %13) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%1214 = "tosa.reshape"(%1213) {new_shape = [-1, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%1215 = "tosa.transpose"(%1214, %14) : (tensor<12x128x64xf32>, tensor<3xi32>) -> tensor<12x64x128xf32>
%1216 = "tosa.matmul"(%1214, %1215) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%1217 = "tosa.mul"(%1216, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1218 = "tosa.matmul"(%1214, %95) : (tensor<12x128x64xf32>, tensor<12x64x512xf32>) -> tensor<12x128x512xf32>
%1219 = "tosa.reshape"(%1218) {new_shape = [1, 786432, 1]} : (tensor<12x128x512xf32>) -> tensor<1x786432x1xf32>
%1220 = "tosa.gather"(%1219, %108) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%1221 = "tosa.reshape"(%1220) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%1222 = "tosa.mul"(%1221, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1223 = "tosa.add"(%1222, %33) : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1224 = "tosa.gather"(%1219, %124) : (tensor<1x786432x1xf32>, tensor<1x16384xi32>) -> tensor<1x16384x1xf32>
%1225 = "tosa.reshape"(%1224) {new_shape = [12, 128, 128]} : (tensor<1x16384x1xf32>) -> tensor<12x128x128xf32>
%1226 = "tosa.transpose"(%1225, %14) : (tensor<12x128x128xf32>, tensor<3xi32>) -> tensor<12x128x128xf32>
%1227 = "tosa.mul"(%1226, %85) {shift = 0 : i32} : (tensor<12x128x128xf32>, tensor<1x1x1xf32>) -> tensor<12x128x128xf32>
%1228 = "tosa.add"(%1223, %1227) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%1229 = "tosa.add"(%1217, %1228) : (tensor<12x128x128xf32>, tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
%1230 = "tosa.reshape"(%1229) {new_shape = [-1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%1231 = torch_c.from_builtin_tensor %1230 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%1232 = torch.aten.masked_fill.Tensor %1231, %136, %137 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%1233 = torch_c.to_builtin_tensor %1232 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%1234 = "tosa.reduce_max"(%1233) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%1235 = "tosa.sub"(%1233, %1234) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%1236 = "tosa.exp"(%1235) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%1237 = "tosa.reduce_sum"(%1236) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%1238 = "tosa.reciprocal"(%1237) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%1239 = "tosa.mul"(%1236, %1238) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%1240 = torch_c.from_builtin_tensor %1239 : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
%1241 = torch.aten.masked_fill.Scalar %1240, %136, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%1242 = torch_c.to_builtin_tensor %1241 : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32>
%1243 = "tosa.reshape"(%1242) {new_shape = [-1, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%1244 = "tosa.matmul"(%1243, %1214) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%1245 = "tosa.reshape"(%1244) {new_shape = [-1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%1246 = "tosa.transpose"(%1245, %13) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%1247 = "tosa.reshape"(%1246) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%1248 = "tosa.matmul"(%1247, %75) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%1249 = "tosa.reshape"(%1248) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1250 = "tosa.add"(%68, %1249) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1251 = "tosa.reshape"(%1250) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%1252 = "tosa.add"(%1251, %1208) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1253 = "tosa.reduce_sum"(%1252) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1254 = "tosa.mul"(%1253, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1255 = "tosa.sub"(%1252, %1254) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1256 = "tosa.mul"(%1255, %1255) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1257 = "tosa.reduce_sum"(%1256) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1258 = "tosa.mul"(%1257, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1259 = "tosa.add"(%1258, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1260 = "tosa.rsqrt"(%1259) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%1261 = "tosa.mul"(%1255, %1260) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1262 = "tosa.mul"(%1261, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1263 = "tosa.add"(%1262, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1264 = "tosa.matmul"(%1263, %171) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%1265 = "tosa.reshape"(%1264) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%1266 = "tosa.add"(%174, %1265) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%1267 = "tosa.reshape"(%1266) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%1268 = "tosa.sub"(%1267, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1269 = "tosa.mul"(%1268, %31) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1270 = "tosa.abs"(%1269) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1271 = "tosa.mul"(%1270, %30) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1272 = "tosa.add"(%1271, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1273 = "tosa.mul"(%1270, %1270) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1274 = "tosa.mul"(%1273, %29) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1275 = "tosa.add"(%1272, %1274) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1276 = "tosa.mul"(%1273, %1270) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1277 = "tosa.mul"(%1276, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1278 = "tosa.add"(%1275, %1277) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1279 = "tosa.mul"(%1276, %1270) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1280 = "tosa.mul"(%1279, %27) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1281 = "tosa.add"(%1278, %1280) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1282 = "tosa.reciprocal"(%1281) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1283 = "tosa.mul"(%1282, %1282) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1284 = "tosa.mul"(%1283, %1283) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1285 = "tosa.sub"(%32, %1284) : (tensor<1x1x1xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1286 = "tosa.greater_equal"(%1269, %33) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xi1>
%1287 = "tosa.negate"(%1285) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1288 = "tosa.select"(%1286, %1285, %1287) : (tensor<1x128x3072xi1>, tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1289 = "tosa.add"(%1288, %32) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1290 = "tosa.mul"(%1289, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%1291 = "tosa.mul"(%1267, %1290) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%1292 = "tosa.matmul"(%1291, %202) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%1293 = "tosa.reshape"(%1292) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%1294 = "tosa.add"(%68, %1293) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%1295 = "tosa.reshape"(%1294) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%1296 = "tosa.add"(%1295, %1263) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1297 = "tosa.reduce_sum"(%1296) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1298 = "tosa.mul"(%1297, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1299 = "tosa.sub"(%1296, %1298) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1300 = "tosa.mul"(%1299, %1299) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%1301 = "tosa.reduce_sum"(%1300) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%1302 = "tosa.mul"(%1301, %47) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1303 = "tosa.add"(%1302, %37) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%1304 = "tosa.rsqrt"(%1303) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%1305 = "tosa.mul"(%1299, %1304) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%1306 = "tosa.mul"(%1305, %53) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1307 = "tosa.add"(%1306, %53) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%1308 = "tosa.slice"(%1307) {size = [1, 1, 768], start = [0, 0, 0]} : (tensor<1x128x768xf32>) -> tensor<1x1x768xf32>
%1309 = "tosa.matmul"(%1308, %75) : (tensor<1x1x768xf32>, tensor<1x768x768xf32>) -> tensor<1x1x768xf32>
%1310 = "tosa.reshape"(%1309) {new_shape = [1, 768]} : (tensor<1x1x768xf32>) -> tensor<1x768xf32>
%1311 = "tosa.sub"(%1310, %25) : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1312 = "tosa.mul"(%1311, %24) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1313 = "tosa.abs"(%1312) : (tensor<1x768xf32>) -> tensor<1x768xf32>
%1314 = "tosa.mul"(%1313, %23) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1315 = "tosa.add"(%1314, %35) : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1316 = "tosa.mul"(%1313, %1313) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1317 = "tosa.mul"(%1316, %22) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1318 = "tosa.add"(%1315, %1317) : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1319 = "tosa.mul"(%1316, %1313) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1320 = "tosa.mul"(%1319, %21) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1321 = "tosa.add"(%1318, %1320) : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1322 = "tosa.mul"(%1319, %1313) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1323 = "tosa.mul"(%1322, %20) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1324 = "tosa.add"(%1321, %1323) : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1325 = "tosa.reciprocal"(%1324) : (tensor<1x768xf32>) -> tensor<1x768xf32>
%1326 = "tosa.mul"(%1325, %1325) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1327 = "tosa.mul"(%1326, %1326) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1328 = "tosa.sub"(%35, %1327) : (tensor<1x1xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1329 = "tosa.greater_equal"(%1312, %25) : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xi1>
%1330 = "tosa.negate"(%1328) : (tensor<1x768xf32>) -> tensor<1x768xf32>
%1331 = "tosa.select"(%1329, %1328, %1330) : (tensor<1x768xi1>, tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1332 = "tosa.add"(%1331, %35) : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1333 = "tosa.mul"(%1332, %19) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x1xf32>) -> tensor<1x768xf32>
%1334 = "tosa.mul"(%1310, %1333) {shift = 0 : i32} : (tensor<1x768xf32>, tensor<1x768xf32>) -> tensor<1x768xf32>
%1335 = "tosa.transpose"(%1, %12) : (tensor<2x768xf32>, tensor<2xi32>) -> tensor<768x2xf32>
%1336 = "tosa.reshape"(%1334) {new_shape = [1, 1, 768]} : (tensor<1x768xf32>) -> tensor<1x1x768xf32>
%1337 = "tosa.reshape"(%1335) {new_shape = [1, 768, 2]} : (tensor<768x2xf32>) -> tensor<1x768x2xf32>
%1338 = "tosa.matmul"(%1336, %1337) : (tensor<1x1x768xf32>, tensor<1x768x2xf32>) -> tensor<1x1x2xf32>
%1339 = "tosa.reshape"(%1338) {new_shape = [1, 2]} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
%1340 = torch_c.from_builtin_tensor %1339 : tensor<1x2xf32> -> !torch.vtensor<[1,2],f32>
%1341 = torch_c.to_builtin_tensor %1340 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
return %1341 : tensor<1x2xf32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment