Created
September 17, 2024 21:23
-
-
Save GleasonK/cf68c91196bc6beb4017d1a2e53ef8bf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // This is a dump from PyTorch/XLA of the following LLAMA2 model file: | |
| // https://github.com/pytorch/xla/blob/master/test/stablehlo/llama_model2.py | |
| // | |
| module @IrToHlo.7509 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { | |
| func.func @main(%arg0: tensor<32000x4096xf32>, %arg1: tensor<4096xf32>, %arg2: tensor<4096x11008xf32>, %arg3: tensor<11008x4096xf32>, %arg4: tensor<4096xf32>, %arg5: tensor<4096x4096xf32>, %arg6: tensor<4096x4096xf32>, %arg7: tensor<4096xf32>, %arg8: tensor<4096x11008xf32>, %arg9: tensor<11008x4096xf32>, %arg10: tensor<4096xf32>, %arg11: tensor<4096x4096xf32>, %arg12: tensor<4096x4096xf32>, %arg13: tensor<4096xf32>, %arg14: tensor<4096x11008xf32>, %arg15: tensor<11008x4096xf32>, %arg16: tensor<4096xf32>, %arg17: tensor<4096x4096xf32>, %arg18: tensor<4096x4096xf32>, %arg19: tensor<4096xf32>, %arg20: tensor<4096x11008xf32>, %arg21: tensor<11008x4096xf32>, %arg22: tensor<4096xf32>, %arg23: tensor<4096x4096xf32>, %arg24: tensor<4096x4096xf32>, %arg25: tensor<4096xf32>, %arg26: tensor<4096x11008xf32>, %arg27: tensor<11008x4096xf32>, %arg28: tensor<4096xf32>, %arg29: tensor<4096x4096xf32>, %arg30: tensor<4096x4096xf32>, %arg31: tensor<4096xf32>, %arg32: tensor<4096x11008xf32>, %arg33: tensor<11008x4096xf32>, %arg34: tensor<4096xf32>, %arg35: tensor<4096x4096xf32>, %arg36: tensor<4096x4096xf32>, %arg37: tensor<4096xf32>, %arg38: tensor<4096x11008xf32>, %arg39: tensor<11008x4096xf32>, %arg40: tensor<4096xf32>, %arg41: tensor<4096x4096xf32>, %arg42: tensor<4096x4096xf32>, %arg43: tensor<4096xf32>, %arg44: tensor<4096x11008xf32>, %arg45: tensor<11008x4096xf32>, %arg46: tensor<4096xf32>, %arg47: tensor<4096x4096xf32>, %arg48: tensor<4096x4096xf32>, %arg49: tensor<4096xf32>, %arg50: tensor<4096x11008xf32>, %arg51: tensor<11008x4096xf32>, %arg52: tensor<4096xf32>, %arg53: tensor<4096x4096xf32>, %arg54: tensor<4096x4096xf32>, %arg55: tensor<4096xf32>, %arg56: tensor<4096x11008xf32>, %arg57: tensor<11008x4096xf32>, %arg58: tensor<4096xf32>, %arg59: tensor<4096x4096xf32>, %arg60: tensor<4096x4096xf32>, %arg61: tensor<4096xf32>, %arg62: tensor<4096x11008xf32>, %arg63: tensor<11008x4096xf32>, %arg64: tensor<4096xf32>, %arg65: tensor<4096x4096xf32>, %arg66: tensor<4096x4096xf32>, %arg67: tensor<4096xf32>, %arg68: tensor<4096x11008xf32>, %arg69: tensor<11008x4096xf32>, %arg70: tensor<4096xf32>, %arg71: tensor<4096x4096xf32>, %arg72: tensor<4096x4096xf32>, %arg73: tensor<4096xf32>, %arg74: tensor<4096x11008xf32>, %arg75: tensor<11008x4096xf32>, %arg76: tensor<4096xf32>, %arg77: tensor<4096x4096xf32>, %arg78: tensor<4096x4096xf32>, %arg79: tensor<4096xf32>, %arg80: tensor<4096x11008xf32>, %arg81: tensor<11008x4096xf32>, %arg82: tensor<4096xf32>, %arg83: tensor<4096x4096xf32>, %arg84: tensor<4096x4096xf32>, %arg85: tensor<4096xf32>, %arg86: tensor<4096x11008xf32>, %arg87: tensor<11008x4096xf32>, %arg88: tensor<4096xf32>, %arg89: tensor<4096x4096xf32>, %arg90: tensor<4096x4096xf32>, %arg91: tensor<4096xf32>, %arg92: tensor<4096x11008xf32>, %arg93: tensor<11008x4096xf32>, %arg94: tensor<4096xf32>, %arg95: tensor<4096x4096xf32>, %arg96: tensor<4096x4096xf32>, %arg97: tensor<4096xf32>, %arg98: tensor<4096x11008xf32>, %arg99: tensor<11008x4096xf32>, %arg100: tensor<4096xf32>, %arg101: tensor<4096x4096xf32>, %arg102: tensor<4096x4096xf32>, %arg103: tensor<4096xf32>, %arg104: tensor<4096x11008xf32>, %arg105: tensor<11008x4096xf32>, %arg106: tensor<4096xf32>, %arg107: tensor<4096x4096xf32>, %arg108: tensor<4096x4096xf32>, %arg109: tensor<4096xf32>, %arg110: tensor<4096x11008xf32>, %arg111: tensor<11008x4096xf32>, %arg112: tensor<4096xf32>, %arg113: tensor<4096x4096xf32>, %arg114: tensor<4096x4096xf32>, %arg115: tensor<4096xf32>, %arg116: tensor<4096x11008xf32>, %arg117: tensor<11008x4096xf32>, %arg118: tensor<4096xf32>, %arg119: tensor<4096x4096xf32>, %arg120: tensor<4096x4096xf32>, %arg121: tensor<4096xf32>, %arg122: tensor<4096x11008xf32>, %arg123: tensor<11008x4096xf32>, %arg124: tensor<4096xf32>, %arg125: tensor<4096x4096xf32>, %arg126: tensor<4096x4096xf32>, %arg127: tensor<4096xf32>, %arg128: tensor<4096x11008xf32>, %arg129: tensor<11008x4096xf32>, %arg130: tensor<4096xf32>, %arg131: tensor<4096x4096xf32>, %arg132: tensor<4096x4096xf32>, %arg133: tensor<4096xf32>, %arg134: tensor<4096x11008xf32>, %arg135: tensor<11008x4096xf32>, %arg136: tensor<4096xf32>, %arg137: tensor<4096x4096xf32>, %arg138: tensor<4096x4096xf32>, %arg139: tensor<4096xf32>, %arg140: tensor<4096x11008xf32>, %arg141: tensor<11008x4096xf32>, %arg142: tensor<4096xf32>, %arg143: tensor<4096x4096xf32>, %arg144: tensor<4096x4096xf32>, %arg145: tensor<4096xf32>, %arg146: tensor<4096x11008xf32>, %arg147: tensor<11008x4096xf32>, %arg148: tensor<4096xf32>, %arg149: tensor<4096x4096xf32>, %arg150: tensor<4096x4096xf32>, %arg151: tensor<4096xf32>, %arg152: tensor<4096x11008xf32>, %arg153: tensor<11008x4096xf32>, %arg154: tensor<4096xf32>, %arg155: tensor<4096x4096xf32>, %arg156: tensor<4096x4096xf32>, %arg157: tensor<4096xf32>, %arg158: tensor<4096x11008xf32>, %arg159: tensor<11008x4096xf32>, %arg160: tensor<4096xf32>, %arg161: tensor<4096x4096xf32>, %arg162: tensor<4096x4096xf32>, %arg163: tensor<4096xf32>, %arg164: tensor<4096x11008xf32>, %arg165: tensor<11008x4096xf32>, %arg166: tensor<4096xf32>, %arg167: tensor<4096x4096xf32>, %arg168: tensor<4096x4096xf32>, %arg169: tensor<4096xf32>, %arg170: tensor<4096x11008xf32>, %arg171: tensor<11008x4096xf32>, %arg172: tensor<4096xf32>, %arg173: tensor<4096x4096xf32>, %arg174: tensor<4096x4096xf32>, %arg175: tensor<4096xf32>, %arg176: tensor<4096x11008xf32>, %arg177: tensor<11008x4096xf32>, %arg178: tensor<4096xf32>, %arg179: tensor<4096x4096xf32>, %arg180: tensor<4096x4096xf32>, %arg181: tensor<4096xf32>, %arg182: tensor<4096x11008xf32>, %arg183: tensor<11008x4096xf32>, %arg184: tensor<4096xf32>, %arg185: tensor<4096x4096xf32>, %arg186: tensor<4096x4096xf32>, %arg187: tensor<4096xf32>, %arg188: tensor<4096x11008xf32>, %arg189: tensor<11008x4096xf32>, %arg190: tensor<4096xf32>, %arg191: tensor<4096x4096xf32>, %arg192: tensor<4096x4096xf32>, %arg193: tensor<4096xf32>, %arg194: tensor<8x100xi64>, %arg195: tensor<32000x4096xf32>, %arg196: tensor<100xi64>, %arg197: tensor<8x1024x32x128xf32>, %arg198: tensor<1x1x1024x1024xf32>, %arg199: tensor<2048x64xcomplex<f32>>, %arg200: tensor<4096x4096xf32>, %arg201: tensor<8x1024x32x128xf32>, %arg202: tensor<4096x4096xf32>, %arg203: tensor<11008x4096xf32>, %arg204: tensor<8x1024x32x128xf32>, %arg205: tensor<4096x4096xf32>, %arg206: tensor<8x1024x32x128xf32>, %arg207: tensor<4096x4096xf32>, %arg208: tensor<11008x4096xf32>, %arg209: tensor<8x1024x32x128xf32>, %arg210: tensor<4096x4096xf32>, %arg211: tensor<8x1024x32x128xf32>, %arg212: tensor<4096x4096xf32>, %arg213: tensor<11008x4096xf32>, %arg214: tensor<8x1024x32x128xf32>, %arg215: tensor<4096x4096xf32>, %arg216: tensor<8x1024x32x128xf32>, %arg217: tensor<4096x4096xf32>, %arg218: tensor<11008x4096xf32>, %arg219: tensor<8x1024x32x128xf32>, %arg220: tensor<4096x4096xf32>, %arg221: tensor<8x1024x32x128xf32>, %arg222: tensor<4096x4096xf32>, %arg223: tensor<11008x4096xf32>, %arg224: tensor<8x1024x32x128xf32>, %arg225: tensor<4096x4096xf32>, %arg226: tensor<8x1024x32x128xf32>, %arg227: tensor<4096x4096xf32>, %arg228: tensor<11008x4096xf32>, %arg229: tensor<8x1024x32x128xf32>, %arg230: tensor<4096x4096xf32>, %arg231: tensor<8x1024x32x128xf32>, %arg232: tensor<4096x4096xf32>, %arg233: tensor<11008x4096xf32>, %arg234: tensor<8x1024x32x128xf32>, %arg235: tensor<4096x4096xf32>, %arg236: tensor<8x1024x32x128xf32>, %arg237: tensor<4096x4096xf32>, %arg238: tensor<11008x4096xf32>, %arg239: tensor<8x1024x32x128xf32>, %arg240: tensor<4096x4096xf32>, %arg241: tensor<8x1024x32x128xf32>, %arg242: tensor<4096x4096xf32>, %arg243: tensor<11008x4096xf32>, %arg244: tensor<8x1024x32x128xf32>, %arg245: tensor<4096x4096xf32>, %arg246: tensor<8x1024x32x128xf32>, %arg247: tensor<4096x4096xf32>, %arg248: tensor<11008x4096xf32>, %arg249: tensor<8x1024x32x128xf32>, %arg250: tensor<4096x4096xf32>, %arg251: tensor<8x1024x32x128xf32>, %arg252: tensor<4096x4096xf32>, %arg253: tensor<11008x4096xf32>, %arg254: tensor<8x1024x32x128xf32>, %arg255: tensor<4096x4096xf32>, %arg256: tensor<8x1024x32x128xf32>, %arg257: tensor<4096x4096xf32>, %arg258: tensor<11008x4096xf32>, %arg259: tensor<8x1024x32x128xf32>, %arg260: tensor<4096x4096xf32>, %arg261: tensor<8x1024x32x128xf32>, %arg262: tensor<4096x4096xf32>, %arg263: tensor<11008x4096xf32>, %arg264: tensor<8x1024x32x128xf32>, %arg265: tensor<4096x4096xf32>, %arg266: tensor<8x1024x32x128xf32>, %arg267: tensor<4096x4096xf32>, %arg268: tensor<11008x4096xf32>, %arg269: tensor<8x1024x32x128xf32>, %arg270: tensor<4096x4096xf32>, %arg271: tensor<8x1024x32x128xf32>, %arg272: tensor<4096x4096xf32>, %arg273: tensor<11008x4096xf32>, %arg274: tensor<8x1024x32x128xf32>, %arg275: tensor<4096x4096xf32>, %arg276: tensor<8x1024x32x128xf32>, %arg277: tensor<4096x4096xf32>, %arg278: tensor<11008x4096xf32>, %arg279: tensor<8x1024x32x128xf32>, %arg280: tensor<4096x4096xf32>, %arg281: tensor<8x1024x32x128xf32>, %arg282: tensor<4096x4096xf32>, %arg283: tensor<11008x4096xf32>, %arg284: tensor<8x1024x32x128xf32>, %arg285: tensor<4096x4096xf32>, %arg286: tensor<8x1024x32x128xf32>, %arg287: tensor<4096x4096xf32>, %arg288: tensor<11008x4096xf32>, %arg289: tensor<8x1024x32x128xf32>, %arg290: tensor<4096x4096xf32>, %arg291: tensor<8x1024x32x128xf32>, %arg292: tensor<4096x4096xf32>, %arg293: tensor<11008x4096xf32>, %arg294: tensor<8x1024x32x128xf32>, %arg295: tensor<4096x4096xf32>, %arg296: tensor<8x1024x32x128xf32>, %arg297: tensor<4096x4096xf32>, %arg298: tensor<11008x4096xf32>, %arg299: tensor<8x1024x32x128xf32>, %arg300: tensor<4096x4096xf32>, %arg301: tensor<8x1024x32x128xf32>, %arg302: tensor<4096x4096xf32>, %arg303: tensor<11008x4096xf32>, %arg304: tensor<8x1024x32x128xf32>, %arg305: tensor<4096x4096xf32>, %arg306: tensor<8x1024x32x128xf32>, %arg307: tensor<4096x4096xf32>, %arg308: tensor<11008x4096xf32>, %arg309: tensor<8x1024x32x128xf32>, %arg310: tensor<4096x4096xf32>, %arg311: tensor<8x1024x32x128xf32>, %arg312: tensor<4096x4096xf32>, %arg313: tensor<11008x4096xf32>, %arg314: tensor<8x1024x32x128xf32>, %arg315: tensor<4096x4096xf32>, %arg316: tensor<8x1024x32x128xf32>, %arg317: tensor<4096x4096xf32>, %arg318: tensor<11008x4096xf32>, %arg319: tensor<8x1024x32x128xf32>, %arg320: tensor<4096x4096xf32>, %arg321: tensor<8x1024x32x128xf32>, %arg322: tensor<4096x4096xf32>, %arg323: tensor<11008x4096xf32>, %arg324: tensor<8x1024x32x128xf32>, %arg325: tensor<4096x4096xf32>, %arg326: tensor<8x1024x32x128xf32>, %arg327: tensor<4096x4096xf32>, %arg328: tensor<11008x4096xf32>, %arg329: tensor<8x1024x32x128xf32>, %arg330: tensor<4096x4096xf32>, %arg331: tensor<8x1024x32x128xf32>, %arg332: tensor<4096x4096xf32>, %arg333: tensor<11008x4096xf32>, %arg334: tensor<8x1024x32x128xf32>, %arg335: tensor<4096x4096xf32>, %arg336: tensor<8x1024x32x128xf32>, %arg337: tensor<4096x4096xf32>, %arg338: tensor<11008x4096xf32>, %arg339: tensor<8x1024x32x128xf32>, %arg340: tensor<4096x4096xf32>, %arg341: tensor<8x1024x32x128xf32>, %arg342: tensor<4096x4096xf32>, %arg343: tensor<11008x4096xf32>, %arg344: tensor<8x1024x32x128xf32>, %arg345: tensor<4096x4096xf32>, %arg346: tensor<8x1024x32x128xf32>, %arg347: tensor<4096x4096xf32>, %arg348: tensor<11008x4096xf32>, %arg349: tensor<8x1024x32x128xf32>, %arg350: tensor<4096x4096xf32>, %arg351: tensor<8x1024x32x128xf32>, %arg352: tensor<4096x4096xf32>, %arg353: tensor<11008x4096xf32>, %arg354: tensor<8x1024x32x128xf32>, %arg355: tensor<4096x4096xf32>, %arg356: tensor<8x1024x32x128xf32>, %arg357: tensor<4096x4096xf32>, %arg358: tensor<11008x4096xf32>) -> tensor<8x100x32000xf32> { | |
| %cst = stablehlo.constant dense<11.3137083> : tensor<8x32x100x1024xf32> | |
| %c = stablehlo.constant dense<1024> : tensor<100xi64> | |
| %c_0 = stablehlo.constant dense<0> : tensor<100xi64> | |
| %cst_1 = stablehlo.constant dense<9.99999974E-6> : tensor<8x100x1xf32> | |
| %cst_2 = stablehlo.constant dense<2.44140625E-4> : tensor<8x100xf32> | |
| %cst_3 = stablehlo.constant dense<2.000000e+00> : tensor<8x100x4096xf32> | |
| %cst_4 = stablehlo.constant dense<0xFF800000> : tensor<f32> | |
| %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32> | |
| %0 = stablehlo.reshape %arg194 : (tensor<8x100xi64>) -> tensor<800xi64> | |
| %1 = stablehlo.convert %0 : (tensor<800xi64>) -> tensor<800xui32> | |
| %2 = "stablehlo.gather"(%arg195, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 4096>}> : (tensor<32000x4096xf32>, tensor<800xui32>) -> tensor<800x4096xf32> | |
| %3 = stablehlo.reshape %2 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %4 = stablehlo.power %3, %cst_3 : tensor<8x100x4096xf32> | |
| %5 = stablehlo.reduce(%4 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %6 = stablehlo.multiply %5, %cst_2 : tensor<8x100xf32> | |
| %7 = stablehlo.reshape %6 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %8 = stablehlo.add %7, %cst_1 : tensor<8x100x1xf32> | |
| %9 = stablehlo.rsqrt %8 : tensor<8x100x1xf32> | |
| %10 = stablehlo.reshape %9 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %12 = stablehlo.multiply %3, %11 : tensor<8x100x4096xf32> | |
| %13 = stablehlo.broadcast_in_dim %arg193, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %14 = stablehlo.multiply %12, %13 : tensor<8x100x4096xf32> | |
| %15 = stablehlo.reshape %14 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %16 = stablehlo.transpose %arg202, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %17 = stablehlo.dot_general %15, %16, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %18 = stablehlo.reshape %17 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %19 = stablehlo.transpose %18, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %20 = stablehlo.reshape %19 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %21 = stablehlo.slice %20 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %22 = stablehlo.reshape %21 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %23 = stablehlo.slice %20 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %24 = stablehlo.reshape %23 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %25 = stablehlo.complex %22, %24 : tensor<256x100x64xcomplex<f32>> | |
| %26 = stablehlo.convert %arg196 : (tensor<100xi64>) -> tensor<100xui32> | |
| %27 = "stablehlo.gather"(%arg199, %26) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 64>}> : (tensor<2048x64xcomplex<f32>>, tensor<100xui32>) -> tensor<100x64xcomplex<f32>> | |
| %28 = stablehlo.broadcast_in_dim %27, dims = [1, 2] : (tensor<100x64xcomplex<f32>>) -> tensor<256x100x64xcomplex<f32>> | |
| %29 = stablehlo.multiply %25, %28 : tensor<256x100x64xcomplex<f32>> | |
| %30 = stablehlo.real %29 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %31 = stablehlo.reshape %30 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %32 = stablehlo.imag %29 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %33 = stablehlo.reshape %32 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %34 = stablehlo.concatenate %31, %33, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %35 = stablehlo.reshape %34 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %36 = stablehlo.compare LT, %arg196, %c_0 : (tensor<100xi64>, tensor<100xi64>) -> tensor<100xi1> | |
| %37 = stablehlo.add %arg196, %c : tensor<100xi64> | |
| %38 = stablehlo.select %36, %37, %arg196 : tensor<100xi1>, tensor<100xi64> | |
| %39 = stablehlo.reshape %38 : (tensor<100xi64>) -> tensor<100x1xi64> | |
| %40 = stablehlo.transpose %arg200, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %41 = stablehlo.dot_general %15, %40, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %42 = stablehlo.reshape %41 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %43 = stablehlo.transpose %42, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %44 = stablehlo.reshape %43 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %45 = stablehlo.slice %44 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %46 = stablehlo.reshape %45 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %47 = stablehlo.slice %44 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %48 = stablehlo.reshape %47 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %49 = stablehlo.complex %46, %48 : tensor<256x100x64xcomplex<f32>> | |
| %50 = stablehlo.multiply %49, %28 : tensor<256x100x64xcomplex<f32>> | |
| %51 = stablehlo.real %50 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %52 = stablehlo.reshape %51 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %53 = stablehlo.imag %50 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %54 = stablehlo.reshape %53 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %55 = stablehlo.concatenate %52, %54, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %56 = stablehlo.reshape %55 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %57 = stablehlo.transpose %56, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %58 = "stablehlo.scatter"(%arg201, %39, %57) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %59 = stablehlo.transpose %58, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %60 = stablehlo.reshape %59 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %61 = stablehlo.dot_general %35, %60, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %62 = stablehlo.reshape %61 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %63 = stablehlo.divide %62, %cst : tensor<8x32x100x1024xf32> | |
| %64 = "stablehlo.gather"(%arg198, %26) <{dimension_numbers = #stablehlo.gather<offset_dims = [0, 1, 3], collapsed_slice_dims = [2], start_index_map = [2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1024>}> : (tensor<1x1x1024x1024xf32>, tensor<100xui32>) -> tensor<1x1x100x1024xf32> | |
| %65 = stablehlo.reshape %64 : (tensor<1x1x100x1024xf32>) -> tensor<100x1024xf32> | |
| %66 = stablehlo.broadcast_in_dim %65, dims = [2, 3] : (tensor<100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %67 = stablehlo.add %63, %66 : tensor<8x32x100x1024xf32> | |
| %68 = stablehlo.reduce(%67 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %69 = stablehlo.broadcast_in_dim %68, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %70 = stablehlo.subtract %67, %69 : tensor<8x32x100x1024xf32> | |
| %71 = stablehlo.exponential %70 : tensor<8x32x100x1024xf32> | |
| %72 = stablehlo.reduce(%71 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %73 = stablehlo.broadcast_in_dim %72, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %74 = stablehlo.divide %71, %73 : tensor<8x32x100x1024xf32> | |
| %75 = stablehlo.reshape %74 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %76 = stablehlo.transpose %arg192, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %77 = stablehlo.dot_general %15, %76, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %78 = stablehlo.reshape %77 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %79 = "stablehlo.scatter"(%arg197, %39, %78) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %80 = stablehlo.transpose %79, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %81 = stablehlo.reshape %80 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %82 = stablehlo.dot_general %75, %81, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %83 = stablehlo.reshape %82 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %84 = stablehlo.transpose %83, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %85 = stablehlo.reshape %84 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %86 = stablehlo.transpose %arg191, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %87 = stablehlo.dot_general %85, %86, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %88 = stablehlo.reshape %87 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %89 = stablehlo.add %3, %88 : tensor<8x100x4096xf32> | |
| %90 = stablehlo.power %89, %cst_3 : tensor<8x100x4096xf32> | |
| %91 = stablehlo.reduce(%90 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %92 = stablehlo.multiply %91, %cst_2 : tensor<8x100xf32> | |
| %93 = stablehlo.reshape %92 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %94 = stablehlo.add %93, %cst_1 : tensor<8x100x1xf32> | |
| %95 = stablehlo.rsqrt %94 : tensor<8x100x1xf32> | |
| %96 = stablehlo.reshape %95 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %97 = stablehlo.broadcast_in_dim %96, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %98 = stablehlo.multiply %89, %97 : tensor<8x100x4096xf32> | |
| %99 = stablehlo.broadcast_in_dim %arg190, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %100 = stablehlo.multiply %98, %99 : tensor<8x100x4096xf32> | |
| %101 = stablehlo.reshape %100 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %102 = stablehlo.transpose %arg203, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %103 = stablehlo.dot_general %101, %102, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %104 = stablehlo.reshape %103 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %105 = stablehlo.logistic %104 : tensor<8x100x11008xf32> | |
| %106 = stablehlo.multiply %104, %105 : tensor<8x100x11008xf32> | |
| %107 = stablehlo.transpose %arg189, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %108 = stablehlo.dot_general %101, %107, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %109 = stablehlo.reshape %108 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %110 = stablehlo.multiply %106, %109 : tensor<8x100x11008xf32> | |
| %111 = stablehlo.reshape %110 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %112 = stablehlo.transpose %arg188, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %113 = stablehlo.dot_general %111, %112, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %114 = stablehlo.reshape %113 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %115 = stablehlo.add %89, %114 : tensor<8x100x4096xf32> | |
| %116 = stablehlo.power %115, %cst_3 : tensor<8x100x4096xf32> | |
| %117 = stablehlo.reduce(%116 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %118 = stablehlo.multiply %117, %cst_2 : tensor<8x100xf32> | |
| %119 = stablehlo.reshape %118 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %120 = stablehlo.add %119, %cst_1 : tensor<8x100x1xf32> | |
| %121 = stablehlo.rsqrt %120 : tensor<8x100x1xf32> | |
| %122 = stablehlo.reshape %121 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %123 = stablehlo.broadcast_in_dim %122, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %124 = stablehlo.multiply %115, %123 : tensor<8x100x4096xf32> | |
| %125 = stablehlo.broadcast_in_dim %arg187, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %126 = stablehlo.multiply %124, %125 : tensor<8x100x4096xf32> | |
| %127 = stablehlo.reshape %126 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %128 = stablehlo.transpose %arg207, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %129 = stablehlo.dot_general %127, %128, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %130 = stablehlo.reshape %129 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %131 = stablehlo.transpose %130, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %132 = stablehlo.reshape %131 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %133 = stablehlo.slice %132 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %134 = stablehlo.reshape %133 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %135 = stablehlo.slice %132 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %136 = stablehlo.reshape %135 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %137 = stablehlo.complex %134, %136 : tensor<256x100x64xcomplex<f32>> | |
| %138 = stablehlo.multiply %137, %28 : tensor<256x100x64xcomplex<f32>> | |
| %139 = stablehlo.real %138 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %140 = stablehlo.reshape %139 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %141 = stablehlo.imag %138 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %142 = stablehlo.reshape %141 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %143 = stablehlo.concatenate %140, %142, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %144 = stablehlo.reshape %143 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %145 = stablehlo.transpose %arg205, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %146 = stablehlo.dot_general %127, %145, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %147 = stablehlo.reshape %146 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %148 = stablehlo.transpose %147, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %149 = stablehlo.reshape %148 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %150 = stablehlo.slice %149 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %151 = stablehlo.reshape %150 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %152 = stablehlo.slice %149 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %153 = stablehlo.reshape %152 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %154 = stablehlo.complex %151, %153 : tensor<256x100x64xcomplex<f32>> | |
| %155 = stablehlo.multiply %154, %28 : tensor<256x100x64xcomplex<f32>> | |
| %156 = stablehlo.real %155 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %157 = stablehlo.reshape %156 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %158 = stablehlo.imag %155 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %159 = stablehlo.reshape %158 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %160 = stablehlo.concatenate %157, %159, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %161 = stablehlo.reshape %160 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %162 = stablehlo.transpose %161, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %163 = "stablehlo.scatter"(%arg206, %39, %162) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %164 = stablehlo.transpose %163, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %165 = stablehlo.reshape %164 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %166 = stablehlo.dot_general %144, %165, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %167 = stablehlo.reshape %166 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %168 = stablehlo.divide %167, %cst : tensor<8x32x100x1024xf32> | |
| %169 = stablehlo.add %168, %66 : tensor<8x32x100x1024xf32> | |
| %170 = stablehlo.reduce(%169 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %171 = stablehlo.broadcast_in_dim %170, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %172 = stablehlo.subtract %169, %171 : tensor<8x32x100x1024xf32> | |
| %173 = stablehlo.exponential %172 : tensor<8x32x100x1024xf32> | |
| %174 = stablehlo.reduce(%173 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %175 = stablehlo.broadcast_in_dim %174, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %176 = stablehlo.divide %173, %175 : tensor<8x32x100x1024xf32> | |
| %177 = stablehlo.reshape %176 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %178 = stablehlo.transpose %arg186, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %179 = stablehlo.dot_general %127, %178, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %180 = stablehlo.reshape %179 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %181 = "stablehlo.scatter"(%arg204, %39, %180) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %182 = stablehlo.transpose %181, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %183 = stablehlo.reshape %182 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %184 = stablehlo.dot_general %177, %183, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %185 = stablehlo.reshape %184 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %186 = stablehlo.transpose %185, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %187 = stablehlo.reshape %186 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %188 = stablehlo.transpose %arg185, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %189 = stablehlo.dot_general %187, %188, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %190 = stablehlo.reshape %189 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %191 = stablehlo.add %115, %190 : tensor<8x100x4096xf32> | |
| %192 = stablehlo.power %191, %cst_3 : tensor<8x100x4096xf32> | |
| %193 = stablehlo.reduce(%192 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %194 = stablehlo.multiply %193, %cst_2 : tensor<8x100xf32> | |
| %195 = stablehlo.reshape %194 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %196 = stablehlo.add %195, %cst_1 : tensor<8x100x1xf32> | |
| %197 = stablehlo.rsqrt %196 : tensor<8x100x1xf32> | |
| %198 = stablehlo.reshape %197 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %199 = stablehlo.broadcast_in_dim %198, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %200 = stablehlo.multiply %191, %199 : tensor<8x100x4096xf32> | |
| %201 = stablehlo.broadcast_in_dim %arg184, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %202 = stablehlo.multiply %200, %201 : tensor<8x100x4096xf32> | |
| %203 = stablehlo.reshape %202 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %204 = stablehlo.transpose %arg208, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %205 = stablehlo.dot_general %203, %204, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %206 = stablehlo.reshape %205 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %207 = stablehlo.logistic %206 : tensor<8x100x11008xf32> | |
| %208 = stablehlo.multiply %206, %207 : tensor<8x100x11008xf32> | |
| %209 = stablehlo.transpose %arg183, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %210 = stablehlo.dot_general %203, %209, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %211 = stablehlo.reshape %210 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %212 = stablehlo.multiply %208, %211 : tensor<8x100x11008xf32> | |
| %213 = stablehlo.reshape %212 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %214 = stablehlo.transpose %arg182, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %215 = stablehlo.dot_general %213, %214, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %216 = stablehlo.reshape %215 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %217 = stablehlo.add %191, %216 : tensor<8x100x4096xf32> | |
| %218 = stablehlo.power %217, %cst_3 : tensor<8x100x4096xf32> | |
| %219 = stablehlo.reduce(%218 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %220 = stablehlo.multiply %219, %cst_2 : tensor<8x100xf32> | |
| %221 = stablehlo.reshape %220 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %222 = stablehlo.add %221, %cst_1 : tensor<8x100x1xf32> | |
| %223 = stablehlo.rsqrt %222 : tensor<8x100x1xf32> | |
| %224 = stablehlo.reshape %223 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %225 = stablehlo.broadcast_in_dim %224, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %226 = stablehlo.multiply %217, %225 : tensor<8x100x4096xf32> | |
| %227 = stablehlo.broadcast_in_dim %arg181, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %228 = stablehlo.multiply %226, %227 : tensor<8x100x4096xf32> | |
| %229 = stablehlo.reshape %228 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %230 = stablehlo.transpose %arg212, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %231 = stablehlo.dot_general %229, %230, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %232 = stablehlo.reshape %231 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %233 = stablehlo.transpose %232, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %234 = stablehlo.reshape %233 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %235 = stablehlo.slice %234 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %236 = stablehlo.reshape %235 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %237 = stablehlo.slice %234 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %238 = stablehlo.reshape %237 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %239 = stablehlo.complex %236, %238 : tensor<256x100x64xcomplex<f32>> | |
| %240 = stablehlo.multiply %239, %28 : tensor<256x100x64xcomplex<f32>> | |
| %241 = stablehlo.real %240 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %242 = stablehlo.reshape %241 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %243 = stablehlo.imag %240 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %244 = stablehlo.reshape %243 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %245 = stablehlo.concatenate %242, %244, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %246 = stablehlo.reshape %245 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %247 = stablehlo.transpose %arg210, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %248 = stablehlo.dot_general %229, %247, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %249 = stablehlo.reshape %248 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %250 = stablehlo.transpose %249, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %251 = stablehlo.reshape %250 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %252 = stablehlo.slice %251 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %253 = stablehlo.reshape %252 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %254 = stablehlo.slice %251 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %255 = stablehlo.reshape %254 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %256 = stablehlo.complex %253, %255 : tensor<256x100x64xcomplex<f32>> | |
| %257 = stablehlo.multiply %256, %28 : tensor<256x100x64xcomplex<f32>> | |
| %258 = stablehlo.real %257 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %259 = stablehlo.reshape %258 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %260 = stablehlo.imag %257 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %261 = stablehlo.reshape %260 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %262 = stablehlo.concatenate %259, %261, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %263 = stablehlo.reshape %262 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %264 = stablehlo.transpose %263, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %265 = "stablehlo.scatter"(%arg211, %39, %264) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %266 = stablehlo.transpose %265, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %267 = stablehlo.reshape %266 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %268 = stablehlo.dot_general %246, %267, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %269 = stablehlo.reshape %268 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %270 = stablehlo.divide %269, %cst : tensor<8x32x100x1024xf32> | |
| %271 = stablehlo.add %270, %66 : tensor<8x32x100x1024xf32> | |
| %272 = stablehlo.reduce(%271 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %273 = stablehlo.broadcast_in_dim %272, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %274 = stablehlo.subtract %271, %273 : tensor<8x32x100x1024xf32> | |
| %275 = stablehlo.exponential %274 : tensor<8x32x100x1024xf32> | |
| %276 = stablehlo.reduce(%275 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %277 = stablehlo.broadcast_in_dim %276, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %278 = stablehlo.divide %275, %277 : tensor<8x32x100x1024xf32> | |
| %279 = stablehlo.reshape %278 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %280 = stablehlo.transpose %arg180, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %281 = stablehlo.dot_general %229, %280, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %282 = stablehlo.reshape %281 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %283 = "stablehlo.scatter"(%arg209, %39, %282) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %284 = stablehlo.transpose %283, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %285 = stablehlo.reshape %284 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %286 = stablehlo.dot_general %279, %285, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %287 = stablehlo.reshape %286 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %288 = stablehlo.transpose %287, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %289 = stablehlo.reshape %288 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %290 = stablehlo.transpose %arg179, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %291 = stablehlo.dot_general %289, %290, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %292 = stablehlo.reshape %291 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %293 = stablehlo.add %217, %292 : tensor<8x100x4096xf32> | |
| %294 = stablehlo.power %293, %cst_3 : tensor<8x100x4096xf32> | |
| %295 = stablehlo.reduce(%294 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %296 = stablehlo.multiply %295, %cst_2 : tensor<8x100xf32> | |
| %297 = stablehlo.reshape %296 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %298 = stablehlo.add %297, %cst_1 : tensor<8x100x1xf32> | |
| %299 = stablehlo.rsqrt %298 : tensor<8x100x1xf32> | |
| %300 = stablehlo.reshape %299 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %301 = stablehlo.broadcast_in_dim %300, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %302 = stablehlo.multiply %293, %301 : tensor<8x100x4096xf32> | |
| %303 = stablehlo.broadcast_in_dim %arg178, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %304 = stablehlo.multiply %302, %303 : tensor<8x100x4096xf32> | |
| %305 = stablehlo.reshape %304 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %306 = stablehlo.transpose %arg213, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %307 = stablehlo.dot_general %305, %306, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %308 = stablehlo.reshape %307 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %309 = stablehlo.logistic %308 : tensor<8x100x11008xf32> | |
| %310 = stablehlo.multiply %308, %309 : tensor<8x100x11008xf32> | |
| %311 = stablehlo.transpose %arg177, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %312 = stablehlo.dot_general %305, %311, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %313 = stablehlo.reshape %312 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %314 = stablehlo.multiply %310, %313 : tensor<8x100x11008xf32> | |
| %315 = stablehlo.reshape %314 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %316 = stablehlo.transpose %arg176, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %317 = stablehlo.dot_general %315, %316, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %318 = stablehlo.reshape %317 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %319 = stablehlo.add %293, %318 : tensor<8x100x4096xf32> | |
| %320 = stablehlo.power %319, %cst_3 : tensor<8x100x4096xf32> | |
| %321 = stablehlo.reduce(%320 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %322 = stablehlo.multiply %321, %cst_2 : tensor<8x100xf32> | |
| %323 = stablehlo.reshape %322 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %324 = stablehlo.add %323, %cst_1 : tensor<8x100x1xf32> | |
| %325 = stablehlo.rsqrt %324 : tensor<8x100x1xf32> | |
| %326 = stablehlo.reshape %325 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %327 = stablehlo.broadcast_in_dim %326, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %328 = stablehlo.multiply %319, %327 : tensor<8x100x4096xf32> | |
| %329 = stablehlo.broadcast_in_dim %arg175, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %330 = stablehlo.multiply %328, %329 : tensor<8x100x4096xf32> | |
| %331 = stablehlo.reshape %330 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %332 = stablehlo.transpose %arg217, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %333 = stablehlo.dot_general %331, %332, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %334 = stablehlo.reshape %333 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %335 = stablehlo.transpose %334, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %336 = stablehlo.reshape %335 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %337 = stablehlo.slice %336 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %338 = stablehlo.reshape %337 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %339 = stablehlo.slice %336 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %340 = stablehlo.reshape %339 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %341 = stablehlo.complex %338, %340 : tensor<256x100x64xcomplex<f32>> | |
| %342 = stablehlo.multiply %341, %28 : tensor<256x100x64xcomplex<f32>> | |
| %343 = stablehlo.real %342 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %344 = stablehlo.reshape %343 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %345 = stablehlo.imag %342 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %346 = stablehlo.reshape %345 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %347 = stablehlo.concatenate %344, %346, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %348 = stablehlo.reshape %347 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %349 = stablehlo.transpose %arg215, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %350 = stablehlo.dot_general %331, %349, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %351 = stablehlo.reshape %350 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %352 = stablehlo.transpose %351, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %353 = stablehlo.reshape %352 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %354 = stablehlo.slice %353 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %355 = stablehlo.reshape %354 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %356 = stablehlo.slice %353 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %357 = stablehlo.reshape %356 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %358 = stablehlo.complex %355, %357 : tensor<256x100x64xcomplex<f32>> | |
| %359 = stablehlo.multiply %358, %28 : tensor<256x100x64xcomplex<f32>> | |
| %360 = stablehlo.real %359 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %361 = stablehlo.reshape %360 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %362 = stablehlo.imag %359 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %363 = stablehlo.reshape %362 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %364 = stablehlo.concatenate %361, %363, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %365 = stablehlo.reshape %364 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %366 = stablehlo.transpose %365, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %367 = "stablehlo.scatter"(%arg216, %39, %366) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %368 = stablehlo.transpose %367, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %369 = stablehlo.reshape %368 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %370 = stablehlo.dot_general %348, %369, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %371 = stablehlo.reshape %370 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %372 = stablehlo.divide %371, %cst : tensor<8x32x100x1024xf32> | |
| %373 = stablehlo.add %372, %66 : tensor<8x32x100x1024xf32> | |
| %374 = stablehlo.reduce(%373 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %375 = stablehlo.broadcast_in_dim %374, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %376 = stablehlo.subtract %373, %375 : tensor<8x32x100x1024xf32> | |
| %377 = stablehlo.exponential %376 : tensor<8x32x100x1024xf32> | |
| %378 = stablehlo.reduce(%377 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %379 = stablehlo.broadcast_in_dim %378, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %380 = stablehlo.divide %377, %379 : tensor<8x32x100x1024xf32> | |
| %381 = stablehlo.reshape %380 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %382 = stablehlo.transpose %arg174, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %383 = stablehlo.dot_general %331, %382, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %384 = stablehlo.reshape %383 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %385 = "stablehlo.scatter"(%arg214, %39, %384) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %386 = stablehlo.transpose %385, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %387 = stablehlo.reshape %386 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %388 = stablehlo.dot_general %381, %387, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %389 = stablehlo.reshape %388 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %390 = stablehlo.transpose %389, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %391 = stablehlo.reshape %390 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %392 = stablehlo.transpose %arg173, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %393 = stablehlo.dot_general %391, %392, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %394 = stablehlo.reshape %393 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %395 = stablehlo.add %319, %394 : tensor<8x100x4096xf32> | |
| %396 = stablehlo.power %395, %cst_3 : tensor<8x100x4096xf32> | |
| %397 = stablehlo.reduce(%396 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %398 = stablehlo.multiply %397, %cst_2 : tensor<8x100xf32> | |
| %399 = stablehlo.reshape %398 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %400 = stablehlo.add %399, %cst_1 : tensor<8x100x1xf32> | |
| %401 = stablehlo.rsqrt %400 : tensor<8x100x1xf32> | |
| %402 = stablehlo.reshape %401 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %403 = stablehlo.broadcast_in_dim %402, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %404 = stablehlo.multiply %395, %403 : tensor<8x100x4096xf32> | |
| %405 = stablehlo.broadcast_in_dim %arg172, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %406 = stablehlo.multiply %404, %405 : tensor<8x100x4096xf32> | |
| %407 = stablehlo.reshape %406 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %408 = stablehlo.transpose %arg218, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %409 = stablehlo.dot_general %407, %408, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %410 = stablehlo.reshape %409 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %411 = stablehlo.logistic %410 : tensor<8x100x11008xf32> | |
| %412 = stablehlo.multiply %410, %411 : tensor<8x100x11008xf32> | |
| %413 = stablehlo.transpose %arg171, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %414 = stablehlo.dot_general %407, %413, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %415 = stablehlo.reshape %414 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %416 = stablehlo.multiply %412, %415 : tensor<8x100x11008xf32> | |
| %417 = stablehlo.reshape %416 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %418 = stablehlo.transpose %arg170, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %419 = stablehlo.dot_general %417, %418, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %420 = stablehlo.reshape %419 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %421 = stablehlo.add %395, %420 : tensor<8x100x4096xf32> | |
| %422 = stablehlo.power %421, %cst_3 : tensor<8x100x4096xf32> | |
| %423 = stablehlo.reduce(%422 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %424 = stablehlo.multiply %423, %cst_2 : tensor<8x100xf32> | |
| %425 = stablehlo.reshape %424 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %426 = stablehlo.add %425, %cst_1 : tensor<8x100x1xf32> | |
| %427 = stablehlo.rsqrt %426 : tensor<8x100x1xf32> | |
| %428 = stablehlo.reshape %427 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %429 = stablehlo.broadcast_in_dim %428, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %430 = stablehlo.multiply %421, %429 : tensor<8x100x4096xf32> | |
| %431 = stablehlo.broadcast_in_dim %arg169, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %432 = stablehlo.multiply %430, %431 : tensor<8x100x4096xf32> | |
| %433 = stablehlo.reshape %432 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %434 = stablehlo.transpose %arg222, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %435 = stablehlo.dot_general %433, %434, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %436 = stablehlo.reshape %435 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %437 = stablehlo.transpose %436, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %438 = stablehlo.reshape %437 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %439 = stablehlo.slice %438 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %440 = stablehlo.reshape %439 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %441 = stablehlo.slice %438 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %442 = stablehlo.reshape %441 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %443 = stablehlo.complex %440, %442 : tensor<256x100x64xcomplex<f32>> | |
| %444 = stablehlo.multiply %443, %28 : tensor<256x100x64xcomplex<f32>> | |
| %445 = stablehlo.real %444 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %446 = stablehlo.reshape %445 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %447 = stablehlo.imag %444 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %448 = stablehlo.reshape %447 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %449 = stablehlo.concatenate %446, %448, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %450 = stablehlo.reshape %449 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %451 = stablehlo.transpose %arg220, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %452 = stablehlo.dot_general %433, %451, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %453 = stablehlo.reshape %452 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %454 = stablehlo.transpose %453, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %455 = stablehlo.reshape %454 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %456 = stablehlo.slice %455 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %457 = stablehlo.reshape %456 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %458 = stablehlo.slice %455 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %459 = stablehlo.reshape %458 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %460 = stablehlo.complex %457, %459 : tensor<256x100x64xcomplex<f32>> | |
| %461 = stablehlo.multiply %460, %28 : tensor<256x100x64xcomplex<f32>> | |
| %462 = stablehlo.real %461 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %463 = stablehlo.reshape %462 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %464 = stablehlo.imag %461 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %465 = stablehlo.reshape %464 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %466 = stablehlo.concatenate %463, %465, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %467 = stablehlo.reshape %466 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %468 = stablehlo.transpose %467, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %469 = "stablehlo.scatter"(%arg221, %39, %468) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %470 = stablehlo.transpose %469, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %471 = stablehlo.reshape %470 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %472 = stablehlo.dot_general %450, %471, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %473 = stablehlo.reshape %472 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %474 = stablehlo.divide %473, %cst : tensor<8x32x100x1024xf32> | |
| %475 = stablehlo.add %474, %66 : tensor<8x32x100x1024xf32> | |
| %476 = stablehlo.reduce(%475 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %477 = stablehlo.broadcast_in_dim %476, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %478 = stablehlo.subtract %475, %477 : tensor<8x32x100x1024xf32> | |
| %479 = stablehlo.exponential %478 : tensor<8x32x100x1024xf32> | |
| %480 = stablehlo.reduce(%479 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %481 = stablehlo.broadcast_in_dim %480, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %482 = stablehlo.divide %479, %481 : tensor<8x32x100x1024xf32> | |
| %483 = stablehlo.reshape %482 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %484 = stablehlo.transpose %arg168, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %485 = stablehlo.dot_general %433, %484, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %486 = stablehlo.reshape %485 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %487 = "stablehlo.scatter"(%arg219, %39, %486) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %488 = stablehlo.transpose %487, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %489 = stablehlo.reshape %488 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %490 = stablehlo.dot_general %483, %489, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %491 = stablehlo.reshape %490 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %492 = stablehlo.transpose %491, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %493 = stablehlo.reshape %492 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %494 = stablehlo.transpose %arg167, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %495 = stablehlo.dot_general %493, %494, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %496 = stablehlo.reshape %495 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %497 = stablehlo.add %421, %496 : tensor<8x100x4096xf32> | |
| %498 = stablehlo.power %497, %cst_3 : tensor<8x100x4096xf32> | |
| %499 = stablehlo.reduce(%498 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %500 = stablehlo.multiply %499, %cst_2 : tensor<8x100xf32> | |
| %501 = stablehlo.reshape %500 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %502 = stablehlo.add %501, %cst_1 : tensor<8x100x1xf32> | |
| %503 = stablehlo.rsqrt %502 : tensor<8x100x1xf32> | |
| %504 = stablehlo.reshape %503 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %505 = stablehlo.broadcast_in_dim %504, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %506 = stablehlo.multiply %497, %505 : tensor<8x100x4096xf32> | |
| %507 = stablehlo.broadcast_in_dim %arg166, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %508 = stablehlo.multiply %506, %507 : tensor<8x100x4096xf32> | |
| %509 = stablehlo.reshape %508 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %510 = stablehlo.transpose %arg223, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %511 = stablehlo.dot_general %509, %510, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %512 = stablehlo.reshape %511 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %513 = stablehlo.logistic %512 : tensor<8x100x11008xf32> | |
| %514 = stablehlo.multiply %512, %513 : tensor<8x100x11008xf32> | |
| %515 = stablehlo.transpose %arg165, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %516 = stablehlo.dot_general %509, %515, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %517 = stablehlo.reshape %516 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %518 = stablehlo.multiply %514, %517 : tensor<8x100x11008xf32> | |
| %519 = stablehlo.reshape %518 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %520 = stablehlo.transpose %arg164, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %521 = stablehlo.dot_general %519, %520, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %522 = stablehlo.reshape %521 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %523 = stablehlo.add %497, %522 : tensor<8x100x4096xf32> | |
| %524 = stablehlo.power %523, %cst_3 : tensor<8x100x4096xf32> | |
| %525 = stablehlo.reduce(%524 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %526 = stablehlo.multiply %525, %cst_2 : tensor<8x100xf32> | |
| %527 = stablehlo.reshape %526 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %528 = stablehlo.add %527, %cst_1 : tensor<8x100x1xf32> | |
| %529 = stablehlo.rsqrt %528 : tensor<8x100x1xf32> | |
| %530 = stablehlo.reshape %529 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %531 = stablehlo.broadcast_in_dim %530, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %532 = stablehlo.multiply %523, %531 : tensor<8x100x4096xf32> | |
| %533 = stablehlo.broadcast_in_dim %arg163, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %534 = stablehlo.multiply %532, %533 : tensor<8x100x4096xf32> | |
| %535 = stablehlo.reshape %534 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %536 = stablehlo.transpose %arg227, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %537 = stablehlo.dot_general %535, %536, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %538 = stablehlo.reshape %537 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %539 = stablehlo.transpose %538, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %540 = stablehlo.reshape %539 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %541 = stablehlo.slice %540 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %542 = stablehlo.reshape %541 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %543 = stablehlo.slice %540 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %544 = stablehlo.reshape %543 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %545 = stablehlo.complex %542, %544 : tensor<256x100x64xcomplex<f32>> | |
| %546 = stablehlo.multiply %545, %28 : tensor<256x100x64xcomplex<f32>> | |
| %547 = stablehlo.real %546 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %548 = stablehlo.reshape %547 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %549 = stablehlo.imag %546 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %550 = stablehlo.reshape %549 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %551 = stablehlo.concatenate %548, %550, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %552 = stablehlo.reshape %551 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %553 = stablehlo.transpose %arg225, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %554 = stablehlo.dot_general %535, %553, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %555 = stablehlo.reshape %554 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %556 = stablehlo.transpose %555, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %557 = stablehlo.reshape %556 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %558 = stablehlo.slice %557 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %559 = stablehlo.reshape %558 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %560 = stablehlo.slice %557 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %561 = stablehlo.reshape %560 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %562 = stablehlo.complex %559, %561 : tensor<256x100x64xcomplex<f32>> | |
| %563 = stablehlo.multiply %562, %28 : tensor<256x100x64xcomplex<f32>> | |
| %564 = stablehlo.real %563 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %565 = stablehlo.reshape %564 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %566 = stablehlo.imag %563 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %567 = stablehlo.reshape %566 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %568 = stablehlo.concatenate %565, %567, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %569 = stablehlo.reshape %568 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %570 = stablehlo.transpose %569, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %571 = "stablehlo.scatter"(%arg226, %39, %570) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %572 = stablehlo.transpose %571, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %573 = stablehlo.reshape %572 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %574 = stablehlo.dot_general %552, %573, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %575 = stablehlo.reshape %574 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %576 = stablehlo.divide %575, %cst : tensor<8x32x100x1024xf32> | |
| %577 = stablehlo.add %576, %66 : tensor<8x32x100x1024xf32> | |
| %578 = stablehlo.reduce(%577 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %579 = stablehlo.broadcast_in_dim %578, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %580 = stablehlo.subtract %577, %579 : tensor<8x32x100x1024xf32> | |
| %581 = stablehlo.exponential %580 : tensor<8x32x100x1024xf32> | |
| %582 = stablehlo.reduce(%581 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %583 = stablehlo.broadcast_in_dim %582, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %584 = stablehlo.divide %581, %583 : tensor<8x32x100x1024xf32> | |
| %585 = stablehlo.reshape %584 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %586 = stablehlo.transpose %arg162, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %587 = stablehlo.dot_general %535, %586, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %588 = stablehlo.reshape %587 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %589 = "stablehlo.scatter"(%arg224, %39, %588) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %590 = stablehlo.transpose %589, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %591 = stablehlo.reshape %590 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %592 = stablehlo.dot_general %585, %591, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %593 = stablehlo.reshape %592 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %594 = stablehlo.transpose %593, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %595 = stablehlo.reshape %594 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %596 = stablehlo.transpose %arg161, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %597 = stablehlo.dot_general %595, %596, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %598 = stablehlo.reshape %597 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %599 = stablehlo.add %523, %598 : tensor<8x100x4096xf32> | |
| %600 = stablehlo.power %599, %cst_3 : tensor<8x100x4096xf32> | |
| %601 = stablehlo.reduce(%600 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %602 = stablehlo.multiply %601, %cst_2 : tensor<8x100xf32> | |
| %603 = stablehlo.reshape %602 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %604 = stablehlo.add %603, %cst_1 : tensor<8x100x1xf32> | |
| %605 = stablehlo.rsqrt %604 : tensor<8x100x1xf32> | |
| %606 = stablehlo.reshape %605 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %607 = stablehlo.broadcast_in_dim %606, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %608 = stablehlo.multiply %599, %607 : tensor<8x100x4096xf32> | |
| %609 = stablehlo.broadcast_in_dim %arg160, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %610 = stablehlo.multiply %608, %609 : tensor<8x100x4096xf32> | |
| %611 = stablehlo.reshape %610 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %612 = stablehlo.transpose %arg228, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %613 = stablehlo.dot_general %611, %612, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %614 = stablehlo.reshape %613 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %615 = stablehlo.logistic %614 : tensor<8x100x11008xf32> | |
| %616 = stablehlo.multiply %614, %615 : tensor<8x100x11008xf32> | |
| %617 = stablehlo.transpose %arg159, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %618 = stablehlo.dot_general %611, %617, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %619 = stablehlo.reshape %618 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %620 = stablehlo.multiply %616, %619 : tensor<8x100x11008xf32> | |
| %621 = stablehlo.reshape %620 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %622 = stablehlo.transpose %arg158, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %623 = stablehlo.dot_general %621, %622, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %624 = stablehlo.reshape %623 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %625 = stablehlo.add %599, %624 : tensor<8x100x4096xf32> | |
| %626 = stablehlo.power %625, %cst_3 : tensor<8x100x4096xf32> | |
| %627 = stablehlo.reduce(%626 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %628 = stablehlo.multiply %627, %cst_2 : tensor<8x100xf32> | |
| %629 = stablehlo.reshape %628 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %630 = stablehlo.add %629, %cst_1 : tensor<8x100x1xf32> | |
| %631 = stablehlo.rsqrt %630 : tensor<8x100x1xf32> | |
| %632 = stablehlo.reshape %631 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %633 = stablehlo.broadcast_in_dim %632, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %634 = stablehlo.multiply %625, %633 : tensor<8x100x4096xf32> | |
| %635 = stablehlo.broadcast_in_dim %arg157, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %636 = stablehlo.multiply %634, %635 : tensor<8x100x4096xf32> | |
| %637 = stablehlo.reshape %636 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %638 = stablehlo.transpose %arg232, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %639 = stablehlo.dot_general %637, %638, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %640 = stablehlo.reshape %639 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %641 = stablehlo.transpose %640, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %642 = stablehlo.reshape %641 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %643 = stablehlo.slice %642 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %644 = stablehlo.reshape %643 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %645 = stablehlo.slice %642 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %646 = stablehlo.reshape %645 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %647 = stablehlo.complex %644, %646 : tensor<256x100x64xcomplex<f32>> | |
| %648 = stablehlo.multiply %647, %28 : tensor<256x100x64xcomplex<f32>> | |
| %649 = stablehlo.real %648 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %650 = stablehlo.reshape %649 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %651 = stablehlo.imag %648 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %652 = stablehlo.reshape %651 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %653 = stablehlo.concatenate %650, %652, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %654 = stablehlo.reshape %653 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %655 = stablehlo.transpose %arg230, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %656 = stablehlo.dot_general %637, %655, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %657 = stablehlo.reshape %656 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %658 = stablehlo.transpose %657, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %659 = stablehlo.reshape %658 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %660 = stablehlo.slice %659 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %661 = stablehlo.reshape %660 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %662 = stablehlo.slice %659 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %663 = stablehlo.reshape %662 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %664 = stablehlo.complex %661, %663 : tensor<256x100x64xcomplex<f32>> | |
| %665 = stablehlo.multiply %664, %28 : tensor<256x100x64xcomplex<f32>> | |
| %666 = stablehlo.real %665 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %667 = stablehlo.reshape %666 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %668 = stablehlo.imag %665 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %669 = stablehlo.reshape %668 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %670 = stablehlo.concatenate %667, %669, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %671 = stablehlo.reshape %670 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %672 = stablehlo.transpose %671, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %673 = "stablehlo.scatter"(%arg231, %39, %672) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %674 = stablehlo.transpose %673, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %675 = stablehlo.reshape %674 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %676 = stablehlo.dot_general %654, %675, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %677 = stablehlo.reshape %676 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %678 = stablehlo.divide %677, %cst : tensor<8x32x100x1024xf32> | |
| %679 = stablehlo.add %678, %66 : tensor<8x32x100x1024xf32> | |
| %680 = stablehlo.reduce(%679 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %681 = stablehlo.broadcast_in_dim %680, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %682 = stablehlo.subtract %679, %681 : tensor<8x32x100x1024xf32> | |
| %683 = stablehlo.exponential %682 : tensor<8x32x100x1024xf32> | |
| %684 = stablehlo.reduce(%683 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %685 = stablehlo.broadcast_in_dim %684, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %686 = stablehlo.divide %683, %685 : tensor<8x32x100x1024xf32> | |
| %687 = stablehlo.reshape %686 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %688 = stablehlo.transpose %arg156, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %689 = stablehlo.dot_general %637, %688, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %690 = stablehlo.reshape %689 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %691 = "stablehlo.scatter"(%arg229, %39, %690) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %692 = stablehlo.transpose %691, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %693 = stablehlo.reshape %692 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %694 = stablehlo.dot_general %687, %693, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %695 = stablehlo.reshape %694 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %696 = stablehlo.transpose %695, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %697 = stablehlo.reshape %696 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %698 = stablehlo.transpose %arg155, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %699 = stablehlo.dot_general %697, %698, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %700 = stablehlo.reshape %699 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %701 = stablehlo.add %625, %700 : tensor<8x100x4096xf32> | |
| %702 = stablehlo.power %701, %cst_3 : tensor<8x100x4096xf32> | |
| %703 = stablehlo.reduce(%702 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %704 = stablehlo.multiply %703, %cst_2 : tensor<8x100xf32> | |
| %705 = stablehlo.reshape %704 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %706 = stablehlo.add %705, %cst_1 : tensor<8x100x1xf32> | |
| %707 = stablehlo.rsqrt %706 : tensor<8x100x1xf32> | |
| %708 = stablehlo.reshape %707 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %709 = stablehlo.broadcast_in_dim %708, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %710 = stablehlo.multiply %701, %709 : tensor<8x100x4096xf32> | |
| %711 = stablehlo.broadcast_in_dim %arg154, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %712 = stablehlo.multiply %710, %711 : tensor<8x100x4096xf32> | |
| %713 = stablehlo.reshape %712 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %714 = stablehlo.transpose %arg233, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %715 = stablehlo.dot_general %713, %714, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %716 = stablehlo.reshape %715 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %717 = stablehlo.logistic %716 : tensor<8x100x11008xf32> | |
| %718 = stablehlo.multiply %716, %717 : tensor<8x100x11008xf32> | |
| %719 = stablehlo.transpose %arg153, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %720 = stablehlo.dot_general %713, %719, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %721 = stablehlo.reshape %720 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %722 = stablehlo.multiply %718, %721 : tensor<8x100x11008xf32> | |
| %723 = stablehlo.reshape %722 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %724 = stablehlo.transpose %arg152, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %725 = stablehlo.dot_general %723, %724, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %726 = stablehlo.reshape %725 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %727 = stablehlo.add %701, %726 : tensor<8x100x4096xf32> | |
| %728 = stablehlo.power %727, %cst_3 : tensor<8x100x4096xf32> | |
| %729 = stablehlo.reduce(%728 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %730 = stablehlo.multiply %729, %cst_2 : tensor<8x100xf32> | |
| %731 = stablehlo.reshape %730 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %732 = stablehlo.add %731, %cst_1 : tensor<8x100x1xf32> | |
| %733 = stablehlo.rsqrt %732 : tensor<8x100x1xf32> | |
| %734 = stablehlo.reshape %733 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %735 = stablehlo.broadcast_in_dim %734, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %736 = stablehlo.multiply %727, %735 : tensor<8x100x4096xf32> | |
| %737 = stablehlo.broadcast_in_dim %arg151, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %738 = stablehlo.multiply %736, %737 : tensor<8x100x4096xf32> | |
| %739 = stablehlo.reshape %738 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %740 = stablehlo.transpose %arg237, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %741 = stablehlo.dot_general %739, %740, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %742 = stablehlo.reshape %741 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %743 = stablehlo.transpose %742, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %744 = stablehlo.reshape %743 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %745 = stablehlo.slice %744 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %746 = stablehlo.reshape %745 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %747 = stablehlo.slice %744 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %748 = stablehlo.reshape %747 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %749 = stablehlo.complex %746, %748 : tensor<256x100x64xcomplex<f32>> | |
| %750 = stablehlo.multiply %749, %28 : tensor<256x100x64xcomplex<f32>> | |
| %751 = stablehlo.real %750 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %752 = stablehlo.reshape %751 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %753 = stablehlo.imag %750 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %754 = stablehlo.reshape %753 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %755 = stablehlo.concatenate %752, %754, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %756 = stablehlo.reshape %755 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %757 = stablehlo.transpose %arg235, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %758 = stablehlo.dot_general %739, %757, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %759 = stablehlo.reshape %758 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %760 = stablehlo.transpose %759, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %761 = stablehlo.reshape %760 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %762 = stablehlo.slice %761 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %763 = stablehlo.reshape %762 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %764 = stablehlo.slice %761 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %765 = stablehlo.reshape %764 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %766 = stablehlo.complex %763, %765 : tensor<256x100x64xcomplex<f32>> | |
| %767 = stablehlo.multiply %766, %28 : tensor<256x100x64xcomplex<f32>> | |
| %768 = stablehlo.real %767 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %769 = stablehlo.reshape %768 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %770 = stablehlo.imag %767 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %771 = stablehlo.reshape %770 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %772 = stablehlo.concatenate %769, %771, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %773 = stablehlo.reshape %772 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %774 = stablehlo.transpose %773, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %775 = "stablehlo.scatter"(%arg236, %39, %774) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %776 = stablehlo.transpose %775, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %777 = stablehlo.reshape %776 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %778 = stablehlo.dot_general %756, %777, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %779 = stablehlo.reshape %778 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %780 = stablehlo.divide %779, %cst : tensor<8x32x100x1024xf32> | |
| %781 = stablehlo.add %780, %66 : tensor<8x32x100x1024xf32> | |
| %782 = stablehlo.reduce(%781 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %783 = stablehlo.broadcast_in_dim %782, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %784 = stablehlo.subtract %781, %783 : tensor<8x32x100x1024xf32> | |
| %785 = stablehlo.exponential %784 : tensor<8x32x100x1024xf32> | |
| %786 = stablehlo.reduce(%785 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %787 = stablehlo.broadcast_in_dim %786, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %788 = stablehlo.divide %785, %787 : tensor<8x32x100x1024xf32> | |
| %789 = stablehlo.reshape %788 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %790 = stablehlo.transpose %arg150, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %791 = stablehlo.dot_general %739, %790, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %792 = stablehlo.reshape %791 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %793 = "stablehlo.scatter"(%arg234, %39, %792) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %794 = stablehlo.transpose %793, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %795 = stablehlo.reshape %794 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %796 = stablehlo.dot_general %789, %795, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %797 = stablehlo.reshape %796 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %798 = stablehlo.transpose %797, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %799 = stablehlo.reshape %798 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %800 = stablehlo.transpose %arg149, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %801 = stablehlo.dot_general %799, %800, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %802 = stablehlo.reshape %801 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %803 = stablehlo.add %727, %802 : tensor<8x100x4096xf32> | |
| %804 = stablehlo.power %803, %cst_3 : tensor<8x100x4096xf32> | |
| %805 = stablehlo.reduce(%804 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %806 = stablehlo.multiply %805, %cst_2 : tensor<8x100xf32> | |
| %807 = stablehlo.reshape %806 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %808 = stablehlo.add %807, %cst_1 : tensor<8x100x1xf32> | |
| %809 = stablehlo.rsqrt %808 : tensor<8x100x1xf32> | |
| %810 = stablehlo.reshape %809 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %811 = stablehlo.broadcast_in_dim %810, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %812 = stablehlo.multiply %803, %811 : tensor<8x100x4096xf32> | |
| %813 = stablehlo.broadcast_in_dim %arg148, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %814 = stablehlo.multiply %812, %813 : tensor<8x100x4096xf32> | |
| %815 = stablehlo.reshape %814 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %816 = stablehlo.transpose %arg238, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %817 = stablehlo.dot_general %815, %816, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %818 = stablehlo.reshape %817 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %819 = stablehlo.logistic %818 : tensor<8x100x11008xf32> | |
| %820 = stablehlo.multiply %818, %819 : tensor<8x100x11008xf32> | |
| %821 = stablehlo.transpose %arg147, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %822 = stablehlo.dot_general %815, %821, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %823 = stablehlo.reshape %822 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %824 = stablehlo.multiply %820, %823 : tensor<8x100x11008xf32> | |
| %825 = stablehlo.reshape %824 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %826 = stablehlo.transpose %arg146, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %827 = stablehlo.dot_general %825, %826, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %828 = stablehlo.reshape %827 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %829 = stablehlo.add %803, %828 : tensor<8x100x4096xf32> | |
| %830 = stablehlo.power %829, %cst_3 : tensor<8x100x4096xf32> | |
| %831 = stablehlo.reduce(%830 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %832 = stablehlo.multiply %831, %cst_2 : tensor<8x100xf32> | |
| %833 = stablehlo.reshape %832 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %834 = stablehlo.add %833, %cst_1 : tensor<8x100x1xf32> | |
| %835 = stablehlo.rsqrt %834 : tensor<8x100x1xf32> | |
| %836 = stablehlo.reshape %835 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %837 = stablehlo.broadcast_in_dim %836, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %838 = stablehlo.multiply %829, %837 : tensor<8x100x4096xf32> | |
| %839 = stablehlo.broadcast_in_dim %arg145, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %840 = stablehlo.multiply %838, %839 : tensor<8x100x4096xf32> | |
| %841 = stablehlo.reshape %840 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %842 = stablehlo.transpose %arg242, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %843 = stablehlo.dot_general %841, %842, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %844 = stablehlo.reshape %843 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %845 = stablehlo.transpose %844, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %846 = stablehlo.reshape %845 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %847 = stablehlo.slice %846 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %848 = stablehlo.reshape %847 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %849 = stablehlo.slice %846 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %850 = stablehlo.reshape %849 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %851 = stablehlo.complex %848, %850 : tensor<256x100x64xcomplex<f32>> | |
| %852 = stablehlo.multiply %851, %28 : tensor<256x100x64xcomplex<f32>> | |
| %853 = stablehlo.real %852 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %854 = stablehlo.reshape %853 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %855 = stablehlo.imag %852 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %856 = stablehlo.reshape %855 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %857 = stablehlo.concatenate %854, %856, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %858 = stablehlo.reshape %857 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %859 = stablehlo.transpose %arg240, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %860 = stablehlo.dot_general %841, %859, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %861 = stablehlo.reshape %860 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %862 = stablehlo.transpose %861, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %863 = stablehlo.reshape %862 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %864 = stablehlo.slice %863 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %865 = stablehlo.reshape %864 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %866 = stablehlo.slice %863 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %867 = stablehlo.reshape %866 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %868 = stablehlo.complex %865, %867 : tensor<256x100x64xcomplex<f32>> | |
| %869 = stablehlo.multiply %868, %28 : tensor<256x100x64xcomplex<f32>> | |
| %870 = stablehlo.real %869 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %871 = stablehlo.reshape %870 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %872 = stablehlo.imag %869 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %873 = stablehlo.reshape %872 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %874 = stablehlo.concatenate %871, %873, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %875 = stablehlo.reshape %874 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %876 = stablehlo.transpose %875, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %877 = "stablehlo.scatter"(%arg241, %39, %876) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %878 = stablehlo.transpose %877, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %879 = stablehlo.reshape %878 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %880 = stablehlo.dot_general %858, %879, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %881 = stablehlo.reshape %880 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %882 = stablehlo.divide %881, %cst : tensor<8x32x100x1024xf32> | |
| %883 = stablehlo.add %882, %66 : tensor<8x32x100x1024xf32> | |
| %884 = stablehlo.reduce(%883 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %885 = stablehlo.broadcast_in_dim %884, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %886 = stablehlo.subtract %883, %885 : tensor<8x32x100x1024xf32> | |
| %887 = stablehlo.exponential %886 : tensor<8x32x100x1024xf32> | |
| %888 = stablehlo.reduce(%887 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %889 = stablehlo.broadcast_in_dim %888, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %890 = stablehlo.divide %887, %889 : tensor<8x32x100x1024xf32> | |
| %891 = stablehlo.reshape %890 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %892 = stablehlo.transpose %arg144, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %893 = stablehlo.dot_general %841, %892, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %894 = stablehlo.reshape %893 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %895 = "stablehlo.scatter"(%arg239, %39, %894) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %896 = stablehlo.transpose %895, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %897 = stablehlo.reshape %896 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %898 = stablehlo.dot_general %891, %897, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %899 = stablehlo.reshape %898 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %900 = stablehlo.transpose %899, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %901 = stablehlo.reshape %900 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %902 = stablehlo.transpose %arg143, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %903 = stablehlo.dot_general %901, %902, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %904 = stablehlo.reshape %903 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %905 = stablehlo.add %829, %904 : tensor<8x100x4096xf32> | |
| %906 = stablehlo.power %905, %cst_3 : tensor<8x100x4096xf32> | |
| %907 = stablehlo.reduce(%906 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %908 = stablehlo.multiply %907, %cst_2 : tensor<8x100xf32> | |
| %909 = stablehlo.reshape %908 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %910 = stablehlo.add %909, %cst_1 : tensor<8x100x1xf32> | |
| %911 = stablehlo.rsqrt %910 : tensor<8x100x1xf32> | |
| %912 = stablehlo.reshape %911 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %913 = stablehlo.broadcast_in_dim %912, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %914 = stablehlo.multiply %905, %913 : tensor<8x100x4096xf32> | |
| %915 = stablehlo.broadcast_in_dim %arg142, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %916 = stablehlo.multiply %914, %915 : tensor<8x100x4096xf32> | |
| %917 = stablehlo.reshape %916 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %918 = stablehlo.transpose %arg243, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %919 = stablehlo.dot_general %917, %918, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %920 = stablehlo.reshape %919 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %921 = stablehlo.logistic %920 : tensor<8x100x11008xf32> | |
| %922 = stablehlo.multiply %920, %921 : tensor<8x100x11008xf32> | |
| %923 = stablehlo.transpose %arg141, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %924 = stablehlo.dot_general %917, %923, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %925 = stablehlo.reshape %924 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %926 = stablehlo.multiply %922, %925 : tensor<8x100x11008xf32> | |
| %927 = stablehlo.reshape %926 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %928 = stablehlo.transpose %arg140, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %929 = stablehlo.dot_general %927, %928, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %930 = stablehlo.reshape %929 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %931 = stablehlo.add %905, %930 : tensor<8x100x4096xf32> | |
| %932 = stablehlo.power %931, %cst_3 : tensor<8x100x4096xf32> | |
| %933 = stablehlo.reduce(%932 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %934 = stablehlo.multiply %933, %cst_2 : tensor<8x100xf32> | |
| %935 = stablehlo.reshape %934 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %936 = stablehlo.add %935, %cst_1 : tensor<8x100x1xf32> | |
| %937 = stablehlo.rsqrt %936 : tensor<8x100x1xf32> | |
| %938 = stablehlo.reshape %937 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %939 = stablehlo.broadcast_in_dim %938, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %940 = stablehlo.multiply %931, %939 : tensor<8x100x4096xf32> | |
| %941 = stablehlo.broadcast_in_dim %arg139, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %942 = stablehlo.multiply %940, %941 : tensor<8x100x4096xf32> | |
| %943 = stablehlo.reshape %942 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %944 = stablehlo.transpose %arg247, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %945 = stablehlo.dot_general %943, %944, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %946 = stablehlo.reshape %945 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %947 = stablehlo.transpose %946, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %948 = stablehlo.reshape %947 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %949 = stablehlo.slice %948 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %950 = stablehlo.reshape %949 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %951 = stablehlo.slice %948 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %952 = stablehlo.reshape %951 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %953 = stablehlo.complex %950, %952 : tensor<256x100x64xcomplex<f32>> | |
| %954 = stablehlo.multiply %953, %28 : tensor<256x100x64xcomplex<f32>> | |
| %955 = stablehlo.real %954 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %956 = stablehlo.reshape %955 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %957 = stablehlo.imag %954 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %958 = stablehlo.reshape %957 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %959 = stablehlo.concatenate %956, %958, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %960 = stablehlo.reshape %959 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %961 = stablehlo.transpose %arg245, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %962 = stablehlo.dot_general %943, %961, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %963 = stablehlo.reshape %962 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %964 = stablehlo.transpose %963, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %965 = stablehlo.reshape %964 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %966 = stablehlo.slice %965 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %967 = stablehlo.reshape %966 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %968 = stablehlo.slice %965 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %969 = stablehlo.reshape %968 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %970 = stablehlo.complex %967, %969 : tensor<256x100x64xcomplex<f32>> | |
| %971 = stablehlo.multiply %970, %28 : tensor<256x100x64xcomplex<f32>> | |
| %972 = stablehlo.real %971 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %973 = stablehlo.reshape %972 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %974 = stablehlo.imag %971 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %975 = stablehlo.reshape %974 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %976 = stablehlo.concatenate %973, %975, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %977 = stablehlo.reshape %976 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %978 = stablehlo.transpose %977, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %979 = "stablehlo.scatter"(%arg246, %39, %978) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %980 = stablehlo.transpose %979, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %981 = stablehlo.reshape %980 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %982 = stablehlo.dot_general %960, %981, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %983 = stablehlo.reshape %982 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %984 = stablehlo.divide %983, %cst : tensor<8x32x100x1024xf32> | |
| %985 = stablehlo.add %984, %66 : tensor<8x32x100x1024xf32> | |
| %986 = stablehlo.reduce(%985 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %987 = stablehlo.broadcast_in_dim %986, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %988 = stablehlo.subtract %985, %987 : tensor<8x32x100x1024xf32> | |
| %989 = stablehlo.exponential %988 : tensor<8x32x100x1024xf32> | |
| %990 = stablehlo.reduce(%989 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %991 = stablehlo.broadcast_in_dim %990, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %992 = stablehlo.divide %989, %991 : tensor<8x32x100x1024xf32> | |
| %993 = stablehlo.reshape %992 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %994 = stablehlo.transpose %arg138, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %995 = stablehlo.dot_general %943, %994, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %996 = stablehlo.reshape %995 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %997 = "stablehlo.scatter"(%arg244, %39, %996) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %998 = stablehlo.transpose %997, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %999 = stablehlo.reshape %998 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1000 = stablehlo.dot_general %993, %999, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1001 = stablehlo.reshape %1000 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1002 = stablehlo.transpose %1001, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1003 = stablehlo.reshape %1002 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1004 = stablehlo.transpose %arg137, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1005 = stablehlo.dot_general %1003, %1004, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1006 = stablehlo.reshape %1005 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1007 = stablehlo.add %931, %1006 : tensor<8x100x4096xf32> | |
| %1008 = stablehlo.power %1007, %cst_3 : tensor<8x100x4096xf32> | |
| %1009 = stablehlo.reduce(%1008 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1010 = stablehlo.multiply %1009, %cst_2 : tensor<8x100xf32> | |
| %1011 = stablehlo.reshape %1010 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1012 = stablehlo.add %1011, %cst_1 : tensor<8x100x1xf32> | |
| %1013 = stablehlo.rsqrt %1012 : tensor<8x100x1xf32> | |
| %1014 = stablehlo.reshape %1013 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1015 = stablehlo.broadcast_in_dim %1014, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1016 = stablehlo.multiply %1007, %1015 : tensor<8x100x4096xf32> | |
| %1017 = stablehlo.broadcast_in_dim %arg136, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1018 = stablehlo.multiply %1016, %1017 : tensor<8x100x4096xf32> | |
| %1019 = stablehlo.reshape %1018 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1020 = stablehlo.transpose %arg248, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1021 = stablehlo.dot_general %1019, %1020, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1022 = stablehlo.reshape %1021 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1023 = stablehlo.logistic %1022 : tensor<8x100x11008xf32> | |
| %1024 = stablehlo.multiply %1022, %1023 : tensor<8x100x11008xf32> | |
| %1025 = stablehlo.transpose %arg135, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1026 = stablehlo.dot_general %1019, %1025, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1027 = stablehlo.reshape %1026 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1028 = stablehlo.multiply %1024, %1027 : tensor<8x100x11008xf32> | |
| %1029 = stablehlo.reshape %1028 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1030 = stablehlo.transpose %arg134, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1031 = stablehlo.dot_general %1029, %1030, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1032 = stablehlo.reshape %1031 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1033 = stablehlo.add %1007, %1032 : tensor<8x100x4096xf32> | |
| %1034 = stablehlo.power %1033, %cst_3 : tensor<8x100x4096xf32> | |
| %1035 = stablehlo.reduce(%1034 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1036 = stablehlo.multiply %1035, %cst_2 : tensor<8x100xf32> | |
| %1037 = stablehlo.reshape %1036 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1038 = stablehlo.add %1037, %cst_1 : tensor<8x100x1xf32> | |
| %1039 = stablehlo.rsqrt %1038 : tensor<8x100x1xf32> | |
| %1040 = stablehlo.reshape %1039 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1041 = stablehlo.broadcast_in_dim %1040, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1042 = stablehlo.multiply %1033, %1041 : tensor<8x100x4096xf32> | |
| %1043 = stablehlo.broadcast_in_dim %arg133, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1044 = stablehlo.multiply %1042, %1043 : tensor<8x100x4096xf32> | |
| %1045 = stablehlo.reshape %1044 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1046 = stablehlo.transpose %arg252, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1047 = stablehlo.dot_general %1045, %1046, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1048 = stablehlo.reshape %1047 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1049 = stablehlo.transpose %1048, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1050 = stablehlo.reshape %1049 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1051 = stablehlo.slice %1050 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1052 = stablehlo.reshape %1051 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1053 = stablehlo.slice %1050 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1054 = stablehlo.reshape %1053 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1055 = stablehlo.complex %1052, %1054 : tensor<256x100x64xcomplex<f32>> | |
| %1056 = stablehlo.multiply %1055, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1057 = stablehlo.real %1056 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1058 = stablehlo.reshape %1057 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1059 = stablehlo.imag %1056 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1060 = stablehlo.reshape %1059 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1061 = stablehlo.concatenate %1058, %1060, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1062 = stablehlo.reshape %1061 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1063 = stablehlo.transpose %arg250, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1064 = stablehlo.dot_general %1045, %1063, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1065 = stablehlo.reshape %1064 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1066 = stablehlo.transpose %1065, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1067 = stablehlo.reshape %1066 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1068 = stablehlo.slice %1067 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1069 = stablehlo.reshape %1068 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1070 = stablehlo.slice %1067 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1071 = stablehlo.reshape %1070 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1072 = stablehlo.complex %1069, %1071 : tensor<256x100x64xcomplex<f32>> | |
| %1073 = stablehlo.multiply %1072, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1074 = stablehlo.real %1073 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1075 = stablehlo.reshape %1074 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1076 = stablehlo.imag %1073 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1077 = stablehlo.reshape %1076 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1078 = stablehlo.concatenate %1075, %1077, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1079 = stablehlo.reshape %1078 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1080 = stablehlo.transpose %1079, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1081 = "stablehlo.scatter"(%arg251, %39, %1080) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1082 = stablehlo.transpose %1081, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1083 = stablehlo.reshape %1082 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1084 = stablehlo.dot_general %1062, %1083, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1085 = stablehlo.reshape %1084 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1086 = stablehlo.divide %1085, %cst : tensor<8x32x100x1024xf32> | |
| %1087 = stablehlo.add %1086, %66 : tensor<8x32x100x1024xf32> | |
| %1088 = stablehlo.reduce(%1087 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1089 = stablehlo.broadcast_in_dim %1088, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1090 = stablehlo.subtract %1087, %1089 : tensor<8x32x100x1024xf32> | |
| %1091 = stablehlo.exponential %1090 : tensor<8x32x100x1024xf32> | |
| %1092 = stablehlo.reduce(%1091 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1093 = stablehlo.broadcast_in_dim %1092, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1094 = stablehlo.divide %1091, %1093 : tensor<8x32x100x1024xf32> | |
| %1095 = stablehlo.reshape %1094 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1096 = stablehlo.transpose %arg132, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1097 = stablehlo.dot_general %1045, %1096, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1098 = stablehlo.reshape %1097 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1099 = "stablehlo.scatter"(%arg249, %39, %1098) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1100 = stablehlo.transpose %1099, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1101 = stablehlo.reshape %1100 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1102 = stablehlo.dot_general %1095, %1101, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1103 = stablehlo.reshape %1102 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1104 = stablehlo.transpose %1103, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1105 = stablehlo.reshape %1104 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1106 = stablehlo.transpose %arg131, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1107 = stablehlo.dot_general %1105, %1106, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1108 = stablehlo.reshape %1107 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1109 = stablehlo.add %1033, %1108 : tensor<8x100x4096xf32> | |
| %1110 = stablehlo.power %1109, %cst_3 : tensor<8x100x4096xf32> | |
| %1111 = stablehlo.reduce(%1110 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1112 = stablehlo.multiply %1111, %cst_2 : tensor<8x100xf32> | |
| %1113 = stablehlo.reshape %1112 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1114 = stablehlo.add %1113, %cst_1 : tensor<8x100x1xf32> | |
| %1115 = stablehlo.rsqrt %1114 : tensor<8x100x1xf32> | |
| %1116 = stablehlo.reshape %1115 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1117 = stablehlo.broadcast_in_dim %1116, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1118 = stablehlo.multiply %1109, %1117 : tensor<8x100x4096xf32> | |
| %1119 = stablehlo.broadcast_in_dim %arg130, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1120 = stablehlo.multiply %1118, %1119 : tensor<8x100x4096xf32> | |
| %1121 = stablehlo.reshape %1120 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1122 = stablehlo.transpose %arg253, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1123 = stablehlo.dot_general %1121, %1122, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1124 = stablehlo.reshape %1123 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1125 = stablehlo.logistic %1124 : tensor<8x100x11008xf32> | |
| %1126 = stablehlo.multiply %1124, %1125 : tensor<8x100x11008xf32> | |
| %1127 = stablehlo.transpose %arg129, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1128 = stablehlo.dot_general %1121, %1127, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1129 = stablehlo.reshape %1128 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1130 = stablehlo.multiply %1126, %1129 : tensor<8x100x11008xf32> | |
| %1131 = stablehlo.reshape %1130 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1132 = stablehlo.transpose %arg128, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1133 = stablehlo.dot_general %1131, %1132, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1134 = stablehlo.reshape %1133 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1135 = stablehlo.add %1109, %1134 : tensor<8x100x4096xf32> | |
| %1136 = stablehlo.power %1135, %cst_3 : tensor<8x100x4096xf32> | |
| %1137 = stablehlo.reduce(%1136 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1138 = stablehlo.multiply %1137, %cst_2 : tensor<8x100xf32> | |
| %1139 = stablehlo.reshape %1138 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1140 = stablehlo.add %1139, %cst_1 : tensor<8x100x1xf32> | |
| %1141 = stablehlo.rsqrt %1140 : tensor<8x100x1xf32> | |
| %1142 = stablehlo.reshape %1141 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1143 = stablehlo.broadcast_in_dim %1142, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1144 = stablehlo.multiply %1135, %1143 : tensor<8x100x4096xf32> | |
| %1145 = stablehlo.broadcast_in_dim %arg127, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1146 = stablehlo.multiply %1144, %1145 : tensor<8x100x4096xf32> | |
| %1147 = stablehlo.reshape %1146 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1148 = stablehlo.transpose %arg257, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1149 = stablehlo.dot_general %1147, %1148, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1150 = stablehlo.reshape %1149 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1151 = stablehlo.transpose %1150, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1152 = stablehlo.reshape %1151 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1153 = stablehlo.slice %1152 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1154 = stablehlo.reshape %1153 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1155 = stablehlo.slice %1152 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1156 = stablehlo.reshape %1155 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1157 = stablehlo.complex %1154, %1156 : tensor<256x100x64xcomplex<f32>> | |
| %1158 = stablehlo.multiply %1157, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1159 = stablehlo.real %1158 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1160 = stablehlo.reshape %1159 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1161 = stablehlo.imag %1158 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1162 = stablehlo.reshape %1161 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1163 = stablehlo.concatenate %1160, %1162, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1164 = stablehlo.reshape %1163 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1165 = stablehlo.transpose %arg255, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1166 = stablehlo.dot_general %1147, %1165, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1167 = stablehlo.reshape %1166 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1168 = stablehlo.transpose %1167, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1169 = stablehlo.reshape %1168 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1170 = stablehlo.slice %1169 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1171 = stablehlo.reshape %1170 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1172 = stablehlo.slice %1169 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1173 = stablehlo.reshape %1172 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1174 = stablehlo.complex %1171, %1173 : tensor<256x100x64xcomplex<f32>> | |
| %1175 = stablehlo.multiply %1174, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1176 = stablehlo.real %1175 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1177 = stablehlo.reshape %1176 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1178 = stablehlo.imag %1175 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1179 = stablehlo.reshape %1178 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1180 = stablehlo.concatenate %1177, %1179, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1181 = stablehlo.reshape %1180 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1182 = stablehlo.transpose %1181, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1183 = "stablehlo.scatter"(%arg256, %39, %1182) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1184 = stablehlo.transpose %1183, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1185 = stablehlo.reshape %1184 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1186 = stablehlo.dot_general %1164, %1185, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1187 = stablehlo.reshape %1186 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1188 = stablehlo.divide %1187, %cst : tensor<8x32x100x1024xf32> | |
| %1189 = stablehlo.add %1188, %66 : tensor<8x32x100x1024xf32> | |
| %1190 = stablehlo.reduce(%1189 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1191 = stablehlo.broadcast_in_dim %1190, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1192 = stablehlo.subtract %1189, %1191 : tensor<8x32x100x1024xf32> | |
| %1193 = stablehlo.exponential %1192 : tensor<8x32x100x1024xf32> | |
| %1194 = stablehlo.reduce(%1193 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1195 = stablehlo.broadcast_in_dim %1194, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1196 = stablehlo.divide %1193, %1195 : tensor<8x32x100x1024xf32> | |
| %1197 = stablehlo.reshape %1196 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1198 = stablehlo.transpose %arg126, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1199 = stablehlo.dot_general %1147, %1198, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1200 = stablehlo.reshape %1199 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1201 = "stablehlo.scatter"(%arg254, %39, %1200) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1202 = stablehlo.transpose %1201, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1203 = stablehlo.reshape %1202 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1204 = stablehlo.dot_general %1197, %1203, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1205 = stablehlo.reshape %1204 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1206 = stablehlo.transpose %1205, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1207 = stablehlo.reshape %1206 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1208 = stablehlo.transpose %arg125, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1209 = stablehlo.dot_general %1207, %1208, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1210 = stablehlo.reshape %1209 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1211 = stablehlo.add %1135, %1210 : tensor<8x100x4096xf32> | |
| %1212 = stablehlo.power %1211, %cst_3 : tensor<8x100x4096xf32> | |
| %1213 = stablehlo.reduce(%1212 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1214 = stablehlo.multiply %1213, %cst_2 : tensor<8x100xf32> | |
| %1215 = stablehlo.reshape %1214 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1216 = stablehlo.add %1215, %cst_1 : tensor<8x100x1xf32> | |
| %1217 = stablehlo.rsqrt %1216 : tensor<8x100x1xf32> | |
| %1218 = stablehlo.reshape %1217 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1219 = stablehlo.broadcast_in_dim %1218, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1220 = stablehlo.multiply %1211, %1219 : tensor<8x100x4096xf32> | |
| %1221 = stablehlo.broadcast_in_dim %arg124, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1222 = stablehlo.multiply %1220, %1221 : tensor<8x100x4096xf32> | |
| %1223 = stablehlo.reshape %1222 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1224 = stablehlo.transpose %arg258, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1225 = stablehlo.dot_general %1223, %1224, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1226 = stablehlo.reshape %1225 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1227 = stablehlo.logistic %1226 : tensor<8x100x11008xf32> | |
| %1228 = stablehlo.multiply %1226, %1227 : tensor<8x100x11008xf32> | |
| %1229 = stablehlo.transpose %arg123, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1230 = stablehlo.dot_general %1223, %1229, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1231 = stablehlo.reshape %1230 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1232 = stablehlo.multiply %1228, %1231 : tensor<8x100x11008xf32> | |
| %1233 = stablehlo.reshape %1232 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1234 = stablehlo.transpose %arg122, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1235 = stablehlo.dot_general %1233, %1234, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1236 = stablehlo.reshape %1235 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1237 = stablehlo.add %1211, %1236 : tensor<8x100x4096xf32> | |
| %1238 = stablehlo.power %1237, %cst_3 : tensor<8x100x4096xf32> | |
| %1239 = stablehlo.reduce(%1238 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1240 = stablehlo.multiply %1239, %cst_2 : tensor<8x100xf32> | |
| %1241 = stablehlo.reshape %1240 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1242 = stablehlo.add %1241, %cst_1 : tensor<8x100x1xf32> | |
| %1243 = stablehlo.rsqrt %1242 : tensor<8x100x1xf32> | |
| %1244 = stablehlo.reshape %1243 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1245 = stablehlo.broadcast_in_dim %1244, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1246 = stablehlo.multiply %1237, %1245 : tensor<8x100x4096xf32> | |
| %1247 = stablehlo.broadcast_in_dim %arg121, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1248 = stablehlo.multiply %1246, %1247 : tensor<8x100x4096xf32> | |
| %1249 = stablehlo.reshape %1248 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1250 = stablehlo.transpose %arg262, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1251 = stablehlo.dot_general %1249, %1250, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1252 = stablehlo.reshape %1251 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1253 = stablehlo.transpose %1252, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1254 = stablehlo.reshape %1253 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1255 = stablehlo.slice %1254 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1256 = stablehlo.reshape %1255 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1257 = stablehlo.slice %1254 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1258 = stablehlo.reshape %1257 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1259 = stablehlo.complex %1256, %1258 : tensor<256x100x64xcomplex<f32>> | |
| %1260 = stablehlo.multiply %1259, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1261 = stablehlo.real %1260 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1262 = stablehlo.reshape %1261 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1263 = stablehlo.imag %1260 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1264 = stablehlo.reshape %1263 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1265 = stablehlo.concatenate %1262, %1264, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1266 = stablehlo.reshape %1265 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1267 = stablehlo.transpose %arg260, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1268 = stablehlo.dot_general %1249, %1267, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1269 = stablehlo.reshape %1268 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1270 = stablehlo.transpose %1269, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1271 = stablehlo.reshape %1270 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1272 = stablehlo.slice %1271 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1273 = stablehlo.reshape %1272 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1274 = stablehlo.slice %1271 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1275 = stablehlo.reshape %1274 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1276 = stablehlo.complex %1273, %1275 : tensor<256x100x64xcomplex<f32>> | |
| %1277 = stablehlo.multiply %1276, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1278 = stablehlo.real %1277 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1279 = stablehlo.reshape %1278 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1280 = stablehlo.imag %1277 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1281 = stablehlo.reshape %1280 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1282 = stablehlo.concatenate %1279, %1281, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1283 = stablehlo.reshape %1282 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1284 = stablehlo.transpose %1283, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1285 = "stablehlo.scatter"(%arg261, %39, %1284) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1286 = stablehlo.transpose %1285, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1287 = stablehlo.reshape %1286 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1288 = stablehlo.dot_general %1266, %1287, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1289 = stablehlo.reshape %1288 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1290 = stablehlo.divide %1289, %cst : tensor<8x32x100x1024xf32> | |
| %1291 = stablehlo.add %1290, %66 : tensor<8x32x100x1024xf32> | |
| %1292 = stablehlo.reduce(%1291 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1293 = stablehlo.broadcast_in_dim %1292, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1294 = stablehlo.subtract %1291, %1293 : tensor<8x32x100x1024xf32> | |
| %1295 = stablehlo.exponential %1294 : tensor<8x32x100x1024xf32> | |
| %1296 = stablehlo.reduce(%1295 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1297 = stablehlo.broadcast_in_dim %1296, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1298 = stablehlo.divide %1295, %1297 : tensor<8x32x100x1024xf32> | |
| %1299 = stablehlo.reshape %1298 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1300 = stablehlo.transpose %arg120, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1301 = stablehlo.dot_general %1249, %1300, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1302 = stablehlo.reshape %1301 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1303 = "stablehlo.scatter"(%arg259, %39, %1302) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1304 = stablehlo.transpose %1303, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1305 = stablehlo.reshape %1304 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1306 = stablehlo.dot_general %1299, %1305, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1307 = stablehlo.reshape %1306 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1308 = stablehlo.transpose %1307, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1309 = stablehlo.reshape %1308 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1310 = stablehlo.transpose %arg119, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1311 = stablehlo.dot_general %1309, %1310, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1312 = stablehlo.reshape %1311 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1313 = stablehlo.add %1237, %1312 : tensor<8x100x4096xf32> | |
| %1314 = stablehlo.power %1313, %cst_3 : tensor<8x100x4096xf32> | |
| %1315 = stablehlo.reduce(%1314 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1316 = stablehlo.multiply %1315, %cst_2 : tensor<8x100xf32> | |
| %1317 = stablehlo.reshape %1316 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1318 = stablehlo.add %1317, %cst_1 : tensor<8x100x1xf32> | |
| %1319 = stablehlo.rsqrt %1318 : tensor<8x100x1xf32> | |
| %1320 = stablehlo.reshape %1319 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1321 = stablehlo.broadcast_in_dim %1320, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1322 = stablehlo.multiply %1313, %1321 : tensor<8x100x4096xf32> | |
| %1323 = stablehlo.broadcast_in_dim %arg118, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1324 = stablehlo.multiply %1322, %1323 : tensor<8x100x4096xf32> | |
| %1325 = stablehlo.reshape %1324 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1326 = stablehlo.transpose %arg263, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1327 = stablehlo.dot_general %1325, %1326, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1328 = stablehlo.reshape %1327 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1329 = stablehlo.logistic %1328 : tensor<8x100x11008xf32> | |
| %1330 = stablehlo.multiply %1328, %1329 : tensor<8x100x11008xf32> | |
| %1331 = stablehlo.transpose %arg117, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1332 = stablehlo.dot_general %1325, %1331, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1333 = stablehlo.reshape %1332 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1334 = stablehlo.multiply %1330, %1333 : tensor<8x100x11008xf32> | |
| %1335 = stablehlo.reshape %1334 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1336 = stablehlo.transpose %arg116, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1337 = stablehlo.dot_general %1335, %1336, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1338 = stablehlo.reshape %1337 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1339 = stablehlo.add %1313, %1338 : tensor<8x100x4096xf32> | |
| %1340 = stablehlo.power %1339, %cst_3 : tensor<8x100x4096xf32> | |
| %1341 = stablehlo.reduce(%1340 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1342 = stablehlo.multiply %1341, %cst_2 : tensor<8x100xf32> | |
| %1343 = stablehlo.reshape %1342 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1344 = stablehlo.add %1343, %cst_1 : tensor<8x100x1xf32> | |
| %1345 = stablehlo.rsqrt %1344 : tensor<8x100x1xf32> | |
| %1346 = stablehlo.reshape %1345 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1347 = stablehlo.broadcast_in_dim %1346, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1348 = stablehlo.multiply %1339, %1347 : tensor<8x100x4096xf32> | |
| %1349 = stablehlo.broadcast_in_dim %arg115, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1350 = stablehlo.multiply %1348, %1349 : tensor<8x100x4096xf32> | |
| %1351 = stablehlo.reshape %1350 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1352 = stablehlo.transpose %arg267, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1353 = stablehlo.dot_general %1351, %1352, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1354 = stablehlo.reshape %1353 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1355 = stablehlo.transpose %1354, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1356 = stablehlo.reshape %1355 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1357 = stablehlo.slice %1356 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1358 = stablehlo.reshape %1357 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1359 = stablehlo.slice %1356 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1360 = stablehlo.reshape %1359 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1361 = stablehlo.complex %1358, %1360 : tensor<256x100x64xcomplex<f32>> | |
| %1362 = stablehlo.multiply %1361, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1363 = stablehlo.real %1362 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1364 = stablehlo.reshape %1363 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1365 = stablehlo.imag %1362 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1366 = stablehlo.reshape %1365 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1367 = stablehlo.concatenate %1364, %1366, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1368 = stablehlo.reshape %1367 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1369 = stablehlo.transpose %arg265, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1370 = stablehlo.dot_general %1351, %1369, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1371 = stablehlo.reshape %1370 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1372 = stablehlo.transpose %1371, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1373 = stablehlo.reshape %1372 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1374 = stablehlo.slice %1373 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1375 = stablehlo.reshape %1374 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1376 = stablehlo.slice %1373 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1377 = stablehlo.reshape %1376 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1378 = stablehlo.complex %1375, %1377 : tensor<256x100x64xcomplex<f32>> | |
| %1379 = stablehlo.multiply %1378, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1380 = stablehlo.real %1379 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1381 = stablehlo.reshape %1380 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1382 = stablehlo.imag %1379 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1383 = stablehlo.reshape %1382 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1384 = stablehlo.concatenate %1381, %1383, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1385 = stablehlo.reshape %1384 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1386 = stablehlo.transpose %1385, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1387 = "stablehlo.scatter"(%arg266, %39, %1386) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1388 = stablehlo.transpose %1387, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1389 = stablehlo.reshape %1388 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1390 = stablehlo.dot_general %1368, %1389, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1391 = stablehlo.reshape %1390 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1392 = stablehlo.divide %1391, %cst : tensor<8x32x100x1024xf32> | |
| %1393 = stablehlo.add %1392, %66 : tensor<8x32x100x1024xf32> | |
| %1394 = stablehlo.reduce(%1393 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1395 = stablehlo.broadcast_in_dim %1394, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1396 = stablehlo.subtract %1393, %1395 : tensor<8x32x100x1024xf32> | |
| %1397 = stablehlo.exponential %1396 : tensor<8x32x100x1024xf32> | |
| %1398 = stablehlo.reduce(%1397 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1399 = stablehlo.broadcast_in_dim %1398, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1400 = stablehlo.divide %1397, %1399 : tensor<8x32x100x1024xf32> | |
| %1401 = stablehlo.reshape %1400 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1402 = stablehlo.transpose %arg114, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1403 = stablehlo.dot_general %1351, %1402, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1404 = stablehlo.reshape %1403 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1405 = "stablehlo.scatter"(%arg264, %39, %1404) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1406 = stablehlo.transpose %1405, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1407 = stablehlo.reshape %1406 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1408 = stablehlo.dot_general %1401, %1407, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1409 = stablehlo.reshape %1408 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1410 = stablehlo.transpose %1409, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1411 = stablehlo.reshape %1410 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1412 = stablehlo.transpose %arg113, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1413 = stablehlo.dot_general %1411, %1412, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1414 = stablehlo.reshape %1413 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1415 = stablehlo.add %1339, %1414 : tensor<8x100x4096xf32> | |
| %1416 = stablehlo.power %1415, %cst_3 : tensor<8x100x4096xf32> | |
| %1417 = stablehlo.reduce(%1416 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1418 = stablehlo.multiply %1417, %cst_2 : tensor<8x100xf32> | |
| %1419 = stablehlo.reshape %1418 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1420 = stablehlo.add %1419, %cst_1 : tensor<8x100x1xf32> | |
| %1421 = stablehlo.rsqrt %1420 : tensor<8x100x1xf32> | |
| %1422 = stablehlo.reshape %1421 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1423 = stablehlo.broadcast_in_dim %1422, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1424 = stablehlo.multiply %1415, %1423 : tensor<8x100x4096xf32> | |
| %1425 = stablehlo.broadcast_in_dim %arg112, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1426 = stablehlo.multiply %1424, %1425 : tensor<8x100x4096xf32> | |
| %1427 = stablehlo.reshape %1426 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1428 = stablehlo.transpose %arg268, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1429 = stablehlo.dot_general %1427, %1428, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1430 = stablehlo.reshape %1429 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1431 = stablehlo.logistic %1430 : tensor<8x100x11008xf32> | |
| %1432 = stablehlo.multiply %1430, %1431 : tensor<8x100x11008xf32> | |
| %1433 = stablehlo.transpose %arg111, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1434 = stablehlo.dot_general %1427, %1433, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1435 = stablehlo.reshape %1434 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1436 = stablehlo.multiply %1432, %1435 : tensor<8x100x11008xf32> | |
| %1437 = stablehlo.reshape %1436 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1438 = stablehlo.transpose %arg110, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1439 = stablehlo.dot_general %1437, %1438, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1440 = stablehlo.reshape %1439 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1441 = stablehlo.add %1415, %1440 : tensor<8x100x4096xf32> | |
| %1442 = stablehlo.power %1441, %cst_3 : tensor<8x100x4096xf32> | |
| %1443 = stablehlo.reduce(%1442 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1444 = stablehlo.multiply %1443, %cst_2 : tensor<8x100xf32> | |
| %1445 = stablehlo.reshape %1444 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1446 = stablehlo.add %1445, %cst_1 : tensor<8x100x1xf32> | |
| %1447 = stablehlo.rsqrt %1446 : tensor<8x100x1xf32> | |
| %1448 = stablehlo.reshape %1447 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1449 = stablehlo.broadcast_in_dim %1448, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1450 = stablehlo.multiply %1441, %1449 : tensor<8x100x4096xf32> | |
| %1451 = stablehlo.broadcast_in_dim %arg109, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1452 = stablehlo.multiply %1450, %1451 : tensor<8x100x4096xf32> | |
| %1453 = stablehlo.reshape %1452 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1454 = stablehlo.transpose %arg272, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1455 = stablehlo.dot_general %1453, %1454, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1456 = stablehlo.reshape %1455 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1457 = stablehlo.transpose %1456, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1458 = stablehlo.reshape %1457 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1459 = stablehlo.slice %1458 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1460 = stablehlo.reshape %1459 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1461 = stablehlo.slice %1458 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1462 = stablehlo.reshape %1461 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1463 = stablehlo.complex %1460, %1462 : tensor<256x100x64xcomplex<f32>> | |
| %1464 = stablehlo.multiply %1463, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1465 = stablehlo.real %1464 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1466 = stablehlo.reshape %1465 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1467 = stablehlo.imag %1464 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1468 = stablehlo.reshape %1467 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1469 = stablehlo.concatenate %1466, %1468, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1470 = stablehlo.reshape %1469 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1471 = stablehlo.transpose %arg270, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1472 = stablehlo.dot_general %1453, %1471, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1473 = stablehlo.reshape %1472 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1474 = stablehlo.transpose %1473, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1475 = stablehlo.reshape %1474 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1476 = stablehlo.slice %1475 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1477 = stablehlo.reshape %1476 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1478 = stablehlo.slice %1475 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1479 = stablehlo.reshape %1478 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1480 = stablehlo.complex %1477, %1479 : tensor<256x100x64xcomplex<f32>> | |
| %1481 = stablehlo.multiply %1480, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1482 = stablehlo.real %1481 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1483 = stablehlo.reshape %1482 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1484 = stablehlo.imag %1481 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1485 = stablehlo.reshape %1484 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1486 = stablehlo.concatenate %1483, %1485, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1487 = stablehlo.reshape %1486 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1488 = stablehlo.transpose %1487, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1489 = "stablehlo.scatter"(%arg271, %39, %1488) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1490 = stablehlo.transpose %1489, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1491 = stablehlo.reshape %1490 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1492 = stablehlo.dot_general %1470, %1491, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1493 = stablehlo.reshape %1492 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1494 = stablehlo.divide %1493, %cst : tensor<8x32x100x1024xf32> | |
| %1495 = stablehlo.add %1494, %66 : tensor<8x32x100x1024xf32> | |
| %1496 = stablehlo.reduce(%1495 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1497 = stablehlo.broadcast_in_dim %1496, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1498 = stablehlo.subtract %1495, %1497 : tensor<8x32x100x1024xf32> | |
| %1499 = stablehlo.exponential %1498 : tensor<8x32x100x1024xf32> | |
| %1500 = stablehlo.reduce(%1499 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1501 = stablehlo.broadcast_in_dim %1500, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1502 = stablehlo.divide %1499, %1501 : tensor<8x32x100x1024xf32> | |
| %1503 = stablehlo.reshape %1502 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1504 = stablehlo.transpose %arg108, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1505 = stablehlo.dot_general %1453, %1504, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1506 = stablehlo.reshape %1505 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1507 = "stablehlo.scatter"(%arg269, %39, %1506) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1508 = stablehlo.transpose %1507, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1509 = stablehlo.reshape %1508 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1510 = stablehlo.dot_general %1503, %1509, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1511 = stablehlo.reshape %1510 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1512 = stablehlo.transpose %1511, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1513 = stablehlo.reshape %1512 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1514 = stablehlo.transpose %arg107, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1515 = stablehlo.dot_general %1513, %1514, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1516 = stablehlo.reshape %1515 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1517 = stablehlo.add %1441, %1516 : tensor<8x100x4096xf32> | |
| %1518 = stablehlo.power %1517, %cst_3 : tensor<8x100x4096xf32> | |
| %1519 = stablehlo.reduce(%1518 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1520 = stablehlo.multiply %1519, %cst_2 : tensor<8x100xf32> | |
| %1521 = stablehlo.reshape %1520 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1522 = stablehlo.add %1521, %cst_1 : tensor<8x100x1xf32> | |
| %1523 = stablehlo.rsqrt %1522 : tensor<8x100x1xf32> | |
| %1524 = stablehlo.reshape %1523 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1525 = stablehlo.broadcast_in_dim %1524, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1526 = stablehlo.multiply %1517, %1525 : tensor<8x100x4096xf32> | |
| %1527 = stablehlo.broadcast_in_dim %arg106, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1528 = stablehlo.multiply %1526, %1527 : tensor<8x100x4096xf32> | |
| %1529 = stablehlo.reshape %1528 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1530 = stablehlo.transpose %arg273, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1531 = stablehlo.dot_general %1529, %1530, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1532 = stablehlo.reshape %1531 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1533 = stablehlo.logistic %1532 : tensor<8x100x11008xf32> | |
| %1534 = stablehlo.multiply %1532, %1533 : tensor<8x100x11008xf32> | |
| %1535 = stablehlo.transpose %arg105, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1536 = stablehlo.dot_general %1529, %1535, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1537 = stablehlo.reshape %1536 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1538 = stablehlo.multiply %1534, %1537 : tensor<8x100x11008xf32> | |
| %1539 = stablehlo.reshape %1538 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1540 = stablehlo.transpose %arg104, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1541 = stablehlo.dot_general %1539, %1540, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1542 = stablehlo.reshape %1541 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1543 = stablehlo.add %1517, %1542 : tensor<8x100x4096xf32> | |
| %1544 = stablehlo.power %1543, %cst_3 : tensor<8x100x4096xf32> | |
| %1545 = stablehlo.reduce(%1544 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1546 = stablehlo.multiply %1545, %cst_2 : tensor<8x100xf32> | |
| %1547 = stablehlo.reshape %1546 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1548 = stablehlo.add %1547, %cst_1 : tensor<8x100x1xf32> | |
| %1549 = stablehlo.rsqrt %1548 : tensor<8x100x1xf32> | |
| %1550 = stablehlo.reshape %1549 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1551 = stablehlo.broadcast_in_dim %1550, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1552 = stablehlo.multiply %1543, %1551 : tensor<8x100x4096xf32> | |
| %1553 = stablehlo.broadcast_in_dim %arg103, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1554 = stablehlo.multiply %1552, %1553 : tensor<8x100x4096xf32> | |
| %1555 = stablehlo.reshape %1554 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1556 = stablehlo.transpose %arg277, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1557 = stablehlo.dot_general %1555, %1556, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1558 = stablehlo.reshape %1557 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1559 = stablehlo.transpose %1558, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1560 = stablehlo.reshape %1559 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1561 = stablehlo.slice %1560 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1562 = stablehlo.reshape %1561 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1563 = stablehlo.slice %1560 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1564 = stablehlo.reshape %1563 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1565 = stablehlo.complex %1562, %1564 : tensor<256x100x64xcomplex<f32>> | |
| %1566 = stablehlo.multiply %1565, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1567 = stablehlo.real %1566 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1568 = stablehlo.reshape %1567 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1569 = stablehlo.imag %1566 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1570 = stablehlo.reshape %1569 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1571 = stablehlo.concatenate %1568, %1570, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1572 = stablehlo.reshape %1571 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1573 = stablehlo.transpose %arg275, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1574 = stablehlo.dot_general %1555, %1573, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1575 = stablehlo.reshape %1574 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1576 = stablehlo.transpose %1575, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1577 = stablehlo.reshape %1576 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1578 = stablehlo.slice %1577 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1579 = stablehlo.reshape %1578 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1580 = stablehlo.slice %1577 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1581 = stablehlo.reshape %1580 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1582 = stablehlo.complex %1579, %1581 : tensor<256x100x64xcomplex<f32>> | |
| %1583 = stablehlo.multiply %1582, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1584 = stablehlo.real %1583 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1585 = stablehlo.reshape %1584 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1586 = stablehlo.imag %1583 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1587 = stablehlo.reshape %1586 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1588 = stablehlo.concatenate %1585, %1587, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1589 = stablehlo.reshape %1588 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1590 = stablehlo.transpose %1589, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1591 = "stablehlo.scatter"(%arg276, %39, %1590) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1592 = stablehlo.transpose %1591, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1593 = stablehlo.reshape %1592 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1594 = stablehlo.dot_general %1572, %1593, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1595 = stablehlo.reshape %1594 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1596 = stablehlo.divide %1595, %cst : tensor<8x32x100x1024xf32> | |
| %1597 = stablehlo.add %1596, %66 : tensor<8x32x100x1024xf32> | |
| %1598 = stablehlo.reduce(%1597 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1599 = stablehlo.broadcast_in_dim %1598, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1600 = stablehlo.subtract %1597, %1599 : tensor<8x32x100x1024xf32> | |
| %1601 = stablehlo.exponential %1600 : tensor<8x32x100x1024xf32> | |
| %1602 = stablehlo.reduce(%1601 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1603 = stablehlo.broadcast_in_dim %1602, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1604 = stablehlo.divide %1601, %1603 : tensor<8x32x100x1024xf32> | |
| %1605 = stablehlo.reshape %1604 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1606 = stablehlo.transpose %arg102, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1607 = stablehlo.dot_general %1555, %1606, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1608 = stablehlo.reshape %1607 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1609 = "stablehlo.scatter"(%arg274, %39, %1608) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1610 = stablehlo.transpose %1609, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1611 = stablehlo.reshape %1610 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1612 = stablehlo.dot_general %1605, %1611, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1613 = stablehlo.reshape %1612 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1614 = stablehlo.transpose %1613, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1615 = stablehlo.reshape %1614 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1616 = stablehlo.transpose %arg101, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1617 = stablehlo.dot_general %1615, %1616, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1618 = stablehlo.reshape %1617 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1619 = stablehlo.add %1543, %1618 : tensor<8x100x4096xf32> | |
| %1620 = stablehlo.power %1619, %cst_3 : tensor<8x100x4096xf32> | |
| %1621 = stablehlo.reduce(%1620 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1622 = stablehlo.multiply %1621, %cst_2 : tensor<8x100xf32> | |
| %1623 = stablehlo.reshape %1622 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1624 = stablehlo.add %1623, %cst_1 : tensor<8x100x1xf32> | |
| %1625 = stablehlo.rsqrt %1624 : tensor<8x100x1xf32> | |
| %1626 = stablehlo.reshape %1625 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1627 = stablehlo.broadcast_in_dim %1626, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1628 = stablehlo.multiply %1619, %1627 : tensor<8x100x4096xf32> | |
| %1629 = stablehlo.broadcast_in_dim %arg100, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1630 = stablehlo.multiply %1628, %1629 : tensor<8x100x4096xf32> | |
| %1631 = stablehlo.reshape %1630 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1632 = stablehlo.transpose %arg278, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1633 = stablehlo.dot_general %1631, %1632, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1634 = stablehlo.reshape %1633 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1635 = stablehlo.logistic %1634 : tensor<8x100x11008xf32> | |
| %1636 = stablehlo.multiply %1634, %1635 : tensor<8x100x11008xf32> | |
| %1637 = stablehlo.transpose %arg99, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1638 = stablehlo.dot_general %1631, %1637, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1639 = stablehlo.reshape %1638 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1640 = stablehlo.multiply %1636, %1639 : tensor<8x100x11008xf32> | |
| %1641 = stablehlo.reshape %1640 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1642 = stablehlo.transpose %arg98, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1643 = stablehlo.dot_general %1641, %1642, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1644 = stablehlo.reshape %1643 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1645 = stablehlo.add %1619, %1644 : tensor<8x100x4096xf32> | |
| %1646 = stablehlo.power %1645, %cst_3 : tensor<8x100x4096xf32> | |
| %1647 = stablehlo.reduce(%1646 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1648 = stablehlo.multiply %1647, %cst_2 : tensor<8x100xf32> | |
| %1649 = stablehlo.reshape %1648 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1650 = stablehlo.add %1649, %cst_1 : tensor<8x100x1xf32> | |
| %1651 = stablehlo.rsqrt %1650 : tensor<8x100x1xf32> | |
| %1652 = stablehlo.reshape %1651 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1653 = stablehlo.broadcast_in_dim %1652, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1654 = stablehlo.multiply %1645, %1653 : tensor<8x100x4096xf32> | |
| %1655 = stablehlo.broadcast_in_dim %arg97, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1656 = stablehlo.multiply %1654, %1655 : tensor<8x100x4096xf32> | |
| %1657 = stablehlo.reshape %1656 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1658 = stablehlo.transpose %arg282, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1659 = stablehlo.dot_general %1657, %1658, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1660 = stablehlo.reshape %1659 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1661 = stablehlo.transpose %1660, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1662 = stablehlo.reshape %1661 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1663 = stablehlo.slice %1662 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1664 = stablehlo.reshape %1663 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1665 = stablehlo.slice %1662 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1666 = stablehlo.reshape %1665 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1667 = stablehlo.complex %1664, %1666 : tensor<256x100x64xcomplex<f32>> | |
| %1668 = stablehlo.multiply %1667, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1669 = stablehlo.real %1668 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1670 = stablehlo.reshape %1669 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1671 = stablehlo.imag %1668 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1672 = stablehlo.reshape %1671 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1673 = stablehlo.concatenate %1670, %1672, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1674 = stablehlo.reshape %1673 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1675 = stablehlo.transpose %arg280, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1676 = stablehlo.dot_general %1657, %1675, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1677 = stablehlo.reshape %1676 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1678 = stablehlo.transpose %1677, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1679 = stablehlo.reshape %1678 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1680 = stablehlo.slice %1679 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1681 = stablehlo.reshape %1680 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1682 = stablehlo.slice %1679 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1683 = stablehlo.reshape %1682 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1684 = stablehlo.complex %1681, %1683 : tensor<256x100x64xcomplex<f32>> | |
| %1685 = stablehlo.multiply %1684, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1686 = stablehlo.real %1685 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1687 = stablehlo.reshape %1686 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1688 = stablehlo.imag %1685 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1689 = stablehlo.reshape %1688 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1690 = stablehlo.concatenate %1687, %1689, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1691 = stablehlo.reshape %1690 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1692 = stablehlo.transpose %1691, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1693 = "stablehlo.scatter"(%arg281, %39, %1692) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1694 = stablehlo.transpose %1693, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1695 = stablehlo.reshape %1694 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1696 = stablehlo.dot_general %1674, %1695, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1697 = stablehlo.reshape %1696 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1698 = stablehlo.divide %1697, %cst : tensor<8x32x100x1024xf32> | |
| %1699 = stablehlo.add %1698, %66 : tensor<8x32x100x1024xf32> | |
| %1700 = stablehlo.reduce(%1699 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1701 = stablehlo.broadcast_in_dim %1700, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1702 = stablehlo.subtract %1699, %1701 : tensor<8x32x100x1024xf32> | |
| %1703 = stablehlo.exponential %1702 : tensor<8x32x100x1024xf32> | |
| %1704 = stablehlo.reduce(%1703 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1705 = stablehlo.broadcast_in_dim %1704, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1706 = stablehlo.divide %1703, %1705 : tensor<8x32x100x1024xf32> | |
| %1707 = stablehlo.reshape %1706 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1708 = stablehlo.transpose %arg96, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1709 = stablehlo.dot_general %1657, %1708, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1710 = stablehlo.reshape %1709 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1711 = "stablehlo.scatter"(%arg279, %39, %1710) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1712 = stablehlo.transpose %1711, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1713 = stablehlo.reshape %1712 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1714 = stablehlo.dot_general %1707, %1713, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1715 = stablehlo.reshape %1714 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1716 = stablehlo.transpose %1715, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1717 = stablehlo.reshape %1716 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1718 = stablehlo.transpose %arg95, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1719 = stablehlo.dot_general %1717, %1718, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1720 = stablehlo.reshape %1719 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1721 = stablehlo.add %1645, %1720 : tensor<8x100x4096xf32> | |
| %1722 = stablehlo.power %1721, %cst_3 : tensor<8x100x4096xf32> | |
| %1723 = stablehlo.reduce(%1722 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1724 = stablehlo.multiply %1723, %cst_2 : tensor<8x100xf32> | |
| %1725 = stablehlo.reshape %1724 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1726 = stablehlo.add %1725, %cst_1 : tensor<8x100x1xf32> | |
| %1727 = stablehlo.rsqrt %1726 : tensor<8x100x1xf32> | |
| %1728 = stablehlo.reshape %1727 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1729 = stablehlo.broadcast_in_dim %1728, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1730 = stablehlo.multiply %1721, %1729 : tensor<8x100x4096xf32> | |
| %1731 = stablehlo.broadcast_in_dim %arg94, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1732 = stablehlo.multiply %1730, %1731 : tensor<8x100x4096xf32> | |
| %1733 = stablehlo.reshape %1732 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1734 = stablehlo.transpose %arg283, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1735 = stablehlo.dot_general %1733, %1734, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1736 = stablehlo.reshape %1735 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1737 = stablehlo.logistic %1736 : tensor<8x100x11008xf32> | |
| %1738 = stablehlo.multiply %1736, %1737 : tensor<8x100x11008xf32> | |
| %1739 = stablehlo.transpose %arg93, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1740 = stablehlo.dot_general %1733, %1739, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1741 = stablehlo.reshape %1740 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1742 = stablehlo.multiply %1738, %1741 : tensor<8x100x11008xf32> | |
| %1743 = stablehlo.reshape %1742 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1744 = stablehlo.transpose %arg92, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1745 = stablehlo.dot_general %1743, %1744, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1746 = stablehlo.reshape %1745 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1747 = stablehlo.add %1721, %1746 : tensor<8x100x4096xf32> | |
| %1748 = stablehlo.power %1747, %cst_3 : tensor<8x100x4096xf32> | |
| %1749 = stablehlo.reduce(%1748 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1750 = stablehlo.multiply %1749, %cst_2 : tensor<8x100xf32> | |
| %1751 = stablehlo.reshape %1750 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1752 = stablehlo.add %1751, %cst_1 : tensor<8x100x1xf32> | |
| %1753 = stablehlo.rsqrt %1752 : tensor<8x100x1xf32> | |
| %1754 = stablehlo.reshape %1753 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1755 = stablehlo.broadcast_in_dim %1754, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1756 = stablehlo.multiply %1747, %1755 : tensor<8x100x4096xf32> | |
| %1757 = stablehlo.broadcast_in_dim %arg91, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1758 = stablehlo.multiply %1756, %1757 : tensor<8x100x4096xf32> | |
| %1759 = stablehlo.reshape %1758 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1760 = stablehlo.transpose %arg287, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1761 = stablehlo.dot_general %1759, %1760, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1762 = stablehlo.reshape %1761 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1763 = stablehlo.transpose %1762, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1764 = stablehlo.reshape %1763 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1765 = stablehlo.slice %1764 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1766 = stablehlo.reshape %1765 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1767 = stablehlo.slice %1764 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1768 = stablehlo.reshape %1767 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1769 = stablehlo.complex %1766, %1768 : tensor<256x100x64xcomplex<f32>> | |
| %1770 = stablehlo.multiply %1769, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1771 = stablehlo.real %1770 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1772 = stablehlo.reshape %1771 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1773 = stablehlo.imag %1770 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1774 = stablehlo.reshape %1773 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1775 = stablehlo.concatenate %1772, %1774, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1776 = stablehlo.reshape %1775 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1777 = stablehlo.transpose %arg285, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1778 = stablehlo.dot_general %1759, %1777, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1779 = stablehlo.reshape %1778 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1780 = stablehlo.transpose %1779, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1781 = stablehlo.reshape %1780 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1782 = stablehlo.slice %1781 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1783 = stablehlo.reshape %1782 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1784 = stablehlo.slice %1781 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1785 = stablehlo.reshape %1784 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1786 = stablehlo.complex %1783, %1785 : tensor<256x100x64xcomplex<f32>> | |
| %1787 = stablehlo.multiply %1786, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1788 = stablehlo.real %1787 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1789 = stablehlo.reshape %1788 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1790 = stablehlo.imag %1787 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1791 = stablehlo.reshape %1790 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1792 = stablehlo.concatenate %1789, %1791, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1793 = stablehlo.reshape %1792 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1794 = stablehlo.transpose %1793, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1795 = "stablehlo.scatter"(%arg286, %39, %1794) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1796 = stablehlo.transpose %1795, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1797 = stablehlo.reshape %1796 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1798 = stablehlo.dot_general %1776, %1797, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1799 = stablehlo.reshape %1798 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1800 = stablehlo.divide %1799, %cst : tensor<8x32x100x1024xf32> | |
| %1801 = stablehlo.add %1800, %66 : tensor<8x32x100x1024xf32> | |
| %1802 = stablehlo.reduce(%1801 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1803 = stablehlo.broadcast_in_dim %1802, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1804 = stablehlo.subtract %1801, %1803 : tensor<8x32x100x1024xf32> | |
| %1805 = stablehlo.exponential %1804 : tensor<8x32x100x1024xf32> | |
| %1806 = stablehlo.reduce(%1805 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1807 = stablehlo.broadcast_in_dim %1806, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1808 = stablehlo.divide %1805, %1807 : tensor<8x32x100x1024xf32> | |
| %1809 = stablehlo.reshape %1808 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1810 = stablehlo.transpose %arg90, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1811 = stablehlo.dot_general %1759, %1810, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1812 = stablehlo.reshape %1811 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1813 = "stablehlo.scatter"(%arg284, %39, %1812) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1814 = stablehlo.transpose %1813, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1815 = stablehlo.reshape %1814 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1816 = stablehlo.dot_general %1809, %1815, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1817 = stablehlo.reshape %1816 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1818 = stablehlo.transpose %1817, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1819 = stablehlo.reshape %1818 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1820 = stablehlo.transpose %arg89, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1821 = stablehlo.dot_general %1819, %1820, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1822 = stablehlo.reshape %1821 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1823 = stablehlo.add %1747, %1822 : tensor<8x100x4096xf32> | |
| %1824 = stablehlo.power %1823, %cst_3 : tensor<8x100x4096xf32> | |
| %1825 = stablehlo.reduce(%1824 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1826 = stablehlo.multiply %1825, %cst_2 : tensor<8x100xf32> | |
| %1827 = stablehlo.reshape %1826 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1828 = stablehlo.add %1827, %cst_1 : tensor<8x100x1xf32> | |
| %1829 = stablehlo.rsqrt %1828 : tensor<8x100x1xf32> | |
| %1830 = stablehlo.reshape %1829 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1831 = stablehlo.broadcast_in_dim %1830, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1832 = stablehlo.multiply %1823, %1831 : tensor<8x100x4096xf32> | |
| %1833 = stablehlo.broadcast_in_dim %arg88, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1834 = stablehlo.multiply %1832, %1833 : tensor<8x100x4096xf32> | |
| %1835 = stablehlo.reshape %1834 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1836 = stablehlo.transpose %arg288, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1837 = stablehlo.dot_general %1835, %1836, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1838 = stablehlo.reshape %1837 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1839 = stablehlo.logistic %1838 : tensor<8x100x11008xf32> | |
| %1840 = stablehlo.multiply %1838, %1839 : tensor<8x100x11008xf32> | |
| %1841 = stablehlo.transpose %arg87, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1842 = stablehlo.dot_general %1835, %1841, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1843 = stablehlo.reshape %1842 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1844 = stablehlo.multiply %1840, %1843 : tensor<8x100x11008xf32> | |
| %1845 = stablehlo.reshape %1844 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1846 = stablehlo.transpose %arg86, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1847 = stablehlo.dot_general %1845, %1846, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1848 = stablehlo.reshape %1847 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1849 = stablehlo.add %1823, %1848 : tensor<8x100x4096xf32> | |
| %1850 = stablehlo.power %1849, %cst_3 : tensor<8x100x4096xf32> | |
| %1851 = stablehlo.reduce(%1850 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1852 = stablehlo.multiply %1851, %cst_2 : tensor<8x100xf32> | |
| %1853 = stablehlo.reshape %1852 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1854 = stablehlo.add %1853, %cst_1 : tensor<8x100x1xf32> | |
| %1855 = stablehlo.rsqrt %1854 : tensor<8x100x1xf32> | |
| %1856 = stablehlo.reshape %1855 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1857 = stablehlo.broadcast_in_dim %1856, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1858 = stablehlo.multiply %1849, %1857 : tensor<8x100x4096xf32> | |
| %1859 = stablehlo.broadcast_in_dim %arg85, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1860 = stablehlo.multiply %1858, %1859 : tensor<8x100x4096xf32> | |
| %1861 = stablehlo.reshape %1860 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1862 = stablehlo.transpose %arg292, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1863 = stablehlo.dot_general %1861, %1862, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1864 = stablehlo.reshape %1863 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1865 = stablehlo.transpose %1864, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1866 = stablehlo.reshape %1865 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1867 = stablehlo.slice %1866 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1868 = stablehlo.reshape %1867 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1869 = stablehlo.slice %1866 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1870 = stablehlo.reshape %1869 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1871 = stablehlo.complex %1868, %1870 : tensor<256x100x64xcomplex<f32>> | |
| %1872 = stablehlo.multiply %1871, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1873 = stablehlo.real %1872 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1874 = stablehlo.reshape %1873 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1875 = stablehlo.imag %1872 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1876 = stablehlo.reshape %1875 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1877 = stablehlo.concatenate %1874, %1876, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1878 = stablehlo.reshape %1877 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1879 = stablehlo.transpose %arg290, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1880 = stablehlo.dot_general %1861, %1879, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1881 = stablehlo.reshape %1880 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1882 = stablehlo.transpose %1881, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1883 = stablehlo.reshape %1882 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1884 = stablehlo.slice %1883 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1885 = stablehlo.reshape %1884 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1886 = stablehlo.slice %1883 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1887 = stablehlo.reshape %1886 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1888 = stablehlo.complex %1885, %1887 : tensor<256x100x64xcomplex<f32>> | |
| %1889 = stablehlo.multiply %1888, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1890 = stablehlo.real %1889 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1891 = stablehlo.reshape %1890 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1892 = stablehlo.imag %1889 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1893 = stablehlo.reshape %1892 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1894 = stablehlo.concatenate %1891, %1893, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1895 = stablehlo.reshape %1894 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1896 = stablehlo.transpose %1895, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1897 = "stablehlo.scatter"(%arg291, %39, %1896) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1898 = stablehlo.transpose %1897, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %1899 = stablehlo.reshape %1898 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %1900 = stablehlo.dot_general %1878, %1899, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1901 = stablehlo.reshape %1900 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %1902 = stablehlo.divide %1901, %cst : tensor<8x32x100x1024xf32> | |
| %1903 = stablehlo.add %1902, %66 : tensor<8x32x100x1024xf32> | |
| %1904 = stablehlo.reduce(%1903 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1905 = stablehlo.broadcast_in_dim %1904, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1906 = stablehlo.subtract %1903, %1905 : tensor<8x32x100x1024xf32> | |
| %1907 = stablehlo.exponential %1906 : tensor<8x32x100x1024xf32> | |
| %1908 = stablehlo.reduce(%1907 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %1909 = stablehlo.broadcast_in_dim %1908, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %1910 = stablehlo.divide %1907, %1909 : tensor<8x32x100x1024xf32> | |
| %1911 = stablehlo.reshape %1910 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %1912 = stablehlo.transpose %arg84, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1913 = stablehlo.dot_general %1861, %1912, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1914 = stablehlo.reshape %1913 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1915 = "stablehlo.scatter"(%arg289, %39, %1914) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %1916 = stablehlo.transpose %1915, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %1917 = stablehlo.reshape %1916 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %1918 = stablehlo.dot_general %1911, %1917, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %1919 = stablehlo.reshape %1918 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1920 = stablehlo.transpose %1919, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1921 = stablehlo.reshape %1920 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %1922 = stablehlo.transpose %arg83, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1923 = stablehlo.dot_general %1921, %1922, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1924 = stablehlo.reshape %1923 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1925 = stablehlo.add %1849, %1924 : tensor<8x100x4096xf32> | |
| %1926 = stablehlo.power %1925, %cst_3 : tensor<8x100x4096xf32> | |
| %1927 = stablehlo.reduce(%1926 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1928 = stablehlo.multiply %1927, %cst_2 : tensor<8x100xf32> | |
| %1929 = stablehlo.reshape %1928 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1930 = stablehlo.add %1929, %cst_1 : tensor<8x100x1xf32> | |
| %1931 = stablehlo.rsqrt %1930 : tensor<8x100x1xf32> | |
| %1932 = stablehlo.reshape %1931 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1933 = stablehlo.broadcast_in_dim %1932, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1934 = stablehlo.multiply %1925, %1933 : tensor<8x100x4096xf32> | |
| %1935 = stablehlo.broadcast_in_dim %arg82, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1936 = stablehlo.multiply %1934, %1935 : tensor<8x100x4096xf32> | |
| %1937 = stablehlo.reshape %1936 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1938 = stablehlo.transpose %arg293, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1939 = stablehlo.dot_general %1937, %1938, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1940 = stablehlo.reshape %1939 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1941 = stablehlo.logistic %1940 : tensor<8x100x11008xf32> | |
| %1942 = stablehlo.multiply %1940, %1941 : tensor<8x100x11008xf32> | |
| %1943 = stablehlo.transpose %arg81, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %1944 = stablehlo.dot_general %1937, %1943, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %1945 = stablehlo.reshape %1944 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %1946 = stablehlo.multiply %1942, %1945 : tensor<8x100x11008xf32> | |
| %1947 = stablehlo.reshape %1946 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %1948 = stablehlo.transpose %arg80, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %1949 = stablehlo.dot_general %1947, %1948, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %1950 = stablehlo.reshape %1949 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %1951 = stablehlo.add %1925, %1950 : tensor<8x100x4096xf32> | |
| %1952 = stablehlo.power %1951, %cst_3 : tensor<8x100x4096xf32> | |
| %1953 = stablehlo.reduce(%1952 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %1954 = stablehlo.multiply %1953, %cst_2 : tensor<8x100xf32> | |
| %1955 = stablehlo.reshape %1954 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %1956 = stablehlo.add %1955, %cst_1 : tensor<8x100x1xf32> | |
| %1957 = stablehlo.rsqrt %1956 : tensor<8x100x1xf32> | |
| %1958 = stablehlo.reshape %1957 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %1959 = stablehlo.broadcast_in_dim %1958, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %1960 = stablehlo.multiply %1951, %1959 : tensor<8x100x4096xf32> | |
| %1961 = stablehlo.broadcast_in_dim %arg79, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %1962 = stablehlo.multiply %1960, %1961 : tensor<8x100x4096xf32> | |
| %1963 = stablehlo.reshape %1962 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %1964 = stablehlo.transpose %arg297, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1965 = stablehlo.dot_general %1963, %1964, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1966 = stablehlo.reshape %1965 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1967 = stablehlo.transpose %1966, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1968 = stablehlo.reshape %1967 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1969 = stablehlo.slice %1968 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1970 = stablehlo.reshape %1969 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1971 = stablehlo.slice %1968 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1972 = stablehlo.reshape %1971 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1973 = stablehlo.complex %1970, %1972 : tensor<256x100x64xcomplex<f32>> | |
| %1974 = stablehlo.multiply %1973, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1975 = stablehlo.real %1974 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1976 = stablehlo.reshape %1975 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1977 = stablehlo.imag %1974 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1978 = stablehlo.reshape %1977 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1979 = stablehlo.concatenate %1976, %1978, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1980 = stablehlo.reshape %1979 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %1981 = stablehlo.transpose %arg295, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %1982 = stablehlo.dot_general %1963, %1981, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %1983 = stablehlo.reshape %1982 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %1984 = stablehlo.transpose %1983, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %1985 = stablehlo.reshape %1984 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %1986 = stablehlo.slice %1985 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1987 = stablehlo.reshape %1986 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1988 = stablehlo.slice %1985 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %1989 = stablehlo.reshape %1988 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %1990 = stablehlo.complex %1987, %1989 : tensor<256x100x64xcomplex<f32>> | |
| %1991 = stablehlo.multiply %1990, %28 : tensor<256x100x64xcomplex<f32>> | |
| %1992 = stablehlo.real %1991 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1993 = stablehlo.reshape %1992 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1994 = stablehlo.imag %1991 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %1995 = stablehlo.reshape %1994 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %1996 = stablehlo.concatenate %1993, %1995, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %1997 = stablehlo.reshape %1996 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %1998 = stablehlo.transpose %1997, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %1999 = "stablehlo.scatter"(%arg296, %39, %1998) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2000 = stablehlo.transpose %1999, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2001 = stablehlo.reshape %2000 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2002 = stablehlo.dot_general %1980, %2001, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2003 = stablehlo.reshape %2002 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2004 = stablehlo.divide %2003, %cst : tensor<8x32x100x1024xf32> | |
| %2005 = stablehlo.add %2004, %66 : tensor<8x32x100x1024xf32> | |
| %2006 = stablehlo.reduce(%2005 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2007 = stablehlo.broadcast_in_dim %2006, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2008 = stablehlo.subtract %2005, %2007 : tensor<8x32x100x1024xf32> | |
| %2009 = stablehlo.exponential %2008 : tensor<8x32x100x1024xf32> | |
| %2010 = stablehlo.reduce(%2009 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2011 = stablehlo.broadcast_in_dim %2010, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2012 = stablehlo.divide %2009, %2011 : tensor<8x32x100x1024xf32> | |
| %2013 = stablehlo.reshape %2012 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2014 = stablehlo.transpose %arg78, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2015 = stablehlo.dot_general %1963, %2014, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2016 = stablehlo.reshape %2015 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2017 = "stablehlo.scatter"(%arg294, %39, %2016) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2018 = stablehlo.transpose %2017, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2019 = stablehlo.reshape %2018 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2020 = stablehlo.dot_general %2013, %2019, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2021 = stablehlo.reshape %2020 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2022 = stablehlo.transpose %2021, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2023 = stablehlo.reshape %2022 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2024 = stablehlo.transpose %arg77, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2025 = stablehlo.dot_general %2023, %2024, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2026 = stablehlo.reshape %2025 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2027 = stablehlo.add %1951, %2026 : tensor<8x100x4096xf32> | |
| %2028 = stablehlo.power %2027, %cst_3 : tensor<8x100x4096xf32> | |
| %2029 = stablehlo.reduce(%2028 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2030 = stablehlo.multiply %2029, %cst_2 : tensor<8x100xf32> | |
| %2031 = stablehlo.reshape %2030 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2032 = stablehlo.add %2031, %cst_1 : tensor<8x100x1xf32> | |
| %2033 = stablehlo.rsqrt %2032 : tensor<8x100x1xf32> | |
| %2034 = stablehlo.reshape %2033 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2035 = stablehlo.broadcast_in_dim %2034, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2036 = stablehlo.multiply %2027, %2035 : tensor<8x100x4096xf32> | |
| %2037 = stablehlo.broadcast_in_dim %arg76, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2038 = stablehlo.multiply %2036, %2037 : tensor<8x100x4096xf32> | |
| %2039 = stablehlo.reshape %2038 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2040 = stablehlo.transpose %arg298, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2041 = stablehlo.dot_general %2039, %2040, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2042 = stablehlo.reshape %2041 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2043 = stablehlo.logistic %2042 : tensor<8x100x11008xf32> | |
| %2044 = stablehlo.multiply %2042, %2043 : tensor<8x100x11008xf32> | |
| %2045 = stablehlo.transpose %arg75, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2046 = stablehlo.dot_general %2039, %2045, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2047 = stablehlo.reshape %2046 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2048 = stablehlo.multiply %2044, %2047 : tensor<8x100x11008xf32> | |
| %2049 = stablehlo.reshape %2048 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2050 = stablehlo.transpose %arg74, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2051 = stablehlo.dot_general %2049, %2050, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2052 = stablehlo.reshape %2051 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2053 = stablehlo.add %2027, %2052 : tensor<8x100x4096xf32> | |
| %2054 = stablehlo.power %2053, %cst_3 : tensor<8x100x4096xf32> | |
| %2055 = stablehlo.reduce(%2054 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2056 = stablehlo.multiply %2055, %cst_2 : tensor<8x100xf32> | |
| %2057 = stablehlo.reshape %2056 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2058 = stablehlo.add %2057, %cst_1 : tensor<8x100x1xf32> | |
| %2059 = stablehlo.rsqrt %2058 : tensor<8x100x1xf32> | |
| %2060 = stablehlo.reshape %2059 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2061 = stablehlo.broadcast_in_dim %2060, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2062 = stablehlo.multiply %2053, %2061 : tensor<8x100x4096xf32> | |
| %2063 = stablehlo.broadcast_in_dim %arg73, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2064 = stablehlo.multiply %2062, %2063 : tensor<8x100x4096xf32> | |
| %2065 = stablehlo.reshape %2064 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2066 = stablehlo.transpose %arg302, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2067 = stablehlo.dot_general %2065, %2066, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2068 = stablehlo.reshape %2067 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2069 = stablehlo.transpose %2068, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2070 = stablehlo.reshape %2069 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2071 = stablehlo.slice %2070 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2072 = stablehlo.reshape %2071 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2073 = stablehlo.slice %2070 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2074 = stablehlo.reshape %2073 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2075 = stablehlo.complex %2072, %2074 : tensor<256x100x64xcomplex<f32>> | |
| %2076 = stablehlo.multiply %2075, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2077 = stablehlo.real %2076 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2078 = stablehlo.reshape %2077 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2079 = stablehlo.imag %2076 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2080 = stablehlo.reshape %2079 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2081 = stablehlo.concatenate %2078, %2080, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2082 = stablehlo.reshape %2081 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2083 = stablehlo.transpose %arg300, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2084 = stablehlo.dot_general %2065, %2083, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2085 = stablehlo.reshape %2084 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2086 = stablehlo.transpose %2085, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2087 = stablehlo.reshape %2086 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2088 = stablehlo.slice %2087 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2089 = stablehlo.reshape %2088 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2090 = stablehlo.slice %2087 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2091 = stablehlo.reshape %2090 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2092 = stablehlo.complex %2089, %2091 : tensor<256x100x64xcomplex<f32>> | |
| %2093 = stablehlo.multiply %2092, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2094 = stablehlo.real %2093 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2095 = stablehlo.reshape %2094 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2096 = stablehlo.imag %2093 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2097 = stablehlo.reshape %2096 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2098 = stablehlo.concatenate %2095, %2097, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2099 = stablehlo.reshape %2098 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2100 = stablehlo.transpose %2099, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2101 = "stablehlo.scatter"(%arg301, %39, %2100) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2102 = stablehlo.transpose %2101, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2103 = stablehlo.reshape %2102 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2104 = stablehlo.dot_general %2082, %2103, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2105 = stablehlo.reshape %2104 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2106 = stablehlo.divide %2105, %cst : tensor<8x32x100x1024xf32> | |
| %2107 = stablehlo.add %2106, %66 : tensor<8x32x100x1024xf32> | |
| %2108 = stablehlo.reduce(%2107 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2109 = stablehlo.broadcast_in_dim %2108, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2110 = stablehlo.subtract %2107, %2109 : tensor<8x32x100x1024xf32> | |
| %2111 = stablehlo.exponential %2110 : tensor<8x32x100x1024xf32> | |
| %2112 = stablehlo.reduce(%2111 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2113 = stablehlo.broadcast_in_dim %2112, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2114 = stablehlo.divide %2111, %2113 : tensor<8x32x100x1024xf32> | |
| %2115 = stablehlo.reshape %2114 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2116 = stablehlo.transpose %arg72, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2117 = stablehlo.dot_general %2065, %2116, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2118 = stablehlo.reshape %2117 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2119 = "stablehlo.scatter"(%arg299, %39, %2118) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2120 = stablehlo.transpose %2119, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2121 = stablehlo.reshape %2120 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2122 = stablehlo.dot_general %2115, %2121, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2123 = stablehlo.reshape %2122 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2124 = stablehlo.transpose %2123, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2125 = stablehlo.reshape %2124 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2126 = stablehlo.transpose %arg71, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2127 = stablehlo.dot_general %2125, %2126, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2128 = stablehlo.reshape %2127 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2129 = stablehlo.add %2053, %2128 : tensor<8x100x4096xf32> | |
| %2130 = stablehlo.power %2129, %cst_3 : tensor<8x100x4096xf32> | |
| %2131 = stablehlo.reduce(%2130 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2132 = stablehlo.multiply %2131, %cst_2 : tensor<8x100xf32> | |
| %2133 = stablehlo.reshape %2132 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2134 = stablehlo.add %2133, %cst_1 : tensor<8x100x1xf32> | |
| %2135 = stablehlo.rsqrt %2134 : tensor<8x100x1xf32> | |
| %2136 = stablehlo.reshape %2135 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2137 = stablehlo.broadcast_in_dim %2136, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2138 = stablehlo.multiply %2129, %2137 : tensor<8x100x4096xf32> | |
| %2139 = stablehlo.broadcast_in_dim %arg70, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2140 = stablehlo.multiply %2138, %2139 : tensor<8x100x4096xf32> | |
| %2141 = stablehlo.reshape %2140 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2142 = stablehlo.transpose %arg303, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2143 = stablehlo.dot_general %2141, %2142, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2144 = stablehlo.reshape %2143 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2145 = stablehlo.logistic %2144 : tensor<8x100x11008xf32> | |
| %2146 = stablehlo.multiply %2144, %2145 : tensor<8x100x11008xf32> | |
| %2147 = stablehlo.transpose %arg69, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2148 = stablehlo.dot_general %2141, %2147, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2149 = stablehlo.reshape %2148 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2150 = stablehlo.multiply %2146, %2149 : tensor<8x100x11008xf32> | |
| %2151 = stablehlo.reshape %2150 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2152 = stablehlo.transpose %arg68, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2153 = stablehlo.dot_general %2151, %2152, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2154 = stablehlo.reshape %2153 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2155 = stablehlo.add %2129, %2154 : tensor<8x100x4096xf32> | |
| %2156 = stablehlo.power %2155, %cst_3 : tensor<8x100x4096xf32> | |
| %2157 = stablehlo.reduce(%2156 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2158 = stablehlo.multiply %2157, %cst_2 : tensor<8x100xf32> | |
| %2159 = stablehlo.reshape %2158 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2160 = stablehlo.add %2159, %cst_1 : tensor<8x100x1xf32> | |
| %2161 = stablehlo.rsqrt %2160 : tensor<8x100x1xf32> | |
| %2162 = stablehlo.reshape %2161 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2163 = stablehlo.broadcast_in_dim %2162, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2164 = stablehlo.multiply %2155, %2163 : tensor<8x100x4096xf32> | |
| %2165 = stablehlo.broadcast_in_dim %arg67, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2166 = stablehlo.multiply %2164, %2165 : tensor<8x100x4096xf32> | |
| %2167 = stablehlo.reshape %2166 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2168 = stablehlo.transpose %arg307, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2169 = stablehlo.dot_general %2167, %2168, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2170 = stablehlo.reshape %2169 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2171 = stablehlo.transpose %2170, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2172 = stablehlo.reshape %2171 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2173 = stablehlo.slice %2172 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2174 = stablehlo.reshape %2173 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2175 = stablehlo.slice %2172 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2176 = stablehlo.reshape %2175 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2177 = stablehlo.complex %2174, %2176 : tensor<256x100x64xcomplex<f32>> | |
| %2178 = stablehlo.multiply %2177, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2179 = stablehlo.real %2178 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2180 = stablehlo.reshape %2179 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2181 = stablehlo.imag %2178 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2182 = stablehlo.reshape %2181 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2183 = stablehlo.concatenate %2180, %2182, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2184 = stablehlo.reshape %2183 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2185 = stablehlo.transpose %arg305, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2186 = stablehlo.dot_general %2167, %2185, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2187 = stablehlo.reshape %2186 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2188 = stablehlo.transpose %2187, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2189 = stablehlo.reshape %2188 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2190 = stablehlo.slice %2189 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2191 = stablehlo.reshape %2190 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2192 = stablehlo.slice %2189 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2193 = stablehlo.reshape %2192 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2194 = stablehlo.complex %2191, %2193 : tensor<256x100x64xcomplex<f32>> | |
| %2195 = stablehlo.multiply %2194, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2196 = stablehlo.real %2195 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2197 = stablehlo.reshape %2196 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2198 = stablehlo.imag %2195 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2199 = stablehlo.reshape %2198 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2200 = stablehlo.concatenate %2197, %2199, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2201 = stablehlo.reshape %2200 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2202 = stablehlo.transpose %2201, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2203 = "stablehlo.scatter"(%arg306, %39, %2202) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2204 = stablehlo.transpose %2203, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2205 = stablehlo.reshape %2204 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2206 = stablehlo.dot_general %2184, %2205, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2207 = stablehlo.reshape %2206 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2208 = stablehlo.divide %2207, %cst : tensor<8x32x100x1024xf32> | |
| %2209 = stablehlo.add %2208, %66 : tensor<8x32x100x1024xf32> | |
| %2210 = stablehlo.reduce(%2209 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2211 = stablehlo.broadcast_in_dim %2210, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2212 = stablehlo.subtract %2209, %2211 : tensor<8x32x100x1024xf32> | |
| %2213 = stablehlo.exponential %2212 : tensor<8x32x100x1024xf32> | |
| %2214 = stablehlo.reduce(%2213 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2215 = stablehlo.broadcast_in_dim %2214, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2216 = stablehlo.divide %2213, %2215 : tensor<8x32x100x1024xf32> | |
| %2217 = stablehlo.reshape %2216 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2218 = stablehlo.transpose %arg66, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2219 = stablehlo.dot_general %2167, %2218, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2220 = stablehlo.reshape %2219 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2221 = "stablehlo.scatter"(%arg304, %39, %2220) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2222 = stablehlo.transpose %2221, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2223 = stablehlo.reshape %2222 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2224 = stablehlo.dot_general %2217, %2223, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2225 = stablehlo.reshape %2224 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2226 = stablehlo.transpose %2225, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2227 = stablehlo.reshape %2226 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2228 = stablehlo.transpose %arg65, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2229 = stablehlo.dot_general %2227, %2228, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2230 = stablehlo.reshape %2229 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2231 = stablehlo.add %2155, %2230 : tensor<8x100x4096xf32> | |
| %2232 = stablehlo.power %2231, %cst_3 : tensor<8x100x4096xf32> | |
| %2233 = stablehlo.reduce(%2232 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2234 = stablehlo.multiply %2233, %cst_2 : tensor<8x100xf32> | |
| %2235 = stablehlo.reshape %2234 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2236 = stablehlo.add %2235, %cst_1 : tensor<8x100x1xf32> | |
| %2237 = stablehlo.rsqrt %2236 : tensor<8x100x1xf32> | |
| %2238 = stablehlo.reshape %2237 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2239 = stablehlo.broadcast_in_dim %2238, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2240 = stablehlo.multiply %2231, %2239 : tensor<8x100x4096xf32> | |
| %2241 = stablehlo.broadcast_in_dim %arg64, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2242 = stablehlo.multiply %2240, %2241 : tensor<8x100x4096xf32> | |
| %2243 = stablehlo.reshape %2242 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2244 = stablehlo.transpose %arg308, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2245 = stablehlo.dot_general %2243, %2244, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2246 = stablehlo.reshape %2245 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2247 = stablehlo.logistic %2246 : tensor<8x100x11008xf32> | |
| %2248 = stablehlo.multiply %2246, %2247 : tensor<8x100x11008xf32> | |
| %2249 = stablehlo.transpose %arg63, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2250 = stablehlo.dot_general %2243, %2249, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2251 = stablehlo.reshape %2250 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2252 = stablehlo.multiply %2248, %2251 : tensor<8x100x11008xf32> | |
| %2253 = stablehlo.reshape %2252 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2254 = stablehlo.transpose %arg62, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2255 = stablehlo.dot_general %2253, %2254, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2256 = stablehlo.reshape %2255 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2257 = stablehlo.add %2231, %2256 : tensor<8x100x4096xf32> | |
| %2258 = stablehlo.power %2257, %cst_3 : tensor<8x100x4096xf32> | |
| %2259 = stablehlo.reduce(%2258 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2260 = stablehlo.multiply %2259, %cst_2 : tensor<8x100xf32> | |
| %2261 = stablehlo.reshape %2260 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2262 = stablehlo.add %2261, %cst_1 : tensor<8x100x1xf32> | |
| %2263 = stablehlo.rsqrt %2262 : tensor<8x100x1xf32> | |
| %2264 = stablehlo.reshape %2263 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2265 = stablehlo.broadcast_in_dim %2264, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2266 = stablehlo.multiply %2257, %2265 : tensor<8x100x4096xf32> | |
| %2267 = stablehlo.broadcast_in_dim %arg61, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2268 = stablehlo.multiply %2266, %2267 : tensor<8x100x4096xf32> | |
| %2269 = stablehlo.reshape %2268 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2270 = stablehlo.transpose %arg312, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2271 = stablehlo.dot_general %2269, %2270, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2272 = stablehlo.reshape %2271 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2273 = stablehlo.transpose %2272, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2274 = stablehlo.reshape %2273 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2275 = stablehlo.slice %2274 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2276 = stablehlo.reshape %2275 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2277 = stablehlo.slice %2274 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2278 = stablehlo.reshape %2277 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2279 = stablehlo.complex %2276, %2278 : tensor<256x100x64xcomplex<f32>> | |
| %2280 = stablehlo.multiply %2279, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2281 = stablehlo.real %2280 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2282 = stablehlo.reshape %2281 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2283 = stablehlo.imag %2280 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2284 = stablehlo.reshape %2283 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2285 = stablehlo.concatenate %2282, %2284, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2286 = stablehlo.reshape %2285 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2287 = stablehlo.transpose %arg310, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2288 = stablehlo.dot_general %2269, %2287, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2289 = stablehlo.reshape %2288 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2290 = stablehlo.transpose %2289, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2291 = stablehlo.reshape %2290 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2292 = stablehlo.slice %2291 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2293 = stablehlo.reshape %2292 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2294 = stablehlo.slice %2291 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2295 = stablehlo.reshape %2294 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2296 = stablehlo.complex %2293, %2295 : tensor<256x100x64xcomplex<f32>> | |
| %2297 = stablehlo.multiply %2296, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2298 = stablehlo.real %2297 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2299 = stablehlo.reshape %2298 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2300 = stablehlo.imag %2297 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2301 = stablehlo.reshape %2300 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2302 = stablehlo.concatenate %2299, %2301, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2303 = stablehlo.reshape %2302 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2304 = stablehlo.transpose %2303, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2305 = "stablehlo.scatter"(%arg311, %39, %2304) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2306 = stablehlo.transpose %2305, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2307 = stablehlo.reshape %2306 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2308 = stablehlo.dot_general %2286, %2307, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2309 = stablehlo.reshape %2308 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2310 = stablehlo.divide %2309, %cst : tensor<8x32x100x1024xf32> | |
| %2311 = stablehlo.add %2310, %66 : tensor<8x32x100x1024xf32> | |
| %2312 = stablehlo.reduce(%2311 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2313 = stablehlo.broadcast_in_dim %2312, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2314 = stablehlo.subtract %2311, %2313 : tensor<8x32x100x1024xf32> | |
| %2315 = stablehlo.exponential %2314 : tensor<8x32x100x1024xf32> | |
| %2316 = stablehlo.reduce(%2315 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2317 = stablehlo.broadcast_in_dim %2316, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2318 = stablehlo.divide %2315, %2317 : tensor<8x32x100x1024xf32> | |
| %2319 = stablehlo.reshape %2318 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2320 = stablehlo.transpose %arg60, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2321 = stablehlo.dot_general %2269, %2320, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2322 = stablehlo.reshape %2321 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2323 = "stablehlo.scatter"(%arg309, %39, %2322) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2324 = stablehlo.transpose %2323, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2325 = stablehlo.reshape %2324 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2326 = stablehlo.dot_general %2319, %2325, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2327 = stablehlo.reshape %2326 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2328 = stablehlo.transpose %2327, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2329 = stablehlo.reshape %2328 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2330 = stablehlo.transpose %arg59, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2331 = stablehlo.dot_general %2329, %2330, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2332 = stablehlo.reshape %2331 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2333 = stablehlo.add %2257, %2332 : tensor<8x100x4096xf32> | |
| %2334 = stablehlo.power %2333, %cst_3 : tensor<8x100x4096xf32> | |
| %2335 = stablehlo.reduce(%2334 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2336 = stablehlo.multiply %2335, %cst_2 : tensor<8x100xf32> | |
| %2337 = stablehlo.reshape %2336 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2338 = stablehlo.add %2337, %cst_1 : tensor<8x100x1xf32> | |
| %2339 = stablehlo.rsqrt %2338 : tensor<8x100x1xf32> | |
| %2340 = stablehlo.reshape %2339 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2341 = stablehlo.broadcast_in_dim %2340, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2342 = stablehlo.multiply %2333, %2341 : tensor<8x100x4096xf32> | |
| %2343 = stablehlo.broadcast_in_dim %arg58, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2344 = stablehlo.multiply %2342, %2343 : tensor<8x100x4096xf32> | |
| %2345 = stablehlo.reshape %2344 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2346 = stablehlo.transpose %arg313, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2347 = stablehlo.dot_general %2345, %2346, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2348 = stablehlo.reshape %2347 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2349 = stablehlo.logistic %2348 : tensor<8x100x11008xf32> | |
| %2350 = stablehlo.multiply %2348, %2349 : tensor<8x100x11008xf32> | |
| %2351 = stablehlo.transpose %arg57, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2352 = stablehlo.dot_general %2345, %2351, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2353 = stablehlo.reshape %2352 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2354 = stablehlo.multiply %2350, %2353 : tensor<8x100x11008xf32> | |
| %2355 = stablehlo.reshape %2354 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2356 = stablehlo.transpose %arg56, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2357 = stablehlo.dot_general %2355, %2356, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2358 = stablehlo.reshape %2357 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2359 = stablehlo.add %2333, %2358 : tensor<8x100x4096xf32> | |
| %2360 = stablehlo.power %2359, %cst_3 : tensor<8x100x4096xf32> | |
| %2361 = stablehlo.reduce(%2360 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2362 = stablehlo.multiply %2361, %cst_2 : tensor<8x100xf32> | |
| %2363 = stablehlo.reshape %2362 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2364 = stablehlo.add %2363, %cst_1 : tensor<8x100x1xf32> | |
| %2365 = stablehlo.rsqrt %2364 : tensor<8x100x1xf32> | |
| %2366 = stablehlo.reshape %2365 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2367 = stablehlo.broadcast_in_dim %2366, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2368 = stablehlo.multiply %2359, %2367 : tensor<8x100x4096xf32> | |
| %2369 = stablehlo.broadcast_in_dim %arg55, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2370 = stablehlo.multiply %2368, %2369 : tensor<8x100x4096xf32> | |
| %2371 = stablehlo.reshape %2370 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2372 = stablehlo.transpose %arg317, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2373 = stablehlo.dot_general %2371, %2372, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2374 = stablehlo.reshape %2373 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2375 = stablehlo.transpose %2374, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2376 = stablehlo.reshape %2375 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2377 = stablehlo.slice %2376 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2378 = stablehlo.reshape %2377 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2379 = stablehlo.slice %2376 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2380 = stablehlo.reshape %2379 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2381 = stablehlo.complex %2378, %2380 : tensor<256x100x64xcomplex<f32>> | |
| %2382 = stablehlo.multiply %2381, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2383 = stablehlo.real %2382 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2384 = stablehlo.reshape %2383 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2385 = stablehlo.imag %2382 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2386 = stablehlo.reshape %2385 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2387 = stablehlo.concatenate %2384, %2386, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2388 = stablehlo.reshape %2387 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2389 = stablehlo.transpose %arg315, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2390 = stablehlo.dot_general %2371, %2389, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2391 = stablehlo.reshape %2390 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2392 = stablehlo.transpose %2391, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2393 = stablehlo.reshape %2392 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2394 = stablehlo.slice %2393 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2395 = stablehlo.reshape %2394 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2396 = stablehlo.slice %2393 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2397 = stablehlo.reshape %2396 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2398 = stablehlo.complex %2395, %2397 : tensor<256x100x64xcomplex<f32>> | |
| %2399 = stablehlo.multiply %2398, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2400 = stablehlo.real %2399 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2401 = stablehlo.reshape %2400 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2402 = stablehlo.imag %2399 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2403 = stablehlo.reshape %2402 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2404 = stablehlo.concatenate %2401, %2403, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2405 = stablehlo.reshape %2404 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2406 = stablehlo.transpose %2405, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2407 = "stablehlo.scatter"(%arg316, %39, %2406) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2408 = stablehlo.transpose %2407, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2409 = stablehlo.reshape %2408 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2410 = stablehlo.dot_general %2388, %2409, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2411 = stablehlo.reshape %2410 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2412 = stablehlo.divide %2411, %cst : tensor<8x32x100x1024xf32> | |
| %2413 = stablehlo.add %2412, %66 : tensor<8x32x100x1024xf32> | |
| %2414 = stablehlo.reduce(%2413 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2415 = stablehlo.broadcast_in_dim %2414, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2416 = stablehlo.subtract %2413, %2415 : tensor<8x32x100x1024xf32> | |
| %2417 = stablehlo.exponential %2416 : tensor<8x32x100x1024xf32> | |
| %2418 = stablehlo.reduce(%2417 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2419 = stablehlo.broadcast_in_dim %2418, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2420 = stablehlo.divide %2417, %2419 : tensor<8x32x100x1024xf32> | |
| %2421 = stablehlo.reshape %2420 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2422 = stablehlo.transpose %arg54, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2423 = stablehlo.dot_general %2371, %2422, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2424 = stablehlo.reshape %2423 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2425 = "stablehlo.scatter"(%arg314, %39, %2424) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2426 = stablehlo.transpose %2425, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2427 = stablehlo.reshape %2426 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2428 = stablehlo.dot_general %2421, %2427, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2429 = stablehlo.reshape %2428 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2430 = stablehlo.transpose %2429, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2431 = stablehlo.reshape %2430 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2432 = stablehlo.transpose %arg53, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2433 = stablehlo.dot_general %2431, %2432, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2434 = stablehlo.reshape %2433 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2435 = stablehlo.add %2359, %2434 : tensor<8x100x4096xf32> | |
| %2436 = stablehlo.power %2435, %cst_3 : tensor<8x100x4096xf32> | |
| %2437 = stablehlo.reduce(%2436 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2438 = stablehlo.multiply %2437, %cst_2 : tensor<8x100xf32> | |
| %2439 = stablehlo.reshape %2438 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2440 = stablehlo.add %2439, %cst_1 : tensor<8x100x1xf32> | |
| %2441 = stablehlo.rsqrt %2440 : tensor<8x100x1xf32> | |
| %2442 = stablehlo.reshape %2441 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2443 = stablehlo.broadcast_in_dim %2442, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2444 = stablehlo.multiply %2435, %2443 : tensor<8x100x4096xf32> | |
| %2445 = stablehlo.broadcast_in_dim %arg52, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2446 = stablehlo.multiply %2444, %2445 : tensor<8x100x4096xf32> | |
| %2447 = stablehlo.reshape %2446 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2448 = stablehlo.transpose %arg318, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2449 = stablehlo.dot_general %2447, %2448, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2450 = stablehlo.reshape %2449 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2451 = stablehlo.logistic %2450 : tensor<8x100x11008xf32> | |
| %2452 = stablehlo.multiply %2450, %2451 : tensor<8x100x11008xf32> | |
| %2453 = stablehlo.transpose %arg51, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2454 = stablehlo.dot_general %2447, %2453, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2455 = stablehlo.reshape %2454 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2456 = stablehlo.multiply %2452, %2455 : tensor<8x100x11008xf32> | |
| %2457 = stablehlo.reshape %2456 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2458 = stablehlo.transpose %arg50, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2459 = stablehlo.dot_general %2457, %2458, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2460 = stablehlo.reshape %2459 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2461 = stablehlo.add %2435, %2460 : tensor<8x100x4096xf32> | |
| %2462 = stablehlo.power %2461, %cst_3 : tensor<8x100x4096xf32> | |
| %2463 = stablehlo.reduce(%2462 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2464 = stablehlo.multiply %2463, %cst_2 : tensor<8x100xf32> | |
| %2465 = stablehlo.reshape %2464 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2466 = stablehlo.add %2465, %cst_1 : tensor<8x100x1xf32> | |
| %2467 = stablehlo.rsqrt %2466 : tensor<8x100x1xf32> | |
| %2468 = stablehlo.reshape %2467 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2469 = stablehlo.broadcast_in_dim %2468, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2470 = stablehlo.multiply %2461, %2469 : tensor<8x100x4096xf32> | |
| %2471 = stablehlo.broadcast_in_dim %arg49, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2472 = stablehlo.multiply %2470, %2471 : tensor<8x100x4096xf32> | |
| %2473 = stablehlo.reshape %2472 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2474 = stablehlo.transpose %arg322, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2475 = stablehlo.dot_general %2473, %2474, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2476 = stablehlo.reshape %2475 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2477 = stablehlo.transpose %2476, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2478 = stablehlo.reshape %2477 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2479 = stablehlo.slice %2478 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2480 = stablehlo.reshape %2479 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2481 = stablehlo.slice %2478 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2482 = stablehlo.reshape %2481 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2483 = stablehlo.complex %2480, %2482 : tensor<256x100x64xcomplex<f32>> | |
| %2484 = stablehlo.multiply %2483, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2485 = stablehlo.real %2484 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2486 = stablehlo.reshape %2485 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2487 = stablehlo.imag %2484 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2488 = stablehlo.reshape %2487 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2489 = stablehlo.concatenate %2486, %2488, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2490 = stablehlo.reshape %2489 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2491 = stablehlo.transpose %arg320, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2492 = stablehlo.dot_general %2473, %2491, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2493 = stablehlo.reshape %2492 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2494 = stablehlo.transpose %2493, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2495 = stablehlo.reshape %2494 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2496 = stablehlo.slice %2495 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2497 = stablehlo.reshape %2496 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2498 = stablehlo.slice %2495 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2499 = stablehlo.reshape %2498 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2500 = stablehlo.complex %2497, %2499 : tensor<256x100x64xcomplex<f32>> | |
| %2501 = stablehlo.multiply %2500, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2502 = stablehlo.real %2501 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2503 = stablehlo.reshape %2502 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2504 = stablehlo.imag %2501 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2505 = stablehlo.reshape %2504 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2506 = stablehlo.concatenate %2503, %2505, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2507 = stablehlo.reshape %2506 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2508 = stablehlo.transpose %2507, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2509 = "stablehlo.scatter"(%arg321, %39, %2508) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2510 = stablehlo.transpose %2509, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2511 = stablehlo.reshape %2510 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2512 = stablehlo.dot_general %2490, %2511, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2513 = stablehlo.reshape %2512 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2514 = stablehlo.divide %2513, %cst : tensor<8x32x100x1024xf32> | |
| %2515 = stablehlo.add %2514, %66 : tensor<8x32x100x1024xf32> | |
| %2516 = stablehlo.reduce(%2515 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2517 = stablehlo.broadcast_in_dim %2516, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2518 = stablehlo.subtract %2515, %2517 : tensor<8x32x100x1024xf32> | |
| %2519 = stablehlo.exponential %2518 : tensor<8x32x100x1024xf32> | |
| %2520 = stablehlo.reduce(%2519 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2521 = stablehlo.broadcast_in_dim %2520, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2522 = stablehlo.divide %2519, %2521 : tensor<8x32x100x1024xf32> | |
| %2523 = stablehlo.reshape %2522 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2524 = stablehlo.transpose %arg48, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2525 = stablehlo.dot_general %2473, %2524, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2526 = stablehlo.reshape %2525 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2527 = "stablehlo.scatter"(%arg319, %39, %2526) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2528 = stablehlo.transpose %2527, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2529 = stablehlo.reshape %2528 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2530 = stablehlo.dot_general %2523, %2529, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2531 = stablehlo.reshape %2530 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2532 = stablehlo.transpose %2531, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2533 = stablehlo.reshape %2532 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2534 = stablehlo.transpose %arg47, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2535 = stablehlo.dot_general %2533, %2534, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2536 = stablehlo.reshape %2535 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2537 = stablehlo.add %2461, %2536 : tensor<8x100x4096xf32> | |
| %2538 = stablehlo.power %2537, %cst_3 : tensor<8x100x4096xf32> | |
| %2539 = stablehlo.reduce(%2538 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2540 = stablehlo.multiply %2539, %cst_2 : tensor<8x100xf32> | |
| %2541 = stablehlo.reshape %2540 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2542 = stablehlo.add %2541, %cst_1 : tensor<8x100x1xf32> | |
| %2543 = stablehlo.rsqrt %2542 : tensor<8x100x1xf32> | |
| %2544 = stablehlo.reshape %2543 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2545 = stablehlo.broadcast_in_dim %2544, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2546 = stablehlo.multiply %2537, %2545 : tensor<8x100x4096xf32> | |
| %2547 = stablehlo.broadcast_in_dim %arg46, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2548 = stablehlo.multiply %2546, %2547 : tensor<8x100x4096xf32> | |
| %2549 = stablehlo.reshape %2548 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2550 = stablehlo.transpose %arg323, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2551 = stablehlo.dot_general %2549, %2550, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2552 = stablehlo.reshape %2551 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2553 = stablehlo.logistic %2552 : tensor<8x100x11008xf32> | |
| %2554 = stablehlo.multiply %2552, %2553 : tensor<8x100x11008xf32> | |
| %2555 = stablehlo.transpose %arg45, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2556 = stablehlo.dot_general %2549, %2555, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2557 = stablehlo.reshape %2556 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2558 = stablehlo.multiply %2554, %2557 : tensor<8x100x11008xf32> | |
| %2559 = stablehlo.reshape %2558 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2560 = stablehlo.transpose %arg44, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2561 = stablehlo.dot_general %2559, %2560, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2562 = stablehlo.reshape %2561 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2563 = stablehlo.add %2537, %2562 : tensor<8x100x4096xf32> | |
| %2564 = stablehlo.power %2563, %cst_3 : tensor<8x100x4096xf32> | |
| %2565 = stablehlo.reduce(%2564 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2566 = stablehlo.multiply %2565, %cst_2 : tensor<8x100xf32> | |
| %2567 = stablehlo.reshape %2566 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2568 = stablehlo.add %2567, %cst_1 : tensor<8x100x1xf32> | |
| %2569 = stablehlo.rsqrt %2568 : tensor<8x100x1xf32> | |
| %2570 = stablehlo.reshape %2569 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2571 = stablehlo.broadcast_in_dim %2570, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2572 = stablehlo.multiply %2563, %2571 : tensor<8x100x4096xf32> | |
| %2573 = stablehlo.broadcast_in_dim %arg43, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2574 = stablehlo.multiply %2572, %2573 : tensor<8x100x4096xf32> | |
| %2575 = stablehlo.reshape %2574 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2576 = stablehlo.transpose %arg327, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2577 = stablehlo.dot_general %2575, %2576, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2578 = stablehlo.reshape %2577 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2579 = stablehlo.transpose %2578, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2580 = stablehlo.reshape %2579 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2581 = stablehlo.slice %2580 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2582 = stablehlo.reshape %2581 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2583 = stablehlo.slice %2580 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2584 = stablehlo.reshape %2583 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2585 = stablehlo.complex %2582, %2584 : tensor<256x100x64xcomplex<f32>> | |
| %2586 = stablehlo.multiply %2585, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2587 = stablehlo.real %2586 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2588 = stablehlo.reshape %2587 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2589 = stablehlo.imag %2586 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2590 = stablehlo.reshape %2589 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2591 = stablehlo.concatenate %2588, %2590, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2592 = stablehlo.reshape %2591 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2593 = stablehlo.transpose %arg325, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2594 = stablehlo.dot_general %2575, %2593, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2595 = stablehlo.reshape %2594 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2596 = stablehlo.transpose %2595, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2597 = stablehlo.reshape %2596 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2598 = stablehlo.slice %2597 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2599 = stablehlo.reshape %2598 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2600 = stablehlo.slice %2597 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2601 = stablehlo.reshape %2600 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2602 = stablehlo.complex %2599, %2601 : tensor<256x100x64xcomplex<f32>> | |
| %2603 = stablehlo.multiply %2602, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2604 = stablehlo.real %2603 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2605 = stablehlo.reshape %2604 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2606 = stablehlo.imag %2603 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2607 = stablehlo.reshape %2606 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2608 = stablehlo.concatenate %2605, %2607, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2609 = stablehlo.reshape %2608 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2610 = stablehlo.transpose %2609, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2611 = "stablehlo.scatter"(%arg326, %39, %2610) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2612 = stablehlo.transpose %2611, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2613 = stablehlo.reshape %2612 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2614 = stablehlo.dot_general %2592, %2613, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2615 = stablehlo.reshape %2614 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2616 = stablehlo.divide %2615, %cst : tensor<8x32x100x1024xf32> | |
| %2617 = stablehlo.add %2616, %66 : tensor<8x32x100x1024xf32> | |
| %2618 = stablehlo.reduce(%2617 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2619 = stablehlo.broadcast_in_dim %2618, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2620 = stablehlo.subtract %2617, %2619 : tensor<8x32x100x1024xf32> | |
| %2621 = stablehlo.exponential %2620 : tensor<8x32x100x1024xf32> | |
| %2622 = stablehlo.reduce(%2621 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2623 = stablehlo.broadcast_in_dim %2622, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2624 = stablehlo.divide %2621, %2623 : tensor<8x32x100x1024xf32> | |
| %2625 = stablehlo.reshape %2624 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2626 = stablehlo.transpose %arg42, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2627 = stablehlo.dot_general %2575, %2626, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2628 = stablehlo.reshape %2627 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2629 = "stablehlo.scatter"(%arg324, %39, %2628) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2630 = stablehlo.transpose %2629, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2631 = stablehlo.reshape %2630 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2632 = stablehlo.dot_general %2625, %2631, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2633 = stablehlo.reshape %2632 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2634 = stablehlo.transpose %2633, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2635 = stablehlo.reshape %2634 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2636 = stablehlo.transpose %arg41, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2637 = stablehlo.dot_general %2635, %2636, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2638 = stablehlo.reshape %2637 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2639 = stablehlo.add %2563, %2638 : tensor<8x100x4096xf32> | |
| %2640 = stablehlo.power %2639, %cst_3 : tensor<8x100x4096xf32> | |
| %2641 = stablehlo.reduce(%2640 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2642 = stablehlo.multiply %2641, %cst_2 : tensor<8x100xf32> | |
| %2643 = stablehlo.reshape %2642 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2644 = stablehlo.add %2643, %cst_1 : tensor<8x100x1xf32> | |
| %2645 = stablehlo.rsqrt %2644 : tensor<8x100x1xf32> | |
| %2646 = stablehlo.reshape %2645 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2647 = stablehlo.broadcast_in_dim %2646, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2648 = stablehlo.multiply %2639, %2647 : tensor<8x100x4096xf32> | |
| %2649 = stablehlo.broadcast_in_dim %arg40, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2650 = stablehlo.multiply %2648, %2649 : tensor<8x100x4096xf32> | |
| %2651 = stablehlo.reshape %2650 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2652 = stablehlo.transpose %arg328, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2653 = stablehlo.dot_general %2651, %2652, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2654 = stablehlo.reshape %2653 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2655 = stablehlo.logistic %2654 : tensor<8x100x11008xf32> | |
| %2656 = stablehlo.multiply %2654, %2655 : tensor<8x100x11008xf32> | |
| %2657 = stablehlo.transpose %arg39, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2658 = stablehlo.dot_general %2651, %2657, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2659 = stablehlo.reshape %2658 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2660 = stablehlo.multiply %2656, %2659 : tensor<8x100x11008xf32> | |
| %2661 = stablehlo.reshape %2660 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2662 = stablehlo.transpose %arg38, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2663 = stablehlo.dot_general %2661, %2662, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2664 = stablehlo.reshape %2663 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2665 = stablehlo.add %2639, %2664 : tensor<8x100x4096xf32> | |
| %2666 = stablehlo.power %2665, %cst_3 : tensor<8x100x4096xf32> | |
| %2667 = stablehlo.reduce(%2666 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2668 = stablehlo.multiply %2667, %cst_2 : tensor<8x100xf32> | |
| %2669 = stablehlo.reshape %2668 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2670 = stablehlo.add %2669, %cst_1 : tensor<8x100x1xf32> | |
| %2671 = stablehlo.rsqrt %2670 : tensor<8x100x1xf32> | |
| %2672 = stablehlo.reshape %2671 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2673 = stablehlo.broadcast_in_dim %2672, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2674 = stablehlo.multiply %2665, %2673 : tensor<8x100x4096xf32> | |
| %2675 = stablehlo.broadcast_in_dim %arg37, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2676 = stablehlo.multiply %2674, %2675 : tensor<8x100x4096xf32> | |
| %2677 = stablehlo.reshape %2676 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2678 = stablehlo.transpose %arg332, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2679 = stablehlo.dot_general %2677, %2678, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2680 = stablehlo.reshape %2679 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2681 = stablehlo.transpose %2680, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2682 = stablehlo.reshape %2681 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2683 = stablehlo.slice %2682 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2684 = stablehlo.reshape %2683 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2685 = stablehlo.slice %2682 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2686 = stablehlo.reshape %2685 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2687 = stablehlo.complex %2684, %2686 : tensor<256x100x64xcomplex<f32>> | |
| %2688 = stablehlo.multiply %2687, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2689 = stablehlo.real %2688 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2690 = stablehlo.reshape %2689 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2691 = stablehlo.imag %2688 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2692 = stablehlo.reshape %2691 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2693 = stablehlo.concatenate %2690, %2692, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2694 = stablehlo.reshape %2693 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2695 = stablehlo.transpose %arg330, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2696 = stablehlo.dot_general %2677, %2695, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2697 = stablehlo.reshape %2696 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2698 = stablehlo.transpose %2697, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2699 = stablehlo.reshape %2698 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2700 = stablehlo.slice %2699 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2701 = stablehlo.reshape %2700 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2702 = stablehlo.slice %2699 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2703 = stablehlo.reshape %2702 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2704 = stablehlo.complex %2701, %2703 : tensor<256x100x64xcomplex<f32>> | |
| %2705 = stablehlo.multiply %2704, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2706 = stablehlo.real %2705 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2707 = stablehlo.reshape %2706 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2708 = stablehlo.imag %2705 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2709 = stablehlo.reshape %2708 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2710 = stablehlo.concatenate %2707, %2709, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2711 = stablehlo.reshape %2710 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2712 = stablehlo.transpose %2711, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2713 = "stablehlo.scatter"(%arg331, %39, %2712) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2714 = stablehlo.transpose %2713, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2715 = stablehlo.reshape %2714 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2716 = stablehlo.dot_general %2694, %2715, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2717 = stablehlo.reshape %2716 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2718 = stablehlo.divide %2717, %cst : tensor<8x32x100x1024xf32> | |
| %2719 = stablehlo.add %2718, %66 : tensor<8x32x100x1024xf32> | |
| %2720 = stablehlo.reduce(%2719 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2721 = stablehlo.broadcast_in_dim %2720, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2722 = stablehlo.subtract %2719, %2721 : tensor<8x32x100x1024xf32> | |
| %2723 = stablehlo.exponential %2722 : tensor<8x32x100x1024xf32> | |
| %2724 = stablehlo.reduce(%2723 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2725 = stablehlo.broadcast_in_dim %2724, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2726 = stablehlo.divide %2723, %2725 : tensor<8x32x100x1024xf32> | |
| %2727 = stablehlo.reshape %2726 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2728 = stablehlo.transpose %arg36, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2729 = stablehlo.dot_general %2677, %2728, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2730 = stablehlo.reshape %2729 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2731 = "stablehlo.scatter"(%arg329, %39, %2730) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2732 = stablehlo.transpose %2731, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2733 = stablehlo.reshape %2732 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2734 = stablehlo.dot_general %2727, %2733, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2735 = stablehlo.reshape %2734 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2736 = stablehlo.transpose %2735, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2737 = stablehlo.reshape %2736 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2738 = stablehlo.transpose %arg35, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2739 = stablehlo.dot_general %2737, %2738, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2740 = stablehlo.reshape %2739 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2741 = stablehlo.add %2665, %2740 : tensor<8x100x4096xf32> | |
| %2742 = stablehlo.power %2741, %cst_3 : tensor<8x100x4096xf32> | |
| %2743 = stablehlo.reduce(%2742 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2744 = stablehlo.multiply %2743, %cst_2 : tensor<8x100xf32> | |
| %2745 = stablehlo.reshape %2744 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2746 = stablehlo.add %2745, %cst_1 : tensor<8x100x1xf32> | |
| %2747 = stablehlo.rsqrt %2746 : tensor<8x100x1xf32> | |
| %2748 = stablehlo.reshape %2747 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2749 = stablehlo.broadcast_in_dim %2748, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2750 = stablehlo.multiply %2741, %2749 : tensor<8x100x4096xf32> | |
| %2751 = stablehlo.broadcast_in_dim %arg34, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2752 = stablehlo.multiply %2750, %2751 : tensor<8x100x4096xf32> | |
| %2753 = stablehlo.reshape %2752 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2754 = stablehlo.transpose %arg333, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2755 = stablehlo.dot_general %2753, %2754, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2756 = stablehlo.reshape %2755 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2757 = stablehlo.logistic %2756 : tensor<8x100x11008xf32> | |
| %2758 = stablehlo.multiply %2756, %2757 : tensor<8x100x11008xf32> | |
| %2759 = stablehlo.transpose %arg33, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2760 = stablehlo.dot_general %2753, %2759, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2761 = stablehlo.reshape %2760 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2762 = stablehlo.multiply %2758, %2761 : tensor<8x100x11008xf32> | |
| %2763 = stablehlo.reshape %2762 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2764 = stablehlo.transpose %arg32, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2765 = stablehlo.dot_general %2763, %2764, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2766 = stablehlo.reshape %2765 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2767 = stablehlo.add %2741, %2766 : tensor<8x100x4096xf32> | |
| %2768 = stablehlo.power %2767, %cst_3 : tensor<8x100x4096xf32> | |
| %2769 = stablehlo.reduce(%2768 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2770 = stablehlo.multiply %2769, %cst_2 : tensor<8x100xf32> | |
| %2771 = stablehlo.reshape %2770 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2772 = stablehlo.add %2771, %cst_1 : tensor<8x100x1xf32> | |
| %2773 = stablehlo.rsqrt %2772 : tensor<8x100x1xf32> | |
| %2774 = stablehlo.reshape %2773 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2775 = stablehlo.broadcast_in_dim %2774, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2776 = stablehlo.multiply %2767, %2775 : tensor<8x100x4096xf32> | |
| %2777 = stablehlo.broadcast_in_dim %arg31, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2778 = stablehlo.multiply %2776, %2777 : tensor<8x100x4096xf32> | |
| %2779 = stablehlo.reshape %2778 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2780 = stablehlo.transpose %arg337, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2781 = stablehlo.dot_general %2779, %2780, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2782 = stablehlo.reshape %2781 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2783 = stablehlo.transpose %2782, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2784 = stablehlo.reshape %2783 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2785 = stablehlo.slice %2784 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2786 = stablehlo.reshape %2785 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2787 = stablehlo.slice %2784 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2788 = stablehlo.reshape %2787 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2789 = stablehlo.complex %2786, %2788 : tensor<256x100x64xcomplex<f32>> | |
| %2790 = stablehlo.multiply %2789, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2791 = stablehlo.real %2790 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2792 = stablehlo.reshape %2791 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2793 = stablehlo.imag %2790 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2794 = stablehlo.reshape %2793 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2795 = stablehlo.concatenate %2792, %2794, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2796 = stablehlo.reshape %2795 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2797 = stablehlo.transpose %arg335, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2798 = stablehlo.dot_general %2779, %2797, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2799 = stablehlo.reshape %2798 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2800 = stablehlo.transpose %2799, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2801 = stablehlo.reshape %2800 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2802 = stablehlo.slice %2801 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2803 = stablehlo.reshape %2802 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2804 = stablehlo.slice %2801 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2805 = stablehlo.reshape %2804 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2806 = stablehlo.complex %2803, %2805 : tensor<256x100x64xcomplex<f32>> | |
| %2807 = stablehlo.multiply %2806, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2808 = stablehlo.real %2807 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2809 = stablehlo.reshape %2808 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2810 = stablehlo.imag %2807 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2811 = stablehlo.reshape %2810 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2812 = stablehlo.concatenate %2809, %2811, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2813 = stablehlo.reshape %2812 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2814 = stablehlo.transpose %2813, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2815 = "stablehlo.scatter"(%arg336, %39, %2814) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2816 = stablehlo.transpose %2815, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2817 = stablehlo.reshape %2816 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2818 = stablehlo.dot_general %2796, %2817, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2819 = stablehlo.reshape %2818 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2820 = stablehlo.divide %2819, %cst : tensor<8x32x100x1024xf32> | |
| %2821 = stablehlo.add %2820, %66 : tensor<8x32x100x1024xf32> | |
| %2822 = stablehlo.reduce(%2821 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2823 = stablehlo.broadcast_in_dim %2822, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2824 = stablehlo.subtract %2821, %2823 : tensor<8x32x100x1024xf32> | |
| %2825 = stablehlo.exponential %2824 : tensor<8x32x100x1024xf32> | |
| %2826 = stablehlo.reduce(%2825 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2827 = stablehlo.broadcast_in_dim %2826, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2828 = stablehlo.divide %2825, %2827 : tensor<8x32x100x1024xf32> | |
| %2829 = stablehlo.reshape %2828 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2830 = stablehlo.transpose %arg30, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2831 = stablehlo.dot_general %2779, %2830, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2832 = stablehlo.reshape %2831 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2833 = "stablehlo.scatter"(%arg334, %39, %2832) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2834 = stablehlo.transpose %2833, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2835 = stablehlo.reshape %2834 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2836 = stablehlo.dot_general %2829, %2835, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2837 = stablehlo.reshape %2836 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2838 = stablehlo.transpose %2837, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2839 = stablehlo.reshape %2838 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2840 = stablehlo.transpose %arg29, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2841 = stablehlo.dot_general %2839, %2840, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2842 = stablehlo.reshape %2841 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2843 = stablehlo.add %2767, %2842 : tensor<8x100x4096xf32> | |
| %2844 = stablehlo.power %2843, %cst_3 : tensor<8x100x4096xf32> | |
| %2845 = stablehlo.reduce(%2844 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2846 = stablehlo.multiply %2845, %cst_2 : tensor<8x100xf32> | |
| %2847 = stablehlo.reshape %2846 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2848 = stablehlo.add %2847, %cst_1 : tensor<8x100x1xf32> | |
| %2849 = stablehlo.rsqrt %2848 : tensor<8x100x1xf32> | |
| %2850 = stablehlo.reshape %2849 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2851 = stablehlo.broadcast_in_dim %2850, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2852 = stablehlo.multiply %2843, %2851 : tensor<8x100x4096xf32> | |
| %2853 = stablehlo.broadcast_in_dim %arg28, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2854 = stablehlo.multiply %2852, %2853 : tensor<8x100x4096xf32> | |
| %2855 = stablehlo.reshape %2854 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2856 = stablehlo.transpose %arg338, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2857 = stablehlo.dot_general %2855, %2856, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2858 = stablehlo.reshape %2857 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2859 = stablehlo.logistic %2858 : tensor<8x100x11008xf32> | |
| %2860 = stablehlo.multiply %2858, %2859 : tensor<8x100x11008xf32> | |
| %2861 = stablehlo.transpose %arg27, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2862 = stablehlo.dot_general %2855, %2861, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2863 = stablehlo.reshape %2862 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2864 = stablehlo.multiply %2860, %2863 : tensor<8x100x11008xf32> | |
| %2865 = stablehlo.reshape %2864 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2866 = stablehlo.transpose %arg26, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2867 = stablehlo.dot_general %2865, %2866, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2868 = stablehlo.reshape %2867 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2869 = stablehlo.add %2843, %2868 : tensor<8x100x4096xf32> | |
| %2870 = stablehlo.power %2869, %cst_3 : tensor<8x100x4096xf32> | |
| %2871 = stablehlo.reduce(%2870 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2872 = stablehlo.multiply %2871, %cst_2 : tensor<8x100xf32> | |
| %2873 = stablehlo.reshape %2872 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2874 = stablehlo.add %2873, %cst_1 : tensor<8x100x1xf32> | |
| %2875 = stablehlo.rsqrt %2874 : tensor<8x100x1xf32> | |
| %2876 = stablehlo.reshape %2875 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2877 = stablehlo.broadcast_in_dim %2876, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2878 = stablehlo.multiply %2869, %2877 : tensor<8x100x4096xf32> | |
| %2879 = stablehlo.broadcast_in_dim %arg25, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2880 = stablehlo.multiply %2878, %2879 : tensor<8x100x4096xf32> | |
| %2881 = stablehlo.reshape %2880 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2882 = stablehlo.transpose %arg342, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2883 = stablehlo.dot_general %2881, %2882, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2884 = stablehlo.reshape %2883 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2885 = stablehlo.transpose %2884, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2886 = stablehlo.reshape %2885 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2887 = stablehlo.slice %2886 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2888 = stablehlo.reshape %2887 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2889 = stablehlo.slice %2886 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2890 = stablehlo.reshape %2889 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2891 = stablehlo.complex %2888, %2890 : tensor<256x100x64xcomplex<f32>> | |
| %2892 = stablehlo.multiply %2891, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2893 = stablehlo.real %2892 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2894 = stablehlo.reshape %2893 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2895 = stablehlo.imag %2892 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2896 = stablehlo.reshape %2895 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2897 = stablehlo.concatenate %2894, %2896, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2898 = stablehlo.reshape %2897 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %2899 = stablehlo.transpose %arg340, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2900 = stablehlo.dot_general %2881, %2899, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2901 = stablehlo.reshape %2900 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2902 = stablehlo.transpose %2901, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2903 = stablehlo.reshape %2902 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2904 = stablehlo.slice %2903 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2905 = stablehlo.reshape %2904 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2906 = stablehlo.slice %2903 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2907 = stablehlo.reshape %2906 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2908 = stablehlo.complex %2905, %2907 : tensor<256x100x64xcomplex<f32>> | |
| %2909 = stablehlo.multiply %2908, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2910 = stablehlo.real %2909 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2911 = stablehlo.reshape %2910 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2912 = stablehlo.imag %2909 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2913 = stablehlo.reshape %2912 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2914 = stablehlo.concatenate %2911, %2913, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %2915 = stablehlo.reshape %2914 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %2916 = stablehlo.transpose %2915, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2917 = "stablehlo.scatter"(%arg341, %39, %2916) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2918 = stablehlo.transpose %2917, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %2919 = stablehlo.reshape %2918 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %2920 = stablehlo.dot_general %2898, %2919, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2921 = stablehlo.reshape %2920 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %2922 = stablehlo.divide %2921, %cst : tensor<8x32x100x1024xf32> | |
| %2923 = stablehlo.add %2922, %66 : tensor<8x32x100x1024xf32> | |
| %2924 = stablehlo.reduce(%2923 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2925 = stablehlo.broadcast_in_dim %2924, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2926 = stablehlo.subtract %2923, %2925 : tensor<8x32x100x1024xf32> | |
| %2927 = stablehlo.exponential %2926 : tensor<8x32x100x1024xf32> | |
| %2928 = stablehlo.reduce(%2927 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %2929 = stablehlo.broadcast_in_dim %2928, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %2930 = stablehlo.divide %2927, %2929 : tensor<8x32x100x1024xf32> | |
| %2931 = stablehlo.reshape %2930 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %2932 = stablehlo.transpose %arg24, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2933 = stablehlo.dot_general %2881, %2932, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2934 = stablehlo.reshape %2933 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2935 = "stablehlo.scatter"(%arg339, %39, %2934) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %2936 = stablehlo.transpose %2935, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %2937 = stablehlo.reshape %2936 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %2938 = stablehlo.dot_general %2931, %2937, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %2939 = stablehlo.reshape %2938 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2940 = stablehlo.transpose %2939, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %2941 = stablehlo.reshape %2940 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %2942 = stablehlo.transpose %arg23, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2943 = stablehlo.dot_general %2941, %2942, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2944 = stablehlo.reshape %2943 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2945 = stablehlo.add %2869, %2944 : tensor<8x100x4096xf32> | |
| %2946 = stablehlo.power %2945, %cst_3 : tensor<8x100x4096xf32> | |
| %2947 = stablehlo.reduce(%2946 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2948 = stablehlo.multiply %2947, %cst_2 : tensor<8x100xf32> | |
| %2949 = stablehlo.reshape %2948 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2950 = stablehlo.add %2949, %cst_1 : tensor<8x100x1xf32> | |
| %2951 = stablehlo.rsqrt %2950 : tensor<8x100x1xf32> | |
| %2952 = stablehlo.reshape %2951 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2953 = stablehlo.broadcast_in_dim %2952, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2954 = stablehlo.multiply %2945, %2953 : tensor<8x100x4096xf32> | |
| %2955 = stablehlo.broadcast_in_dim %arg22, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2956 = stablehlo.multiply %2954, %2955 : tensor<8x100x4096xf32> | |
| %2957 = stablehlo.reshape %2956 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2958 = stablehlo.transpose %arg343, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2959 = stablehlo.dot_general %2957, %2958, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2960 = stablehlo.reshape %2959 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2961 = stablehlo.logistic %2960 : tensor<8x100x11008xf32> | |
| %2962 = stablehlo.multiply %2960, %2961 : tensor<8x100x11008xf32> | |
| %2963 = stablehlo.transpose %arg21, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %2964 = stablehlo.dot_general %2957, %2963, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %2965 = stablehlo.reshape %2964 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %2966 = stablehlo.multiply %2962, %2965 : tensor<8x100x11008xf32> | |
| %2967 = stablehlo.reshape %2966 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %2968 = stablehlo.transpose %arg20, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %2969 = stablehlo.dot_general %2967, %2968, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %2970 = stablehlo.reshape %2969 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %2971 = stablehlo.add %2945, %2970 : tensor<8x100x4096xf32> | |
| %2972 = stablehlo.power %2971, %cst_3 : tensor<8x100x4096xf32> | |
| %2973 = stablehlo.reduce(%2972 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %2974 = stablehlo.multiply %2973, %cst_2 : tensor<8x100xf32> | |
| %2975 = stablehlo.reshape %2974 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %2976 = stablehlo.add %2975, %cst_1 : tensor<8x100x1xf32> | |
| %2977 = stablehlo.rsqrt %2976 : tensor<8x100x1xf32> | |
| %2978 = stablehlo.reshape %2977 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %2979 = stablehlo.broadcast_in_dim %2978, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %2980 = stablehlo.multiply %2971, %2979 : tensor<8x100x4096xf32> | |
| %2981 = stablehlo.broadcast_in_dim %arg19, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %2982 = stablehlo.multiply %2980, %2981 : tensor<8x100x4096xf32> | |
| %2983 = stablehlo.reshape %2982 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %2984 = stablehlo.transpose %arg347, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %2985 = stablehlo.dot_general %2983, %2984, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %2986 = stablehlo.reshape %2985 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %2987 = stablehlo.transpose %2986, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %2988 = stablehlo.reshape %2987 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %2989 = stablehlo.slice %2988 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2990 = stablehlo.reshape %2989 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2991 = stablehlo.slice %2988 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %2992 = stablehlo.reshape %2991 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %2993 = stablehlo.complex %2990, %2992 : tensor<256x100x64xcomplex<f32>> | |
| %2994 = stablehlo.multiply %2993, %28 : tensor<256x100x64xcomplex<f32>> | |
| %2995 = stablehlo.real %2994 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2996 = stablehlo.reshape %2995 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2997 = stablehlo.imag %2994 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %2998 = stablehlo.reshape %2997 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %2999 = stablehlo.concatenate %2996, %2998, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %3000 = stablehlo.reshape %2999 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %3001 = stablehlo.transpose %arg345, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3002 = stablehlo.dot_general %2983, %3001, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3003 = stablehlo.reshape %3002 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3004 = stablehlo.transpose %3003, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3005 = stablehlo.reshape %3004 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %3006 = stablehlo.slice %3005 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3007 = stablehlo.reshape %3006 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3008 = stablehlo.slice %3005 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3009 = stablehlo.reshape %3008 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3010 = stablehlo.complex %3007, %3009 : tensor<256x100x64xcomplex<f32>> | |
| %3011 = stablehlo.multiply %3010, %28 : tensor<256x100x64xcomplex<f32>> | |
| %3012 = stablehlo.real %3011 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3013 = stablehlo.reshape %3012 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3014 = stablehlo.imag %3011 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3015 = stablehlo.reshape %3014 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3016 = stablehlo.concatenate %3013, %3015, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %3017 = stablehlo.reshape %3016 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %3018 = stablehlo.transpose %3017, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %3019 = "stablehlo.scatter"(%arg346, %39, %3018) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %3020 = stablehlo.transpose %3019, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %3021 = stablehlo.reshape %3020 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %3022 = stablehlo.dot_general %3000, %3021, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %3023 = stablehlo.reshape %3022 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %3024 = stablehlo.divide %3023, %cst : tensor<8x32x100x1024xf32> | |
| %3025 = stablehlo.add %3024, %66 : tensor<8x32x100x1024xf32> | |
| %3026 = stablehlo.reduce(%3025 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %3027 = stablehlo.broadcast_in_dim %3026, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %3028 = stablehlo.subtract %3025, %3027 : tensor<8x32x100x1024xf32> | |
| %3029 = stablehlo.exponential %3028 : tensor<8x32x100x1024xf32> | |
| %3030 = stablehlo.reduce(%3029 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %3031 = stablehlo.broadcast_in_dim %3030, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %3032 = stablehlo.divide %3029, %3031 : tensor<8x32x100x1024xf32> | |
| %3033 = stablehlo.reshape %3032 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %3034 = stablehlo.transpose %arg18, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3035 = stablehlo.dot_general %2983, %3034, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3036 = stablehlo.reshape %3035 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3037 = "stablehlo.scatter"(%arg344, %39, %3036) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %3038 = stablehlo.transpose %3037, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %3039 = stablehlo.reshape %3038 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %3040 = stablehlo.dot_general %3033, %3039, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %3041 = stablehlo.reshape %3040 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3042 = stablehlo.transpose %3041, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %3043 = stablehlo.reshape %3042 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %3044 = stablehlo.transpose %arg17, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3045 = stablehlo.dot_general %3043, %3044, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3046 = stablehlo.reshape %3045 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %3047 = stablehlo.add %2971, %3046 : tensor<8x100x4096xf32> | |
| %3048 = stablehlo.power %3047, %cst_3 : tensor<8x100x4096xf32> | |
| %3049 = stablehlo.reduce(%3048 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %3050 = stablehlo.multiply %3049, %cst_2 : tensor<8x100xf32> | |
| %3051 = stablehlo.reshape %3050 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %3052 = stablehlo.add %3051, %cst_1 : tensor<8x100x1xf32> | |
| %3053 = stablehlo.rsqrt %3052 : tensor<8x100x1xf32> | |
| %3054 = stablehlo.reshape %3053 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %3055 = stablehlo.broadcast_in_dim %3054, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %3056 = stablehlo.multiply %3047, %3055 : tensor<8x100x4096xf32> | |
| %3057 = stablehlo.broadcast_in_dim %arg16, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %3058 = stablehlo.multiply %3056, %3057 : tensor<8x100x4096xf32> | |
| %3059 = stablehlo.reshape %3058 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %3060 = stablehlo.transpose %arg348, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %3061 = stablehlo.dot_general %3059, %3060, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %3062 = stablehlo.reshape %3061 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %3063 = stablehlo.logistic %3062 : tensor<8x100x11008xf32> | |
| %3064 = stablehlo.multiply %3062, %3063 : tensor<8x100x11008xf32> | |
| %3065 = stablehlo.transpose %arg15, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %3066 = stablehlo.dot_general %3059, %3065, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %3067 = stablehlo.reshape %3066 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %3068 = stablehlo.multiply %3064, %3067 : tensor<8x100x11008xf32> | |
| %3069 = stablehlo.reshape %3068 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %3070 = stablehlo.transpose %arg14, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %3071 = stablehlo.dot_general %3069, %3070, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %3072 = stablehlo.reshape %3071 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %3073 = stablehlo.add %3047, %3072 : tensor<8x100x4096xf32> | |
| %3074 = stablehlo.power %3073, %cst_3 : tensor<8x100x4096xf32> | |
| %3075 = stablehlo.reduce(%3074 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %3076 = stablehlo.multiply %3075, %cst_2 : tensor<8x100xf32> | |
| %3077 = stablehlo.reshape %3076 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %3078 = stablehlo.add %3077, %cst_1 : tensor<8x100x1xf32> | |
| %3079 = stablehlo.rsqrt %3078 : tensor<8x100x1xf32> | |
| %3080 = stablehlo.reshape %3079 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %3081 = stablehlo.broadcast_in_dim %3080, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %3082 = stablehlo.multiply %3073, %3081 : tensor<8x100x4096xf32> | |
| %3083 = stablehlo.broadcast_in_dim %arg13, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %3084 = stablehlo.multiply %3082, %3083 : tensor<8x100x4096xf32> | |
| %3085 = stablehlo.reshape %3084 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %3086 = stablehlo.transpose %arg352, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3087 = stablehlo.dot_general %3085, %3086, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3088 = stablehlo.reshape %3087 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3089 = stablehlo.transpose %3088, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3090 = stablehlo.reshape %3089 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %3091 = stablehlo.slice %3090 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3092 = stablehlo.reshape %3091 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3093 = stablehlo.slice %3090 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3094 = stablehlo.reshape %3093 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3095 = stablehlo.complex %3092, %3094 : tensor<256x100x64xcomplex<f32>> | |
| %3096 = stablehlo.multiply %3095, %28 : tensor<256x100x64xcomplex<f32>> | |
| %3097 = stablehlo.real %3096 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3098 = stablehlo.reshape %3097 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3099 = stablehlo.imag %3096 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3100 = stablehlo.reshape %3099 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3101 = stablehlo.concatenate %3098, %3100, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %3102 = stablehlo.reshape %3101 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %3103 = stablehlo.transpose %arg350, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3104 = stablehlo.dot_general %3085, %3103, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3105 = stablehlo.reshape %3104 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3106 = stablehlo.transpose %3105, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3107 = stablehlo.reshape %3106 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %3108 = stablehlo.slice %3107 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3109 = stablehlo.reshape %3108 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3110 = stablehlo.slice %3107 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3111 = stablehlo.reshape %3110 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3112 = stablehlo.complex %3109, %3111 : tensor<256x100x64xcomplex<f32>> | |
| %3113 = stablehlo.multiply %3112, %28 : tensor<256x100x64xcomplex<f32>> | |
| %3114 = stablehlo.real %3113 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3115 = stablehlo.reshape %3114 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3116 = stablehlo.imag %3113 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3117 = stablehlo.reshape %3116 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3118 = stablehlo.concatenate %3115, %3117, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %3119 = stablehlo.reshape %3118 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %3120 = stablehlo.transpose %3119, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %3121 = "stablehlo.scatter"(%arg351, %39, %3120) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %3122 = stablehlo.transpose %3121, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %3123 = stablehlo.reshape %3122 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %3124 = stablehlo.dot_general %3102, %3123, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %3125 = stablehlo.reshape %3124 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %3126 = stablehlo.divide %3125, %cst : tensor<8x32x100x1024xf32> | |
| %3127 = stablehlo.add %3126, %66 : tensor<8x32x100x1024xf32> | |
| %3128 = stablehlo.reduce(%3127 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %3129 = stablehlo.broadcast_in_dim %3128, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %3130 = stablehlo.subtract %3127, %3129 : tensor<8x32x100x1024xf32> | |
| %3131 = stablehlo.exponential %3130 : tensor<8x32x100x1024xf32> | |
| %3132 = stablehlo.reduce(%3131 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %3133 = stablehlo.broadcast_in_dim %3132, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %3134 = stablehlo.divide %3131, %3133 : tensor<8x32x100x1024xf32> | |
| %3135 = stablehlo.reshape %3134 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %3136 = stablehlo.transpose %arg12, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3137 = stablehlo.dot_general %3085, %3136, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3138 = stablehlo.reshape %3137 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3139 = "stablehlo.scatter"(%arg349, %39, %3138) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %3140 = stablehlo.transpose %3139, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %3141 = stablehlo.reshape %3140 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %3142 = stablehlo.dot_general %3135, %3141, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %3143 = stablehlo.reshape %3142 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3144 = stablehlo.transpose %3143, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %3145 = stablehlo.reshape %3144 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %3146 = stablehlo.transpose %arg11, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3147 = stablehlo.dot_general %3145, %3146, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3148 = stablehlo.reshape %3147 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %3149 = stablehlo.add %3073, %3148 : tensor<8x100x4096xf32> | |
| %3150 = stablehlo.power %3149, %cst_3 : tensor<8x100x4096xf32> | |
| %3151 = stablehlo.reduce(%3150 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %3152 = stablehlo.multiply %3151, %cst_2 : tensor<8x100xf32> | |
| %3153 = stablehlo.reshape %3152 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %3154 = stablehlo.add %3153, %cst_1 : tensor<8x100x1xf32> | |
| %3155 = stablehlo.rsqrt %3154 : tensor<8x100x1xf32> | |
| %3156 = stablehlo.reshape %3155 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %3157 = stablehlo.broadcast_in_dim %3156, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %3158 = stablehlo.multiply %3149, %3157 : tensor<8x100x4096xf32> | |
| %3159 = stablehlo.broadcast_in_dim %arg10, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %3160 = stablehlo.multiply %3158, %3159 : tensor<8x100x4096xf32> | |
| %3161 = stablehlo.reshape %3160 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %3162 = stablehlo.transpose %arg353, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %3163 = stablehlo.dot_general %3161, %3162, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %3164 = stablehlo.reshape %3163 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %3165 = stablehlo.logistic %3164 : tensor<8x100x11008xf32> | |
| %3166 = stablehlo.multiply %3164, %3165 : tensor<8x100x11008xf32> | |
| %3167 = stablehlo.transpose %arg9, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %3168 = stablehlo.dot_general %3161, %3167, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %3169 = stablehlo.reshape %3168 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %3170 = stablehlo.multiply %3166, %3169 : tensor<8x100x11008xf32> | |
| %3171 = stablehlo.reshape %3170 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %3172 = stablehlo.transpose %arg8, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %3173 = stablehlo.dot_general %3171, %3172, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %3174 = stablehlo.reshape %3173 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %3175 = stablehlo.add %3149, %3174 : tensor<8x100x4096xf32> | |
| %3176 = stablehlo.power %3175, %cst_3 : tensor<8x100x4096xf32> | |
| %3177 = stablehlo.reduce(%3176 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %3178 = stablehlo.multiply %3177, %cst_2 : tensor<8x100xf32> | |
| %3179 = stablehlo.reshape %3178 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %3180 = stablehlo.add %3179, %cst_1 : tensor<8x100x1xf32> | |
| %3181 = stablehlo.rsqrt %3180 : tensor<8x100x1xf32> | |
| %3182 = stablehlo.reshape %3181 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %3183 = stablehlo.broadcast_in_dim %3182, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %3184 = stablehlo.multiply %3175, %3183 : tensor<8x100x4096xf32> | |
| %3185 = stablehlo.broadcast_in_dim %arg7, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %3186 = stablehlo.multiply %3184, %3185 : tensor<8x100x4096xf32> | |
| %3187 = stablehlo.reshape %3186 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %3188 = stablehlo.transpose %arg357, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3189 = stablehlo.dot_general %3187, %3188, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3190 = stablehlo.reshape %3189 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3191 = stablehlo.transpose %3190, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3192 = stablehlo.reshape %3191 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %3193 = stablehlo.slice %3192 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3194 = stablehlo.reshape %3193 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3195 = stablehlo.slice %3192 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3196 = stablehlo.reshape %3195 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3197 = stablehlo.complex %3194, %3196 : tensor<256x100x64xcomplex<f32>> | |
| %3198 = stablehlo.multiply %3197, %28 : tensor<256x100x64xcomplex<f32>> | |
| %3199 = stablehlo.real %3198 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3200 = stablehlo.reshape %3199 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3201 = stablehlo.imag %3198 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3202 = stablehlo.reshape %3201 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3203 = stablehlo.concatenate %3200, %3202, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %3204 = stablehlo.reshape %3203 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32> | |
| %3205 = stablehlo.transpose %arg355, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3206 = stablehlo.dot_general %3187, %3205, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3207 = stablehlo.reshape %3206 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3208 = stablehlo.transpose %3207, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3209 = stablehlo.reshape %3208 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32> | |
| %3210 = stablehlo.slice %3209 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3211 = stablehlo.reshape %3210 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3212 = stablehlo.slice %3209 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32> | |
| %3213 = stablehlo.reshape %3212 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32> | |
| %3214 = stablehlo.complex %3211, %3213 : tensor<256x100x64xcomplex<f32>> | |
| %3215 = stablehlo.multiply %3214, %28 : tensor<256x100x64xcomplex<f32>> | |
| %3216 = stablehlo.real %3215 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3217 = stablehlo.reshape %3216 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3218 = stablehlo.imag %3215 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32> | |
| %3219 = stablehlo.reshape %3218 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32> | |
| %3220 = stablehlo.concatenate %3217, %3219, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32> | |
| %3221 = stablehlo.reshape %3220 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32> | |
| %3222 = stablehlo.transpose %3221, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %3223 = "stablehlo.scatter"(%arg356, %39, %3222) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %3224 = stablehlo.transpose %3223, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32> | |
| %3225 = stablehlo.reshape %3224 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32> | |
| %3226 = stablehlo.dot_general %3204, %3225, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32> | |
| %3227 = stablehlo.reshape %3226 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32> | |
| %3228 = stablehlo.divide %3227, %cst : tensor<8x32x100x1024xf32> | |
| %3229 = stablehlo.add %3228, %66 : tensor<8x32x100x1024xf32> | |
| %3230 = stablehlo.reduce(%3229 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %3231 = stablehlo.broadcast_in_dim %3230, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %3232 = stablehlo.subtract %3229, %3231 : tensor<8x32x100x1024xf32> | |
| %3233 = stablehlo.exponential %3232 : tensor<8x32x100x1024xf32> | |
| %3234 = stablehlo.reduce(%3233 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32> | |
| %3235 = stablehlo.broadcast_in_dim %3234, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32> | |
| %3236 = stablehlo.divide %3233, %3235 : tensor<8x32x100x1024xf32> | |
| %3237 = stablehlo.reshape %3236 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32> | |
| %3238 = stablehlo.transpose %arg6, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3239 = stablehlo.dot_general %3187, %3238, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3240 = stablehlo.reshape %3239 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32> | |
| %3241 = "stablehlo.scatter"(%arg354, %39, %3240) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({ | |
| ^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>): | |
| stablehlo.return %arg360 : tensor<f32> | |
| }) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32> | |
| %3242 = stablehlo.transpose %3241, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32> | |
| %3243 = stablehlo.reshape %3242 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32> | |
| %3244 = stablehlo.dot_general %3237, %3243, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32> | |
| %3245 = stablehlo.reshape %3244 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32> | |
| %3246 = stablehlo.transpose %3245, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32> | |
| %3247 = stablehlo.reshape %3246 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32> | |
| %3248 = stablehlo.transpose %arg5, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32> | |
| %3249 = stablehlo.dot_general %3247, %3248, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32> | |
| %3250 = stablehlo.reshape %3249 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %3251 = stablehlo.add %3175, %3250 : tensor<8x100x4096xf32> | |
| %3252 = stablehlo.power %3251, %cst_3 : tensor<8x100x4096xf32> | |
| %3253 = stablehlo.reduce(%3252 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %3254 = stablehlo.multiply %3253, %cst_2 : tensor<8x100xf32> | |
| %3255 = stablehlo.reshape %3254 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %3256 = stablehlo.add %3255, %cst_1 : tensor<8x100x1xf32> | |
| %3257 = stablehlo.rsqrt %3256 : tensor<8x100x1xf32> | |
| %3258 = stablehlo.reshape %3257 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %3259 = stablehlo.broadcast_in_dim %3258, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %3260 = stablehlo.multiply %3251, %3259 : tensor<8x100x4096xf32> | |
| %3261 = stablehlo.broadcast_in_dim %arg4, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %3262 = stablehlo.multiply %3260, %3261 : tensor<8x100x4096xf32> | |
| %3263 = stablehlo.reshape %3262 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %3264 = stablehlo.transpose %arg358, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %3265 = stablehlo.dot_general %3263, %3264, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %3266 = stablehlo.reshape %3265 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %3267 = stablehlo.logistic %3266 : tensor<8x100x11008xf32> | |
| %3268 = stablehlo.multiply %3266, %3267 : tensor<8x100x11008xf32> | |
| %3269 = stablehlo.transpose %arg3, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32> | |
| %3270 = stablehlo.dot_general %3263, %3269, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32> | |
| %3271 = stablehlo.reshape %3270 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32> | |
| %3272 = stablehlo.multiply %3268, %3271 : tensor<8x100x11008xf32> | |
| %3273 = stablehlo.reshape %3272 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32> | |
| %3274 = stablehlo.transpose %arg2, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32> | |
| %3275 = stablehlo.dot_general %3273, %3274, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32> | |
| %3276 = stablehlo.reshape %3275 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32> | |
| %3277 = stablehlo.add %3251, %3276 : tensor<8x100x4096xf32> | |
| %3278 = stablehlo.power %3277, %cst_3 : tensor<8x100x4096xf32> | |
| %3279 = stablehlo.reduce(%3278 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32> | |
| %3280 = stablehlo.multiply %3279, %cst_2 : tensor<8x100xf32> | |
| %3281 = stablehlo.reshape %3280 : (tensor<8x100xf32>) -> tensor<8x100x1xf32> | |
| %3282 = stablehlo.add %3281, %cst_1 : tensor<8x100x1xf32> | |
| %3283 = stablehlo.rsqrt %3282 : tensor<8x100x1xf32> | |
| %3284 = stablehlo.reshape %3283 : (tensor<8x100x1xf32>) -> tensor<8x100xf32> | |
| %3285 = stablehlo.broadcast_in_dim %3284, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32> | |
| %3286 = stablehlo.multiply %3277, %3285 : tensor<8x100x4096xf32> | |
| %3287 = stablehlo.broadcast_in_dim %arg1, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32> | |
| %3288 = stablehlo.multiply %3286, %3287 : tensor<8x100x4096xf32> | |
| %3289 = stablehlo.reshape %3288 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32> | |
| %3290 = stablehlo.transpose %arg0, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,32000]{0,1}"} : (tensor<32000x4096xf32>) -> tensor<4096x32000xf32> | |
| %3291 = stablehlo.dot_general %3289, %3290, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x32000xf32>) -> tensor<800x32000xf32> | |
| %3292 = stablehlo.reshape %3291 : (tensor<800x32000xf32>) -> tensor<8x100x32000xf32> | |
| return %3292 : tensor<8x100x32000xf32> | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment