Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AmosLewis/f2d8fb5d121d38a18577aea5234b8d45 to your computer and use it in GitHub Desktop.
Save AmosLewis/f2d8fb5d121d38a18577aea5234b8d45 to your computer and use it in GitHub Desktop.
func.func @forward(%arg0: tensor<1x128xi64>) -> tensor<1x2xf32> {
%0 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2x768xf32>} : () -> tensor<2x768xf32>
%1 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768xf32>} : () -> tensor<768xf32>
%2 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072x768xf32>} : () -> tensor<3072x768xf32>
%3 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x3072xf32>} : () -> tensor<768x3072xf32>
%4 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072xf32>} : () -> tensor<3072xf32>
%5 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x768xf32>} : () -> tensor<768x768xf32>
%6 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x2304xf32>} : () -> tensor<768x2304xf32>
%7 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2304xf32>} : () -> tensor<2304xf32>
%8 = "tosa.const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
%9 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x1x1024x1024xui8>} : () -> tensor<1x1x1024x1024xi8>
%10 = "tosa.const"() {value = dense<8.000000e+00> : tensor<f32>} : () -> tensor<f32>
%11 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1024x768xf32>} : () -> tensor<1024x768xf32>
%12 = "tosa.const"() {value = dense_resource<__elided__> : tensor<50257x768xf32>} : () -> tensor<50257x768xf32>
%13 = "tosa.const"() {value = dense<7.680000e+02> : tensor<1xf32>} : () -> tensor<1xf32>
%14 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x98304xi32>} : () -> tensor<1x98304xi32>
%15 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x98304xi32>} : () -> tensor<1x98304xi32>
%16 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x98304xi32>} : () -> tensor<1x98304xi32>
%17 = "tosa.const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
%18 = "tosa.const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%19 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x1048576xi32>} : () -> tensor<1x1048576xi32>
%20 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x131072xi32>} : () -> tensor<1x131072xi32>
%21 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x16384xi32>} : () -> tensor<1x16384xi32>
%22 = "tosa.const"() {value = dense<0> : tensor<1x1x128x128xi8>} : () -> tensor<1x1x128x128xi8>
%23 = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%24 = "tosa.const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
%25 = "tosa.const"() {value = dense<0.797884583> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%26 = "tosa.const"() {value = dense<4.471500e-02> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%27 = "tosa.const"() {value = dense<3.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%28 = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%29 = "tosa.const"() {value = dense<9.99999974E-6> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%30 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%31 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x128xi64>} : () -> tensor<1x128xi64>
%32 = torch_c.from_builtin_tensor %arg0 : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64>
%33 = torch_c.to_builtin_tensor %32 : !torch.vtensor<[1,128],si64> -> tensor<1x128xi64>
%34 = "tosa.reshape"(%12) {new_shape = [1, 50257, 768]} : (tensor<50257x768xf32>) -> tensor<1x50257x768xf32>
%35 = "tosa.cast"(%33) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%36 = "tosa.gather"(%34, %35) : (tensor<1x50257x768xf32>, tensor<1x128xi32>) -> tensor<1x128x768xf32>
%37 = "tosa.reshape"(%11) {new_shape = [1, 1024, 768]} : (tensor<1024x768xf32>) -> tensor<1x1024x768xf32>
%38 = "tosa.cast"(%31) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%39 = "tosa.gather"(%37, %38) : (tensor<1x1024x768xf32>, tensor<1x128xi32>) -> tensor<1x128x768xf32>
%40 = "tosa.add"(%36, %39) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%41 = "tosa.reciprocal"(%13) : (tensor<1xf32>) -> tensor<1xf32>
%42 = "tosa.reduce_sum"(%40) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%43 = "tosa.reshape"(%41) {new_shape = [1, 1, 1]} : (tensor<1xf32>) -> tensor<1x1x1xf32>
%44 = "tosa.mul"(%42, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%45 = "tosa.sub"(%40, %44) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%46 = "tosa.mul"(%45, %45) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%47 = "tosa.reduce_sum"(%46) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%48 = "tosa.mul"(%47, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%49 = "tosa.reshape"(%1) {new_shape = [1, 1, 768]} : (tensor<768xf32>) -> tensor<1x1x768xf32>
%50 = "tosa.add"(%48, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%51 = "tosa.rsqrt"(%50) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%52 = "tosa.mul"(%45, %51) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%53 = "tosa.mul"(%52, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%54 = "tosa.add"(%53, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%55 = "tosa.reshape"(%6) {new_shape = [1, 768, 2304]} : (tensor<768x2304xf32>) -> tensor<1x768x2304xf32>
%56 = "tosa.matmul"(%54, %55) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%57 = "tosa.reshape"(%56) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%58 = "tosa.reshape"(%7) {new_shape = [1, 2304]} : (tensor<2304xf32>) -> tensor<1x2304xf32>
%59 = "tosa.add"(%58, %57) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%60 = "tosa.reshape"(%59) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%61 = "tosa.gather"(%60, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%62 = "tosa.gather"(%60, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%63 = "tosa.gather"(%60, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%64 = "tosa.reshape"(%61) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%65 = "tosa.transpose"(%64, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%66 = "tosa.reshape"(%62) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%67 = "tosa.transpose"(%66, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%68 = "tosa.reshape"(%63) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%69 = "tosa.transpose"(%68, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%70 = "tosa.transpose"(%67, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%71 = "tosa.reshape"(%65) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%72 = "tosa.reshape"(%70) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%73 = "tosa.matmul"(%71, %72) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%74 = "tosa.reshape"(%73) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%75 = "tosa.reciprocal"(%10) : (tensor<f32>) -> tensor<f32>
%76 = "tosa.reshape"(%75) {new_shape = [1, 1, 1, 1]} : (tensor<f32>) -> tensor<1x1x1x1xf32>
%77 = "tosa.mul"(%74, %76) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%78 = "tosa.reshape"(%9) {new_shape = [1, 1048576, 1]} : (tensor<1x1x1024x1024xi8>) -> tensor<1x1048576x1xi8>
%79 = "tosa.gather"(%78, %19) : (tensor<1x1048576x1xi8>, tensor<1x1048576xi32>) -> tensor<1x1048576x1xi8>
%80 = "tosa.gather"(%79, %19) : (tensor<1x1048576x1xi8>, tensor<1x1048576xi32>) -> tensor<1x1048576x1xi8>
%81 = "tosa.gather"(%80, %20) : (tensor<1x1048576x1xi8>, tensor<1x131072xi32>) -> tensor<1x131072x1xi8>
%82 = "tosa.gather"(%81, %21) : (tensor<1x131072x1xi8>, tensor<1x16384xi32>) -> tensor<1x16384x1xi8>
%83 = "tosa.reshape"(%82) {new_shape = [1, 1, 128, 128]} : (tensor<1x16384x1xi8>) -> tensor<1x1x128x128xi8>
%84 = "tosa.equal"(%83, %22) : (tensor<1x1x128x128xi8>, tensor<1x1x128x128xi8>) -> tensor<1x1x128x128xi1>
%85 = "tosa.select"(%84, %8, %77) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%86 = "tosa.reduce_max"(%85) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%87 = "tosa.sub"(%85, %86) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%88 = "tosa.exp"(%87) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%89 = "tosa.reduce_sum"(%88) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%90 = "tosa.reciprocal"(%89) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%91 = "tosa.mul"(%88, %90) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%92 = "tosa.reshape"(%91) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%93 = "tosa.reshape"(%69) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%94 = "tosa.matmul"(%92, %93) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%95 = "tosa.reshape"(%94) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%96 = "tosa.transpose"(%95, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%97 = "tosa.reshape"(%96) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%98 = "tosa.reshape"(%5) {new_shape = [1, 768, 768]} : (tensor<768x768xf32>) -> tensor<1x768x768xf32>
%99 = "tosa.matmul"(%97, %98) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%100 = "tosa.reshape"(%99) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%101 = "tosa.reshape"(%1) {new_shape = [1, 768]} : (tensor<768xf32>) -> tensor<1x768xf32>
%102 = "tosa.add"(%101, %100) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%103 = "tosa.reshape"(%102) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%104 = "tosa.add"(%103, %40) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%105 = "tosa.reduce_sum"(%104) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%106 = "tosa.mul"(%105, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%107 = "tosa.sub"(%104, %106) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%108 = "tosa.mul"(%107, %107) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%109 = "tosa.reduce_sum"(%108) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%110 = "tosa.mul"(%109, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%111 = "tosa.add"(%110, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%112 = "tosa.rsqrt"(%111) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%113 = "tosa.mul"(%107, %112) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%114 = "tosa.mul"(%113, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%115 = "tosa.add"(%114, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%116 = "tosa.reshape"(%3) {new_shape = [1, 768, 3072]} : (tensor<768x3072xf32>) -> tensor<1x768x3072xf32>
%117 = "tosa.matmul"(%115, %116) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%118 = "tosa.reshape"(%117) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%119 = "tosa.reshape"(%4) {new_shape = [1, 3072]} : (tensor<3072xf32>) -> tensor<1x3072xf32>
%120 = "tosa.add"(%119, %118) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%121 = "tosa.reshape"(%120) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%122 = "tosa.mul"(%121, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%123 = "tosa.pow"(%121, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%124 = "tosa.mul"(%123, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%125 = "tosa.add"(%121, %124) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%126 = "tosa.mul"(%125, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%127 = "tosa.tanh"(%126) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%128 = "tosa.add"(%127, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%129 = "tosa.mul"(%122, %128) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%130 = "tosa.reshape"(%2) {new_shape = [1, 3072, 768]} : (tensor<3072x768xf32>) -> tensor<1x3072x768xf32>
%131 = "tosa.matmul"(%129, %130) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%132 = "tosa.reshape"(%131) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%133 = "tosa.add"(%101, %132) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%134 = "tosa.reshape"(%133) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%135 = "tosa.add"(%104, %134) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%136 = "tosa.reduce_sum"(%135) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%137 = "tosa.mul"(%136, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%138 = "tosa.sub"(%135, %137) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%139 = "tosa.mul"(%138, %138) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%140 = "tosa.reduce_sum"(%139) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%141 = "tosa.mul"(%140, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%142 = "tosa.add"(%141, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%143 = "tosa.rsqrt"(%142) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%144 = "tosa.mul"(%138, %143) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%145 = "tosa.mul"(%144, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%146 = "tosa.add"(%145, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%147 = "tosa.matmul"(%146, %55) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%148 = "tosa.reshape"(%147) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%149 = "tosa.add"(%58, %148) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%150 = "tosa.reshape"(%149) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%151 = "tosa.gather"(%150, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%152 = "tosa.gather"(%150, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%153 = "tosa.gather"(%150, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%154 = "tosa.reshape"(%151) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%155 = "tosa.transpose"(%154, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%156 = "tosa.reshape"(%152) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%157 = "tosa.transpose"(%156, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%158 = "tosa.reshape"(%153) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%159 = "tosa.transpose"(%158, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%160 = "tosa.transpose"(%157, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%161 = "tosa.reshape"(%155) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%162 = "tosa.reshape"(%160) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%163 = "tosa.matmul"(%161, %162) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%164 = "tosa.reshape"(%163) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%165 = "tosa.mul"(%164, %76) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%166 = "tosa.select"(%84, %8, %165) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%167 = "tosa.reduce_max"(%166) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%168 = "tosa.sub"(%166, %167) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%169 = "tosa.exp"(%168) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%170 = "tosa.reduce_sum"(%169) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%171 = "tosa.reciprocal"(%170) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%172 = "tosa.mul"(%169, %171) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%173 = "tosa.reshape"(%172) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%174 = "tosa.reshape"(%159) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%175 = "tosa.matmul"(%173, %174) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%176 = "tosa.reshape"(%175) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%177 = "tosa.transpose"(%176, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%178 = "tosa.reshape"(%177) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%179 = "tosa.matmul"(%178, %98) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%180 = "tosa.reshape"(%179) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%181 = "tosa.add"(%101, %180) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%182 = "tosa.reshape"(%181) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%183 = "tosa.add"(%182, %135) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%184 = "tosa.reduce_sum"(%183) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%185 = "tosa.mul"(%184, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%186 = "tosa.sub"(%183, %185) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%187 = "tosa.mul"(%186, %186) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%188 = "tosa.reduce_sum"(%187) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%189 = "tosa.mul"(%188, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%190 = "tosa.add"(%189, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%191 = "tosa.rsqrt"(%190) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%192 = "tosa.mul"(%186, %191) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%193 = "tosa.mul"(%192, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%194 = "tosa.add"(%193, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%195 = "tosa.matmul"(%194, %116) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%196 = "tosa.reshape"(%195) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%197 = "tosa.add"(%119, %196) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%198 = "tosa.reshape"(%197) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%199 = "tosa.mul"(%198, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%200 = "tosa.pow"(%198, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%201 = "tosa.mul"(%200, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%202 = "tosa.add"(%198, %201) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%203 = "tosa.mul"(%202, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%204 = "tosa.tanh"(%203) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%205 = "tosa.add"(%204, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%206 = "tosa.mul"(%199, %205) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%207 = "tosa.matmul"(%206, %130) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%208 = "tosa.reshape"(%207) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%209 = "tosa.add"(%101, %208) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%210 = "tosa.reshape"(%209) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%211 = "tosa.add"(%183, %210) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%212 = "tosa.reduce_sum"(%211) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%213 = "tosa.mul"(%212, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%214 = "tosa.sub"(%211, %213) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%215 = "tosa.mul"(%214, %214) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%216 = "tosa.reduce_sum"(%215) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%217 = "tosa.mul"(%216, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%218 = "tosa.add"(%217, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%219 = "tosa.rsqrt"(%218) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%220 = "tosa.mul"(%214, %219) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%221 = "tosa.mul"(%220, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%222 = "tosa.add"(%221, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%223 = "tosa.matmul"(%222, %55) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%224 = "tosa.reshape"(%223) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%225 = "tosa.add"(%58, %224) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%226 = "tosa.reshape"(%225) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%227 = "tosa.gather"(%226, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%228 = "tosa.gather"(%226, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%229 = "tosa.gather"(%226, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%230 = "tosa.reshape"(%227) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%231 = "tosa.transpose"(%230, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%232 = "tosa.reshape"(%228) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%233 = "tosa.transpose"(%232, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%234 = "tosa.reshape"(%229) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%235 = "tosa.transpose"(%234, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%236 = "tosa.transpose"(%233, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%237 = "tosa.reshape"(%231) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%238 = "tosa.reshape"(%236) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%239 = "tosa.matmul"(%237, %238) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%240 = "tosa.reshape"(%239) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%241 = "tosa.mul"(%240, %76) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%242 = "tosa.select"(%84, %8, %241) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%243 = "tosa.reduce_max"(%242) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%244 = "tosa.sub"(%242, %243) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%245 = "tosa.exp"(%244) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%246 = "tosa.reduce_sum"(%245) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%247 = "tosa.reciprocal"(%246) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%248 = "tosa.mul"(%245, %247) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%249 = "tosa.reshape"(%248) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%250 = "tosa.reshape"(%235) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%251 = "tosa.matmul"(%249, %250) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%252 = "tosa.reshape"(%251) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%253 = "tosa.transpose"(%252, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%254 = "tosa.reshape"(%253) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%255 = "tosa.matmul"(%254, %98) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%256 = "tosa.reshape"(%255) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%257 = "tosa.add"(%101, %256) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%258 = "tosa.reshape"(%257) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%259 = "tosa.add"(%258, %211) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%260 = "tosa.reduce_sum"(%259) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%261 = "tosa.mul"(%260, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%262 = "tosa.sub"(%259, %261) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%263 = "tosa.mul"(%262, %262) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%264 = "tosa.reduce_sum"(%263) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%265 = "tosa.mul"(%264, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%266 = "tosa.add"(%265, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%267 = "tosa.rsqrt"(%266) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%268 = "tosa.mul"(%262, %267) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%269 = "tosa.mul"(%268, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%270 = "tosa.add"(%269, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%271 = "tosa.matmul"(%270, %116) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%272 = "tosa.reshape"(%271) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%273 = "tosa.add"(%119, %272) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%274 = "tosa.reshape"(%273) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%275 = "tosa.mul"(%274, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%276 = "tosa.pow"(%274, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%277 = "tosa.mul"(%276, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%278 = "tosa.add"(%274, %277) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%279 = "tosa.mul"(%278, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%280 = "tosa.tanh"(%279) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%281 = "tosa.add"(%280, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%282 = "tosa.mul"(%275, %281) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%283 = "tosa.matmul"(%282, %130) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%284 = "tosa.reshape"(%283) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%285 = "tosa.add"(%101, %284) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%286 = "tosa.reshape"(%285) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%287 = "tosa.add"(%259, %286) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%288 = "tosa.reduce_sum"(%287) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%289 = "tosa.mul"(%288, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%290 = "tosa.sub"(%287, %289) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%291 = "tosa.mul"(%290, %290) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%292 = "tosa.reduce_sum"(%291) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%293 = "tosa.mul"(%292, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%294 = "tosa.add"(%293, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%295 = "tosa.rsqrt"(%294) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%296 = "tosa.mul"(%290, %295) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%297 = "tosa.mul"(%296, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%298 = "tosa.add"(%297, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%299 = "tosa.matmul"(%298, %55) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%300 = "tosa.reshape"(%299) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%301 = "tosa.add"(%58, %300) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%302 = "tosa.reshape"(%301) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%303 = "tosa.gather"(%302, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%304 = "tosa.gather"(%302, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%305 = "tosa.gather"(%302, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%306 = "tosa.reshape"(%303) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%307 = "tosa.transpose"(%306, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%308 = "tosa.reshape"(%304) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%309 = "tosa.transpose"(%308, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%310 = "tosa.reshape"(%305) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%311 = "tosa.transpose"(%310, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%312 = "tosa.transpose"(%309, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%313 = "tosa.reshape"(%307) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%314 = "tosa.reshape"(%312) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%315 = "tosa.matmul"(%313, %314) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%316 = "tosa.reshape"(%315) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%317 = "tosa.mul"(%316, %76) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%318 = "tosa.select"(%84, %8, %317) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%319 = "tosa.reduce_max"(%318) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%320 = "tosa.sub"(%318, %319) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%321 = "tosa.exp"(%320) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%322 = "tosa.reduce_sum"(%321) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%323 = "tosa.reciprocal"(%322) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%324 = "tosa.mul"(%321, %323) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%325 = "tosa.reshape"(%324) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%326 = "tosa.reshape"(%311) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%327 = "tosa.matmul"(%325, %326) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%328 = "tosa.reshape"(%327) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%329 = "tosa.transpose"(%328, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%330 = "tosa.reshape"(%329) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%331 = "tosa.matmul"(%330, %98) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%332 = "tosa.reshape"(%331) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%333 = "tosa.add"(%101, %332) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%334 = "tosa.reshape"(%333) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%335 = "tosa.add"(%334, %287) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%336 = "tosa.reduce_sum"(%335) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%337 = "tosa.mul"(%336, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%338 = "tosa.sub"(%335, %337) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%339 = "tosa.mul"(%338, %338) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%340 = "tosa.reduce_sum"(%339) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%341 = "tosa.mul"(%340, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%342 = "tosa.add"(%341, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%343 = "tosa.rsqrt"(%342) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%344 = "tosa.mul"(%338, %343) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%345 = "tosa.mul"(%344, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%346 = "tosa.add"(%345, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%347 = "tosa.matmul"(%346, %116) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%348 = "tosa.reshape"(%347) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%349 = "tosa.add"(%119, %348) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%350 = "tosa.reshape"(%349) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%351 = "tosa.mul"(%350, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%352 = "tosa.pow"(%350, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%353 = "tosa.mul"(%352, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%354 = "tosa.add"(%350, %353) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%355 = "tosa.mul"(%354, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%356 = "tosa.tanh"(%355) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%357 = "tosa.add"(%356, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%358 = "tosa.mul"(%351, %357) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%359 = "tosa.matmul"(%358, %130) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%360 = "tosa.reshape"(%359) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%361 = "tosa.add"(%101, %360) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%362 = "tosa.reshape"(%361) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%363 = "tosa.add"(%335, %362) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%364 = "tosa.reduce_sum"(%363) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%365 = "tosa.mul"(%364, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%366 = "tosa.sub"(%363, %365) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%367 = "tosa.mul"(%366, %366) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%368 = "tosa.reduce_sum"(%367) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%369 = "tosa.mul"(%368, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%370 = "tosa.add"(%369, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%371 = "tosa.rsqrt"(%370) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%372 = "tosa.mul"(%366, %371) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%373 = "tosa.mul"(%372, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%374 = "tosa.add"(%373, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%375 = "tosa.matmul"(%374, %55) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%376 = "tosa.reshape"(%375) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%377 = "tosa.add"(%58, %376) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%378 = "tosa.reshape"(%377) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%379 = "tosa.gather"(%378, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%380 = "tosa.gather"(%378, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%381 = "tosa.gather"(%378, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%382 = "tosa.reshape"(%379) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%383 = "tosa.transpose"(%382, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%384 = "tosa.reshape"(%380) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%385 = "tosa.transpose"(%384, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%386 = "tosa.reshape"(%381) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%387 = "tosa.transpose"(%386, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%388 = "tosa.transpose"(%385, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%389 = "tosa.reshape"(%383) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%390 = "tosa.reshape"(%388) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%391 = "tosa.matmul"(%389, %390) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%392 = "tosa.reshape"(%391) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%393 = "tosa.mul"(%392, %76) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%394 = "tosa.select"(%84, %8, %393) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%395 = "tosa.reduce_max"(%394) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%396 = "tosa.sub"(%394, %395) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%397 = "tosa.exp"(%396) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%398 = "tosa.reduce_sum"(%397) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%399 = "tosa.reciprocal"(%398) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%400 = "tosa.mul"(%397, %399) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%401 = "tosa.reshape"(%400) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%402 = "tosa.reshape"(%387) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%403 = "tosa.matmul"(%401, %402) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%404 = "tosa.reshape"(%403) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%405 = "tosa.transpose"(%404, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%406 = "tosa.reshape"(%405) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%407 = "tosa.matmul"(%406, %98) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%408 = "tosa.reshape"(%407) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%409 = "tosa.add"(%101, %408) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%410 = "tosa.reshape"(%409) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%411 = "tosa.add"(%410, %363) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%412 = "tosa.reduce_sum"(%411) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%413 = "tosa.mul"(%412, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%414 = "tosa.sub"(%411, %413) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%415 = "tosa.mul"(%414, %414) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%416 = "tosa.reduce_sum"(%415) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%417 = "tosa.mul"(%416, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%418 = "tosa.add"(%417, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%419 = "tosa.rsqrt"(%418) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%420 = "tosa.mul"(%414, %419) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%421 = "tosa.mul"(%420, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%422 = "tosa.add"(%421, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%423 = "tosa.matmul"(%422, %116) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%424 = "tosa.reshape"(%423) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%425 = "tosa.add"(%119, %424) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%426 = "tosa.reshape"(%425) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%427 = "tosa.mul"(%426, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%428 = "tosa.pow"(%426, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%429 = "tosa.mul"(%428, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%430 = "tosa.add"(%426, %429) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%431 = "tosa.mul"(%430, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%432 = "tosa.tanh"(%431) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%433 = "tosa.add"(%432, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%434 = "tosa.mul"(%427, %433) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%435 = "tosa.matmul"(%434, %130) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%436 = "tosa.reshape"(%435) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%437 = "tosa.add"(%101, %436) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%438 = "tosa.reshape"(%437) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%439 = "tosa.add"(%411, %438) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%440 = "tosa.reduce_sum"(%439) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%441 = "tosa.mul"(%440, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%442 = "tosa.sub"(%439, %441) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%443 = "tosa.mul"(%442, %442) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%444 = "tosa.reduce_sum"(%443) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%445 = "tosa.mul"(%444, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%446 = "tosa.add"(%445, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%447 = "tosa.rsqrt"(%446) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%448 = "tosa.mul"(%442, %447) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%449 = "tosa.mul"(%448, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%450 = "tosa.add"(%449, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%451 = "tosa.matmul"(%450, %55) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%452 = "tosa.reshape"(%451) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%453 = "tosa.add"(%58, %452) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%454 = "tosa.reshape"(%453) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%455 = "tosa.gather"(%454, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%456 = "tosa.gather"(%454, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%457 = "tosa.gather"(%454, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%458 = "tosa.reshape"(%455) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%459 = "tosa.transpose"(%458, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%460 = "tosa.reshape"(%456) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%461 = "tosa.transpose"(%460, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%462 = "tosa.reshape"(%457) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%463 = "tosa.transpose"(%462, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%464 = "tosa.transpose"(%461, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%465 = "tosa.reshape"(%459) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%466 = "tosa.reshape"(%464) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%467 = "tosa.matmul"(%465, %466) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%468 = "tosa.reshape"(%467) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%469 = "tosa.mul"(%468, %76) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%470 = "tosa.select"(%84, %8, %469) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%471 = "tosa.reduce_max"(%470) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%472 = "tosa.sub"(%470, %471) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%473 = "tosa.exp"(%472) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%474 = "tosa.reduce_sum"(%473) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%475 = "tosa.reciprocal"(%474) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%476 = "tosa.mul"(%473, %475) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%477 = "tosa.reshape"(%476) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%478 = "tosa.reshape"(%463) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%479 = "tosa.matmul"(%477, %478) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%480 = "tosa.reshape"(%479) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%481 = "tosa.transpose"(%480, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%482 = "tosa.reshape"(%481) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%483 = "tosa.matmul"(%482, %98) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%484 = "tosa.reshape"(%483) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%485 = "tosa.add"(%101, %484) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%486 = "tosa.reshape"(%485) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%487 = "tosa.add"(%486, %439) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%488 = "tosa.reduce_sum"(%487) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%489 = "tosa.mul"(%488, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%490 = "tosa.sub"(%487, %489) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%491 = "tosa.mul"(%490, %490) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%492 = "tosa.reduce_sum"(%491) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%493 = "tosa.mul"(%492, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%494 = "tosa.add"(%493, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%495 = "tosa.rsqrt"(%494) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%496 = "tosa.mul"(%490, %495) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%497 = "tosa.mul"(%496, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%498 = "tosa.add"(%497, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%499 = "tosa.matmul"(%498, %116) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%500 = "tosa.reshape"(%499) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%501 = "tosa.add"(%119, %500) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%502 = "tosa.reshape"(%501) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%503 = "tosa.mul"(%502, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%504 = "tosa.pow"(%502, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%505 = "tosa.mul"(%504, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%506 = "tosa.add"(%502, %505) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%507 = "tosa.mul"(%506, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%508 = "tosa.tanh"(%507) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%509 = "tosa.add"(%508, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%510 = "tosa.mul"(%503, %509) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%511 = "tosa.matmul"(%510, %130) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%512 = "tosa.reshape"(%511) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%513 = "tosa.add"(%101, %512) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%514 = "tosa.reshape"(%513) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%515 = "tosa.add"(%487, %514) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%516 = "tosa.reduce_sum"(%515) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%517 = "tosa.mul"(%516, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%518 = "tosa.sub"(%515, %517) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%519 = "tosa.mul"(%518, %518) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%520 = "tosa.reduce_sum"(%519) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%521 = "tosa.mul"(%520, %43) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%522 = "tosa.add"(%521, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%523 = "tosa.rsqrt"(%522) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%524 = "tosa.mul"(%518, %523) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%525 = "tosa.mul"(%524, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%526 = "tosa.add"(%525, %49) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%527 = "tosa.transpose"(%0, %23) : (tensor<2x768xf32>, tensor<2xi32>) -> tensor<768x2xf32>
%528 = "tosa.reshape"(%527) {new_shape = [1, 768, 2]} : (tensor<768x2xf32>) -> tensor<1x768x2xf32>
%529 = "tosa.matmul"(%526, %528) : (tensor<1x128x768xf32>, tensor<1x768x2xf32>) -> tensor<1x128x2xf32>
%530 = torch_c.from_builtin_tensor %24 : tensor<1xi64> -> !torch.vtensor<[1],si64>
%531 = "tosa.slice"(%529) {size = [1, 1, 2], start = [0, 127, 0]} : (tensor<1x128x2xf32>) -> tensor<1x1x2xf32>
%532 = "tosa.reshape"(%531) {new_shape = [1, 2]} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
%533 = torch_c.from_builtin_tensor %532 : tensor<1x2xf32> -> !torch.vtensor<[1,2],f32>
%534 = torch.prim.ListConstruct %530 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%535 = torch.aten.index.Tensor %533, %534 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
%536 = torch_c.to_builtin_tensor %535 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
return %536 : tensor<1x2xf32>
}
@AmosLewis
Copy link
Author

//===-------------------------------------------===//
Legalizing operation : 'torch.prim.ListConstruct'(0x93a7830) {
%534 = "torch.prim.ListConstruct"(%530) : (!torch.vtensor<[1],si64>) -> !torch.list

  • Fold {
    } -> FAILURE : unable to fold
    } -> FAILURE : no matched legalization pattern
    //===-------------------------------------------===//
    <eval_with_key>.2:544:12: error: failed to legalize operation 'torch.prim.ListConstruct'
    <eval_with_key>.2:544:12: note: see current operation: %534 = "torch.prim.ListConstruct"(%530) : (!torch.vtensor<[1],si64>) -> !torch.list
    // -----// IR Dump After FinalizingBackendTypeConversion Failed (torch-finalizing-backend-type-conversion) //----- //
    mlir-asm-printer: Verifying operation: func.func

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment