Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created January 3, 2023 21:48
Show Gist options
  • Save AmosLewis/9bea591a8123bf86f3cfdca6d5189db0 to your computer and use it in GitHub Desktop.
Save AmosLewis/9bea591a8123bf86f3cfdca6d5189db0 to your computer and use it in GitHub Desktop.
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: tensor<1x128xi64>) -> tensor<1x2xf32> {
%0 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2x768xf32>} : () -> tensor<2x768xf32>
%1 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768xf32>} : () -> tensor<768xf32>
%2 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072x768xf32>} : () -> tensor<3072x768xf32>
%3 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x3072xf32>} : () -> tensor<768x3072xf32>
%4 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072xf32>} : () -> tensor<3072xf32>
%5 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x768xf32>} : () -> tensor<768x768xf32>
%6 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x2304xf32>} : () -> tensor<768x2304xf32>
%7 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2304xf32>} : () -> tensor<2304xf32>
%8 = "tosa.const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
%9 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x1x1024x1024xui8>} : () -> tensor<1x1x1024x1024xi8>
%10 = "tosa.const"() {value = dense<8.000000e+00> : tensor<f32>} : () -> tensor<f32>
%11 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1024x768xf32>} : () -> tensor<1024x768xf32>
%12 = "tosa.const"() {value = dense_resource<__elided__> : tensor<50257x768xf32>} : () -> tensor<50257x768xf32>
%13 = "tosa.const"() {value = dense<7.680000e+02> : tensor<1xf32>} : () -> tensor<1xf32>
%14 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x98304xi32>} : () -> tensor<1x98304xi32>
%15 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x98304xi32>} : () -> tensor<1x98304xi32>
%16 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x98304xi32>} : () -> tensor<1x98304xi32>
%17 = "tosa.const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
%18 = "tosa.const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%19 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x1048576xi32>} : () -> tensor<1x1048576xi32>
%20 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x131072xi32>} : () -> tensor<1x131072xi32>
%21 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x16384xi32>} : () -> tensor<1x16384xi32>
%22 = "tosa.const"() {value = dense<0> : tensor<1x1x128x128xi8>} : () -> tensor<1x1x128x128xi8>
%23 = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%24 = "tosa.const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
%25 = "tosa.const"() {value = dense<0.797884583> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%26 = "tosa.const"() {value = dense<4.471500e-02> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%27 = "tosa.const"() {value = dense<3.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%28 = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%29 = "tosa.const"() {value = dense<9.99999974E-6> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%30 = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
%31 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x128xi64>} : () -> tensor<1x128xi64>
%32 = "tosa.reshape"(%12) {new_shape = [1, 50257, 768]} : (tensor<50257x768xf32>) -> tensor<1x50257x768xf32>
%33 = "tosa.cast"(%arg0) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%34 = "tosa.gather"(%32, %33) : (tensor<1x50257x768xf32>, tensor<1x128xi32>) -> tensor<1x128x768xf32>
%35 = "tosa.reshape"(%11) {new_shape = [1, 1024, 768]} : (tensor<1024x768xf32>) -> tensor<1x1024x768xf32>
%36 = "tosa.cast"(%31) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%37 = "tosa.gather"(%35, %36) : (tensor<1x1024x768xf32>, tensor<1x128xi32>) -> tensor<1x128x768xf32>
%38 = "tosa.add"(%34, %37) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%39 = "tosa.reciprocal"(%13) : (tensor<1xf32>) -> tensor<1xf32>
%40 = "tosa.reduce_sum"(%38) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%41 = "tosa.reshape"(%39) {new_shape = [1, 1, 1]} : (tensor<1xf32>) -> tensor<1x1x1xf32>
%42 = "tosa.mul"(%40, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%43 = "tosa.sub"(%38, %42) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%44 = "tosa.mul"(%43, %43) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%45 = "tosa.reduce_sum"(%44) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%46 = "tosa.mul"(%45, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%47 = "tosa.reshape"(%1) {new_shape = [1, 1, 768]} : (tensor<768xf32>) -> tensor<1x1x768xf32>
%48 = "tosa.add"(%46, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%49 = "tosa.rsqrt"(%48) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%50 = "tosa.mul"(%43, %49) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%51 = "tosa.mul"(%50, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%52 = "tosa.add"(%51, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%53 = "tosa.reshape"(%6) {new_shape = [1, 768, 2304]} : (tensor<768x2304xf32>) -> tensor<1x768x2304xf32>
%54 = "tosa.matmul"(%52, %53) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%55 = "tosa.reshape"(%54) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%56 = "tosa.reshape"(%7) {new_shape = [1, 2304]} : (tensor<2304xf32>) -> tensor<1x2304xf32>
%57 = "tosa.add"(%56, %55) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%58 = "tosa.reshape"(%57) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%59 = "tosa.gather"(%58, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%60 = "tosa.gather"(%58, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%61 = "tosa.gather"(%58, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%62 = "tosa.reshape"(%59) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%63 = "tosa.transpose"(%62, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%64 = "tosa.reshape"(%60) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%65 = "tosa.transpose"(%64, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%66 = "tosa.reshape"(%61) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%67 = "tosa.transpose"(%66, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%68 = "tosa.transpose"(%65, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%69 = "tosa.reshape"(%63) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%70 = "tosa.reshape"(%68) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%71 = "tosa.matmul"(%69, %70) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%72 = "tosa.reshape"(%71) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%73 = "tosa.reciprocal"(%10) : (tensor<f32>) -> tensor<f32>
%74 = "tosa.reshape"(%73) {new_shape = [1, 1, 1, 1]} : (tensor<f32>) -> tensor<1x1x1x1xf32>
%75 = "tosa.mul"(%72, %74) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%76 = "tosa.reshape"(%9) {new_shape = [1, 1048576, 1]} : (tensor<1x1x1024x1024xi8>) -> tensor<1x1048576x1xi8>
%77 = "tosa.gather"(%76, %19) : (tensor<1x1048576x1xi8>, tensor<1x1048576xi32>) -> tensor<1x1048576x1xi8>
%78 = "tosa.gather"(%77, %19) : (tensor<1x1048576x1xi8>, tensor<1x1048576xi32>) -> tensor<1x1048576x1xi8>
%79 = "tosa.gather"(%78, %20) : (tensor<1x1048576x1xi8>, tensor<1x131072xi32>) -> tensor<1x131072x1xi8>
%80 = "tosa.gather"(%79, %21) : (tensor<1x131072x1xi8>, tensor<1x16384xi32>) -> tensor<1x16384x1xi8>
%81 = "tosa.reshape"(%80) {new_shape = [1, 1, 128, 128]} : (tensor<1x16384x1xi8>) -> tensor<1x1x128x128xi8>
%82 = "tosa.equal"(%81, %22) : (tensor<1x1x128x128xi8>, tensor<1x1x128x128xi8>) -> tensor<1x1x128x128xi1>
%83 = "tosa.select"(%82, %8, %75) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%84 = "tosa.reduce_max"(%83) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%85 = "tosa.sub"(%83, %84) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%86 = "tosa.exp"(%85) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%87 = "tosa.reduce_sum"(%86) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%88 = "tosa.reciprocal"(%87) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%89 = "tosa.mul"(%86, %88) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%90 = "tosa.reshape"(%89) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%91 = "tosa.reshape"(%67) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%92 = "tosa.matmul"(%90, %91) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%93 = "tosa.reshape"(%92) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%94 = "tosa.transpose"(%93, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%95 = "tosa.reshape"(%94) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%96 = "tosa.reshape"(%5) {new_shape = [1, 768, 768]} : (tensor<768x768xf32>) -> tensor<1x768x768xf32>
%97 = "tosa.matmul"(%95, %96) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%98 = "tosa.reshape"(%97) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%99 = "tosa.reshape"(%1) {new_shape = [1, 768]} : (tensor<768xf32>) -> tensor<1x768xf32>
%100 = "tosa.add"(%99, %98) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%101 = "tosa.reshape"(%100) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%102 = "tosa.add"(%101, %38) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%103 = "tosa.reduce_sum"(%102) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%104 = "tosa.mul"(%103, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%105 = "tosa.sub"(%102, %104) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%106 = "tosa.mul"(%105, %105) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%107 = "tosa.reduce_sum"(%106) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%108 = "tosa.mul"(%107, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%109 = "tosa.add"(%108, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%110 = "tosa.rsqrt"(%109) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%111 = "tosa.mul"(%105, %110) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%112 = "tosa.mul"(%111, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%113 = "tosa.add"(%112, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%114 = "tosa.reshape"(%3) {new_shape = [1, 768, 3072]} : (tensor<768x3072xf32>) -> tensor<1x768x3072xf32>
%115 = "tosa.matmul"(%113, %114) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%116 = "tosa.reshape"(%115) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%117 = "tosa.reshape"(%4) {new_shape = [1, 3072]} : (tensor<3072xf32>) -> tensor<1x3072xf32>
%118 = "tosa.add"(%117, %116) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%119 = "tosa.reshape"(%118) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%120 = "tosa.mul"(%119, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%121 = "tosa.pow"(%119, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%122 = "tosa.mul"(%121, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%123 = "tosa.add"(%119, %122) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%124 = "tosa.mul"(%123, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%125 = "tosa.tanh"(%124) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%126 = "tosa.add"(%125, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%127 = "tosa.mul"(%120, %126) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%128 = "tosa.reshape"(%2) {new_shape = [1, 3072, 768]} : (tensor<3072x768xf32>) -> tensor<1x3072x768xf32>
%129 = "tosa.matmul"(%127, %128) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%130 = "tosa.reshape"(%129) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%131 = "tosa.add"(%99, %130) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%132 = "tosa.reshape"(%131) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%133 = "tosa.add"(%102, %132) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%134 = "tosa.reduce_sum"(%133) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%135 = "tosa.mul"(%134, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%136 = "tosa.sub"(%133, %135) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%137 = "tosa.mul"(%136, %136) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%138 = "tosa.reduce_sum"(%137) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%139 = "tosa.mul"(%138, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%140 = "tosa.add"(%139, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%141 = "tosa.rsqrt"(%140) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%142 = "tosa.mul"(%136, %141) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%143 = "tosa.mul"(%142, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%144 = "tosa.add"(%143, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%145 = "tosa.matmul"(%144, %53) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%146 = "tosa.reshape"(%145) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%147 = "tosa.add"(%56, %146) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%148 = "tosa.reshape"(%147) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%149 = "tosa.gather"(%148, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%150 = "tosa.gather"(%148, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%151 = "tosa.gather"(%148, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%152 = "tosa.reshape"(%149) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%153 = "tosa.transpose"(%152, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%154 = "tosa.reshape"(%150) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%155 = "tosa.transpose"(%154, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%156 = "tosa.reshape"(%151) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%157 = "tosa.transpose"(%156, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%158 = "tosa.transpose"(%155, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%159 = "tosa.reshape"(%153) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%160 = "tosa.reshape"(%158) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%161 = "tosa.matmul"(%159, %160) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%162 = "tosa.reshape"(%161) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%163 = "tosa.mul"(%162, %74) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%164 = "tosa.select"(%82, %8, %163) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%165 = "tosa.reduce_max"(%164) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%166 = "tosa.sub"(%164, %165) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%167 = "tosa.exp"(%166) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%168 = "tosa.reduce_sum"(%167) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%169 = "tosa.reciprocal"(%168) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%170 = "tosa.mul"(%167, %169) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%171 = "tosa.reshape"(%170) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%172 = "tosa.reshape"(%157) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%173 = "tosa.matmul"(%171, %172) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%174 = "tosa.reshape"(%173) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%175 = "tosa.transpose"(%174, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%176 = "tosa.reshape"(%175) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%177 = "tosa.matmul"(%176, %96) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%178 = "tosa.reshape"(%177) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%179 = "tosa.add"(%99, %178) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%180 = "tosa.reshape"(%179) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%181 = "tosa.add"(%180, %133) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%182 = "tosa.reduce_sum"(%181) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%183 = "tosa.mul"(%182, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%184 = "tosa.sub"(%181, %183) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%185 = "tosa.mul"(%184, %184) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%186 = "tosa.reduce_sum"(%185) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%187 = "tosa.mul"(%186, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%188 = "tosa.add"(%187, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%189 = "tosa.rsqrt"(%188) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%190 = "tosa.mul"(%184, %189) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%191 = "tosa.mul"(%190, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%192 = "tosa.add"(%191, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%193 = "tosa.matmul"(%192, %114) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%194 = "tosa.reshape"(%193) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%195 = "tosa.add"(%117, %194) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%196 = "tosa.reshape"(%195) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%197 = "tosa.mul"(%196, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%198 = "tosa.pow"(%196, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%199 = "tosa.mul"(%198, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%200 = "tosa.add"(%196, %199) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%201 = "tosa.mul"(%200, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%202 = "tosa.tanh"(%201) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%203 = "tosa.add"(%202, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%204 = "tosa.mul"(%197, %203) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%205 = "tosa.matmul"(%204, %128) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%206 = "tosa.reshape"(%205) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%207 = "tosa.add"(%99, %206) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%208 = "tosa.reshape"(%207) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%209 = "tosa.add"(%181, %208) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%210 = "tosa.reduce_sum"(%209) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%211 = "tosa.mul"(%210, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%212 = "tosa.sub"(%209, %211) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%213 = "tosa.mul"(%212, %212) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%214 = "tosa.reduce_sum"(%213) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%215 = "tosa.mul"(%214, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%216 = "tosa.add"(%215, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%217 = "tosa.rsqrt"(%216) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%218 = "tosa.mul"(%212, %217) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%219 = "tosa.mul"(%218, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%220 = "tosa.add"(%219, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%221 = "tosa.matmul"(%220, %53) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%222 = "tosa.reshape"(%221) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%223 = "tosa.add"(%56, %222) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%224 = "tosa.reshape"(%223) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%225 = "tosa.gather"(%224, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%226 = "tosa.gather"(%224, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%227 = "tosa.gather"(%224, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%228 = "tosa.reshape"(%225) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%229 = "tosa.transpose"(%228, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%230 = "tosa.reshape"(%226) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%231 = "tosa.transpose"(%230, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%232 = "tosa.reshape"(%227) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%233 = "tosa.transpose"(%232, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%234 = "tosa.transpose"(%231, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%235 = "tosa.reshape"(%229) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%236 = "tosa.reshape"(%234) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%237 = "tosa.matmul"(%235, %236) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%238 = "tosa.reshape"(%237) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%239 = "tosa.mul"(%238, %74) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%240 = "tosa.select"(%82, %8, %239) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%241 = "tosa.reduce_max"(%240) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%242 = "tosa.sub"(%240, %241) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%243 = "tosa.exp"(%242) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%244 = "tosa.reduce_sum"(%243) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%245 = "tosa.reciprocal"(%244) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%246 = "tosa.mul"(%243, %245) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%247 = "tosa.reshape"(%246) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%248 = "tosa.reshape"(%233) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%249 = "tosa.matmul"(%247, %248) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%250 = "tosa.reshape"(%249) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%251 = "tosa.transpose"(%250, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%252 = "tosa.reshape"(%251) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%253 = "tosa.matmul"(%252, %96) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%254 = "tosa.reshape"(%253) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%255 = "tosa.add"(%99, %254) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%256 = "tosa.reshape"(%255) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%257 = "tosa.add"(%256, %209) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%258 = "tosa.reduce_sum"(%257) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%259 = "tosa.mul"(%258, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%260 = "tosa.sub"(%257, %259) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%261 = "tosa.mul"(%260, %260) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%262 = "tosa.reduce_sum"(%261) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%263 = "tosa.mul"(%262, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%264 = "tosa.add"(%263, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%265 = "tosa.rsqrt"(%264) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%266 = "tosa.mul"(%260, %265) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%267 = "tosa.mul"(%266, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%268 = "tosa.add"(%267, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%269 = "tosa.matmul"(%268, %114) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%270 = "tosa.reshape"(%269) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%271 = "tosa.add"(%117, %270) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%272 = "tosa.reshape"(%271) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%273 = "tosa.mul"(%272, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%274 = "tosa.pow"(%272, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%275 = "tosa.mul"(%274, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%276 = "tosa.add"(%272, %275) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%277 = "tosa.mul"(%276, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%278 = "tosa.tanh"(%277) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%279 = "tosa.add"(%278, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%280 = "tosa.mul"(%273, %279) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%281 = "tosa.matmul"(%280, %128) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%282 = "tosa.reshape"(%281) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%283 = "tosa.add"(%99, %282) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%284 = "tosa.reshape"(%283) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%285 = "tosa.add"(%257, %284) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%286 = "tosa.reduce_sum"(%285) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%287 = "tosa.mul"(%286, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%288 = "tosa.sub"(%285, %287) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%289 = "tosa.mul"(%288, %288) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%290 = "tosa.reduce_sum"(%289) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%291 = "tosa.mul"(%290, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%292 = "tosa.add"(%291, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%293 = "tosa.rsqrt"(%292) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%294 = "tosa.mul"(%288, %293) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%295 = "tosa.mul"(%294, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%296 = "tosa.add"(%295, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%297 = "tosa.matmul"(%296, %53) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%298 = "tosa.reshape"(%297) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%299 = "tosa.add"(%56, %298) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%300 = "tosa.reshape"(%299) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%301 = "tosa.gather"(%300, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%302 = "tosa.gather"(%300, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%303 = "tosa.gather"(%300, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%304 = "tosa.reshape"(%301) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%305 = "tosa.transpose"(%304, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%306 = "tosa.reshape"(%302) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%307 = "tosa.transpose"(%306, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%308 = "tosa.reshape"(%303) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%309 = "tosa.transpose"(%308, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%310 = "tosa.transpose"(%307, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%311 = "tosa.reshape"(%305) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%312 = "tosa.reshape"(%310) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%313 = "tosa.matmul"(%311, %312) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%314 = "tosa.reshape"(%313) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%315 = "tosa.mul"(%314, %74) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%316 = "tosa.select"(%82, %8, %315) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%317 = "tosa.reduce_max"(%316) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%318 = "tosa.sub"(%316, %317) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%319 = "tosa.exp"(%318) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%320 = "tosa.reduce_sum"(%319) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%321 = "tosa.reciprocal"(%320) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%322 = "tosa.mul"(%319, %321) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%323 = "tosa.reshape"(%322) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%324 = "tosa.reshape"(%309) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%325 = "tosa.matmul"(%323, %324) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%326 = "tosa.reshape"(%325) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%327 = "tosa.transpose"(%326, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%328 = "tosa.reshape"(%327) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%329 = "tosa.matmul"(%328, %96) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%330 = "tosa.reshape"(%329) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%331 = "tosa.add"(%99, %330) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%332 = "tosa.reshape"(%331) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%333 = "tosa.add"(%332, %285) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%334 = "tosa.reduce_sum"(%333) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%335 = "tosa.mul"(%334, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%336 = "tosa.sub"(%333, %335) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%337 = "tosa.mul"(%336, %336) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%338 = "tosa.reduce_sum"(%337) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%339 = "tosa.mul"(%338, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%340 = "tosa.add"(%339, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%341 = "tosa.rsqrt"(%340) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%342 = "tosa.mul"(%336, %341) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%343 = "tosa.mul"(%342, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%344 = "tosa.add"(%343, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%345 = "tosa.matmul"(%344, %114) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%346 = "tosa.reshape"(%345) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%347 = "tosa.add"(%117, %346) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%348 = "tosa.reshape"(%347) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%349 = "tosa.mul"(%348, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%350 = "tosa.pow"(%348, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%351 = "tosa.mul"(%350, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%352 = "tosa.add"(%348, %351) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%353 = "tosa.mul"(%352, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%354 = "tosa.tanh"(%353) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%355 = "tosa.add"(%354, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%356 = "tosa.mul"(%349, %355) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%357 = "tosa.matmul"(%356, %128) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%358 = "tosa.reshape"(%357) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%359 = "tosa.add"(%99, %358) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%360 = "tosa.reshape"(%359) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%361 = "tosa.add"(%333, %360) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%362 = "tosa.reduce_sum"(%361) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%363 = "tosa.mul"(%362, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%364 = "tosa.sub"(%361, %363) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%365 = "tosa.mul"(%364, %364) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%366 = "tosa.reduce_sum"(%365) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%367 = "tosa.mul"(%366, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%368 = "tosa.add"(%367, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%369 = "tosa.rsqrt"(%368) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%370 = "tosa.mul"(%364, %369) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%371 = "tosa.mul"(%370, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%372 = "tosa.add"(%371, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%373 = "tosa.matmul"(%372, %53) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%374 = "tosa.reshape"(%373) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%375 = "tosa.add"(%56, %374) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%376 = "tosa.reshape"(%375) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%377 = "tosa.gather"(%376, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%378 = "tosa.gather"(%376, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%379 = "tosa.gather"(%376, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%380 = "tosa.reshape"(%377) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%381 = "tosa.transpose"(%380, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%382 = "tosa.reshape"(%378) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%383 = "tosa.transpose"(%382, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%384 = "tosa.reshape"(%379) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%385 = "tosa.transpose"(%384, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%386 = "tosa.transpose"(%383, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%387 = "tosa.reshape"(%381) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%388 = "tosa.reshape"(%386) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%389 = "tosa.matmul"(%387, %388) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%390 = "tosa.reshape"(%389) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%391 = "tosa.mul"(%390, %74) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%392 = "tosa.select"(%82, %8, %391) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%393 = "tosa.reduce_max"(%392) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%394 = "tosa.sub"(%392, %393) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%395 = "tosa.exp"(%394) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%396 = "tosa.reduce_sum"(%395) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%397 = "tosa.reciprocal"(%396) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%398 = "tosa.mul"(%395, %397) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%399 = "tosa.reshape"(%398) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%400 = "tosa.reshape"(%385) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%401 = "tosa.matmul"(%399, %400) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%402 = "tosa.reshape"(%401) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%403 = "tosa.transpose"(%402, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%404 = "tosa.reshape"(%403) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%405 = "tosa.matmul"(%404, %96) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%406 = "tosa.reshape"(%405) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%407 = "tosa.add"(%99, %406) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%408 = "tosa.reshape"(%407) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%409 = "tosa.add"(%408, %361) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%410 = "tosa.reduce_sum"(%409) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%411 = "tosa.mul"(%410, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%412 = "tosa.sub"(%409, %411) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%413 = "tosa.mul"(%412, %412) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%414 = "tosa.reduce_sum"(%413) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%415 = "tosa.mul"(%414, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%416 = "tosa.add"(%415, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%417 = "tosa.rsqrt"(%416) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%418 = "tosa.mul"(%412, %417) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%419 = "tosa.mul"(%418, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%420 = "tosa.add"(%419, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%421 = "tosa.matmul"(%420, %114) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%422 = "tosa.reshape"(%421) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%423 = "tosa.add"(%117, %422) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%424 = "tosa.reshape"(%423) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%425 = "tosa.mul"(%424, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%426 = "tosa.pow"(%424, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%427 = "tosa.mul"(%426, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%428 = "tosa.add"(%424, %427) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%429 = "tosa.mul"(%428, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%430 = "tosa.tanh"(%429) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%431 = "tosa.add"(%430, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%432 = "tosa.mul"(%425, %431) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%433 = "tosa.matmul"(%432, %128) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%434 = "tosa.reshape"(%433) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%435 = "tosa.add"(%99, %434) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%436 = "tosa.reshape"(%435) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%437 = "tosa.add"(%409, %436) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%438 = "tosa.reduce_sum"(%437) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%439 = "tosa.mul"(%438, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%440 = "tosa.sub"(%437, %439) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%441 = "tosa.mul"(%440, %440) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%442 = "tosa.reduce_sum"(%441) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%443 = "tosa.mul"(%442, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%444 = "tosa.add"(%443, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%445 = "tosa.rsqrt"(%444) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%446 = "tosa.mul"(%440, %445) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%447 = "tosa.mul"(%446, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%448 = "tosa.add"(%447, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%449 = "tosa.matmul"(%448, %53) : (tensor<1x128x768xf32>, tensor<1x768x2304xf32>) -> tensor<1x128x2304xf32>
%450 = "tosa.reshape"(%449) {new_shape = [128, 2304]} : (tensor<1x128x2304xf32>) -> tensor<128x2304xf32>
%451 = "tosa.add"(%56, %450) : (tensor<1x2304xf32>, tensor<128x2304xf32>) -> tensor<128x2304xf32>
%452 = "tosa.reshape"(%451) {new_shape = [1, 294912, 1]} : (tensor<128x2304xf32>) -> tensor<1x294912x1xf32>
%453 = "tosa.gather"(%452, %14) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%454 = "tosa.gather"(%452, %15) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%455 = "tosa.gather"(%452, %16) : (tensor<1x294912x1xf32>, tensor<1x98304xi32>) -> tensor<1x98304x1xf32>
%456 = "tosa.reshape"(%453) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%457 = "tosa.transpose"(%456, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%458 = "tosa.reshape"(%454) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%459 = "tosa.transpose"(%458, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%460 = "tosa.reshape"(%455) {new_shape = [1, 128, 12, 64]} : (tensor<1x98304x1xf32>) -> tensor<1x128x12x64xf32>
%461 = "tosa.transpose"(%460, %17) : (tensor<1x128x12x64xf32>, tensor<4xi64>) -> tensor<1x12x128x64xf32>
%462 = "tosa.transpose"(%459, %18) : (tensor<1x12x128x64xf32>, tensor<4xi32>) -> tensor<1x12x64x128xf32>
%463 = "tosa.reshape"(%457) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%464 = "tosa.reshape"(%462) {new_shape = [12, 64, 128]} : (tensor<1x12x64x128xf32>) -> tensor<12x64x128xf32>
%465 = "tosa.matmul"(%463, %464) : (tensor<12x128x64xf32>, tensor<12x64x128xf32>) -> tensor<12x128x128xf32>
%466 = "tosa.reshape"(%465) {new_shape = [1, 12, 128, 128]} : (tensor<12x128x128xf32>) -> tensor<1x12x128x128xf32>
%467 = "tosa.mul"(%466, %74) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x128x128xf32>
%468 = "tosa.select"(%82, %8, %467) : (tensor<1x1x128x128xi1>, tensor<f32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%469 = "tosa.reduce_max"(%468) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%470 = "tosa.sub"(%468, %469) : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%471 = "tosa.exp"(%470) : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
%472 = "tosa.reduce_sum"(%471) {axis = 3 : i64} : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x1xf32>
%473 = "tosa.reciprocal"(%472) : (tensor<1x12x128x1xf32>) -> tensor<1x12x128x1xf32>
%474 = "tosa.mul"(%471, %473) {shift = 0 : i32} : (tensor<1x12x128x128xf32>, tensor<1x12x128x1xf32>) -> tensor<1x12x128x128xf32>
%475 = "tosa.reshape"(%474) {new_shape = [12, 128, 128]} : (tensor<1x12x128x128xf32>) -> tensor<12x128x128xf32>
%476 = "tosa.reshape"(%461) {new_shape = [12, 128, 64]} : (tensor<1x12x128x64xf32>) -> tensor<12x128x64xf32>
%477 = "tosa.matmul"(%475, %476) : (tensor<12x128x128xf32>, tensor<12x128x64xf32>) -> tensor<12x128x64xf32>
%478 = "tosa.reshape"(%477) {new_shape = [1, 12, 128, 64]} : (tensor<12x128x64xf32>) -> tensor<1x12x128x64xf32>
%479 = "tosa.transpose"(%478, %17) : (tensor<1x12x128x64xf32>, tensor<4xi64>) -> tensor<1x128x12x64xf32>
%480 = "tosa.reshape"(%479) {new_shape = [1, 128, 768]} : (tensor<1x128x12x64xf32>) -> tensor<1x128x768xf32>
%481 = "tosa.matmul"(%480, %96) : (tensor<1x128x768xf32>, tensor<1x768x768xf32>) -> tensor<1x128x768xf32>
%482 = "tosa.reshape"(%481) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%483 = "tosa.add"(%99, %482) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%484 = "tosa.reshape"(%483) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%485 = "tosa.add"(%484, %437) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%486 = "tosa.reduce_sum"(%485) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%487 = "tosa.mul"(%486, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%488 = "tosa.sub"(%485, %487) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%489 = "tosa.mul"(%488, %488) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%490 = "tosa.reduce_sum"(%489) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%491 = "tosa.mul"(%490, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%492 = "tosa.add"(%491, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%493 = "tosa.rsqrt"(%492) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%494 = "tosa.mul"(%488, %493) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%495 = "tosa.mul"(%494, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%496 = "tosa.add"(%495, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%497 = "tosa.matmul"(%496, %114) : (tensor<1x128x768xf32>, tensor<1x768x3072xf32>) -> tensor<1x128x3072xf32>
%498 = "tosa.reshape"(%497) {new_shape = [128, 3072]} : (tensor<1x128x3072xf32>) -> tensor<128x3072xf32>
%499 = "tosa.add"(%117, %498) : (tensor<1x3072xf32>, tensor<128x3072xf32>) -> tensor<128x3072xf32>
%500 = "tosa.reshape"(%499) {new_shape = [1, 128, 3072]} : (tensor<128x3072xf32>) -> tensor<1x128x3072xf32>
%501 = "tosa.mul"(%500, %28) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%502 = "tosa.pow"(%500, %27) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%503 = "tosa.mul"(%502, %26) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%504 = "tosa.add"(%500, %503) : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%505 = "tosa.mul"(%504, %25) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%506 = "tosa.tanh"(%505) : (tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%507 = "tosa.add"(%506, %30) : (tensor<1x128x3072xf32>, tensor<1x1x1xf32>) -> tensor<1x128x3072xf32>
%508 = "tosa.mul"(%501, %507) {shift = 0 : i32} : (tensor<1x128x3072xf32>, tensor<1x128x3072xf32>) -> tensor<1x128x3072xf32>
%509 = "tosa.matmul"(%508, %128) : (tensor<1x128x3072xf32>, tensor<1x3072x768xf32>) -> tensor<1x128x768xf32>
%510 = "tosa.reshape"(%509) {new_shape = [128, 768]} : (tensor<1x128x768xf32>) -> tensor<128x768xf32>
%511 = "tosa.add"(%99, %510) : (tensor<1x768xf32>, tensor<128x768xf32>) -> tensor<128x768xf32>
%512 = "tosa.reshape"(%511) {new_shape = [1, 128, 768]} : (tensor<128x768xf32>) -> tensor<1x128x768xf32>
%513 = "tosa.add"(%485, %512) : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%514 = "tosa.reduce_sum"(%513) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%515 = "tosa.mul"(%514, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%516 = "tosa.sub"(%513, %515) : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%517 = "tosa.mul"(%516, %516) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
%518 = "tosa.reduce_sum"(%517) {axis = 2 : i64} : (tensor<1x128x768xf32>) -> tensor<1x128x1xf32>
%519 = "tosa.mul"(%518, %41) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%520 = "tosa.add"(%519, %29) : (tensor<1x128x1xf32>, tensor<1x1x1xf32>) -> tensor<1x128x1xf32>
%521 = "tosa.rsqrt"(%520) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%522 = "tosa.mul"(%516, %521) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x128x1xf32>) -> tensor<1x128x768xf32>
%523 = "tosa.mul"(%522, %47) {shift = 0 : i32} : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%524 = "tosa.add"(%523, %47) : (tensor<1x128x768xf32>, tensor<1x1x768xf32>) -> tensor<1x128x768xf32>
%525 = "tosa.transpose"(%0, %23) : (tensor<2x768xf32>, tensor<2xi32>) -> tensor<768x2xf32>
%526 = "tosa.reshape"(%525) {new_shape = [1, 768, 2]} : (tensor<768x2xf32>) -> tensor<1x768x2xf32>
%527 = "tosa.matmul"(%524, %526) : (tensor<1x128x768xf32>, tensor<1x768x2xf32>) -> tensor<1x128x2xf32>
%528 = "tosa.slice"(%527) {size = [1, 1, 2], start = [0, 127, 0]} : (tensor<1x128x2xf32>) -> tensor<1x1x2xf32>
%529 = "tosa.cast"(%24) : (tensor<1xi64>) -> tensor<1xi32>
%530 = "tosa.reshape"(%529) {new_shape = [1, 1]} : (tensor<1xi32>) -> tensor<1x1xi32>
%531 = "tosa.gather"(%528, %530) : (tensor<1x1x2xf32>, tensor<1x1xi32>) -> tensor<1x1x2xf32>
%532 = "tosa.reshape"(%531) {new_shape = [1, 2]} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
return %532 : tensor<1x2xf32>
}
}
@AmosLewis
Copy link
Author

torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir -mlir-print-ir-after-failure -mlir-disable-threading -debug --mlir-elide-elementsattrs-if-larger=4

@liuyumoye
Copy link

你好,我用mlir解析"tosa.const"() {value = dense_resource 类型的时候,想获取里面的值,但是用auto attr = const_op.getValueAttr().castmlir::ElementsAttr(); for (const auto v : attr.getValues()) { div_constant.push_back(v); },会报错如下:ElementsAttr does not provide iteration facilities for type float, see attribute: dense_resource : tensor<2x768xf32>
invalid T for ElementsAttr::getValues,请问你知道怎么解决这个问题吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment