Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created March 3, 2023 15:38
Show Gist options
  • Save pashu123/8cba884c52bee252dfbb31a8d5672504 to your computer and use it in GitHub Desktop.
Save pashu123/8cba884c52bee252dfbb31a8d5672504 to your computer and use it in GitHub Desktop.
func.func @forward(%arg0: tensor<512xf32>, %arg1: tensor<512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512xf32>, %arg7: tensor<512xf32>, %arg8: tensor<512xf32>, %arg9: tensor<512xf32>, %arg10: tensor<512xf32>, %arg11: tensor<512xf32>, %arg12: tensor<512xf32>, %arg13: tensor<512xf32>, %arg14: tensor<512xf32>, %arg15: tensor<512xf32>, %arg16: tensor<512xf32>, %arg17: tensor<1x512xf32>, %arg18: tensor<512x512xf32>, %arg19: tensor<512x512xf32>, %arg20: tensor<512x512xf32>, %arg21: tensor<512x512xf32>, %arg22: tensor<1536x512xf32>, %arg23: tensor<1536x512xf32>, %arg24: tensor<512x1536xf32>, %arg25: tensor<512x512xf32>, %arg26: tensor<512x512xf32>, %arg27: tensor<512x512xf32>, %arg28: tensor<512x512xf32>, %arg29: tensor<1536x512xf32>, %arg30: tensor<1536x512xf32>, %arg31: tensor<512x1536xf32>, %arg32: tensor<512x512xf32>, %arg33: tensor<512x512xf32>, %arg34: tensor<512x512xf32>, %arg35: tensor<512x512xf32>, %arg36: tensor<1536x512xf32>, %arg37: tensor<1536x512xf32>, %arg38: tensor<512x1536xf32>, %arg39: tensor<512x512xf32>, %arg40: tensor<512x512xf32>, %arg41: tensor<512x512xf32>, %arg42: tensor<512x512xf32>, %arg43: tensor<1536x512xf32>, %arg44: tensor<1536x512xf32>, %arg45: tensor<512x1536xf32>, %arg46: tensor<512x512xf32>, %arg47: tensor<512x512xf32>, %arg48: tensor<512x512xf32>, %arg49: tensor<512x512xf32>, %arg50: tensor<1536x512xf32>, %arg51: tensor<1536x512xf32>, %arg52: tensor<512x1536xf32>, %arg53: tensor<512x512xf32>, %arg54: tensor<512x512xf32>, %arg55: tensor<512x512xf32>, %arg56: tensor<512x512xf32>, %arg57: tensor<1536x512xf32>, %arg58: tensor<1536x512xf32>, %arg59: tensor<512x1536xf32>, %arg60: tensor<512x512xf32>, %arg61: tensor<512x512xf32>, %arg62: tensor<512x512xf32>, %arg63: tensor<512x512xf32>, %arg64: tensor<1536x512xf32>, %arg65: tensor<1536x512xf32>, %arg66: tensor<512x1536xf32>, %arg67: tensor<512x512xf32>, %arg68: tensor<512x512xf32>, %arg69: tensor<512x512xf32>, %arg70: tensor<512x512xf32>, %arg71: tensor<1536x512xf32>, %arg72: tensor<1536x512xf32>, %arg73: tensor<512x1536xf32>, %arg74: tensor<1x512xf32>, %arg75: tensor<1x1xi64>, %arg76: tensor<2048x32x2xf32>) -> (tensor<1x1xf32>, tensor<2048x32x2xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant -3.40282347E+38 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%c0_i64 = arith.constant 0 : i64
%cst_2 = arith.constant 1.000000e-05 : f64
%cst_3 = arith.constant 2.000000e+00 : f32
%cst_4 = arith.constant 5.120000e+02 : f32
%cst_5 = arith.constant 8.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = torch_c.from_builtin_tensor %arg76 : tensor<2048x32x2xf32> -> !torch.vtensor<[2048,32,2],f32>
%1 = torch_c.from_builtin_tensor %arg75 : tensor<1x1xi64> -> !torch.vtensor<[1,1],si64>
%2 = torch_c.from_builtin_tensor %arg74 : tensor<1x512xf32> -> !torch.vtensor<[1,512],f32>
%3 = torch_c.from_builtin_tensor %arg73 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%4 = torch_c.from_builtin_tensor %arg72 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%5 = torch_c.from_builtin_tensor %arg71 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%6 = torch_c.from_builtin_tensor %arg70 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%7 = torch_c.from_builtin_tensor %arg69 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%8 = torch_c.from_builtin_tensor %arg68 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%9 = torch_c.from_builtin_tensor %arg67 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%10 = torch_c.from_builtin_tensor %arg66 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%11 = torch_c.from_builtin_tensor %arg65 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%12 = torch_c.from_builtin_tensor %arg64 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%13 = torch_c.from_builtin_tensor %arg63 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%14 = torch_c.from_builtin_tensor %arg62 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%15 = torch_c.from_builtin_tensor %arg61 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%16 = torch_c.from_builtin_tensor %arg60 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%17 = torch_c.from_builtin_tensor %arg59 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%18 = torch_c.from_builtin_tensor %arg58 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%19 = torch_c.from_builtin_tensor %arg57 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%20 = torch_c.from_builtin_tensor %arg56 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%21 = torch_c.from_builtin_tensor %arg55 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%22 = torch_c.from_builtin_tensor %arg54 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%23 = torch_c.from_builtin_tensor %arg53 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%24 = torch_c.from_builtin_tensor %arg52 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%25 = torch_c.from_builtin_tensor %arg51 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%26 = torch_c.from_builtin_tensor %arg50 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%27 = torch_c.from_builtin_tensor %arg49 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%28 = torch_c.from_builtin_tensor %arg48 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%29 = torch_c.from_builtin_tensor %arg47 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%30 = torch_c.from_builtin_tensor %arg46 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%31 = torch_c.from_builtin_tensor %arg45 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%32 = torch_c.from_builtin_tensor %arg44 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%33 = torch_c.from_builtin_tensor %arg43 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%34 = torch_c.from_builtin_tensor %arg42 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%35 = torch_c.from_builtin_tensor %arg41 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%36 = torch_c.from_builtin_tensor %arg40 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%37 = torch_c.from_builtin_tensor %arg39 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%38 = torch_c.from_builtin_tensor %arg38 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%39 = torch_c.from_builtin_tensor %arg37 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%40 = torch_c.from_builtin_tensor %arg36 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%41 = torch_c.from_builtin_tensor %arg35 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%42 = torch_c.from_builtin_tensor %arg34 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%43 = torch_c.from_builtin_tensor %arg33 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%44 = torch_c.from_builtin_tensor %arg32 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%45 = torch_c.from_builtin_tensor %arg31 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%46 = torch_c.from_builtin_tensor %arg30 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%47 = torch_c.from_builtin_tensor %arg29 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%48 = torch_c.from_builtin_tensor %arg28 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%49 = torch_c.from_builtin_tensor %arg27 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%50 = torch_c.from_builtin_tensor %arg26 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%51 = torch_c.from_builtin_tensor %arg25 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%52 = torch_c.from_builtin_tensor %arg24 : tensor<512x1536xf32> -> !torch.vtensor<[512,1536],f32>
%53 = torch_c.from_builtin_tensor %arg23 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%54 = torch_c.from_builtin_tensor %arg22 : tensor<1536x512xf32> -> !torch.vtensor<[1536,512],f32>
%55 = torch_c.from_builtin_tensor %arg21 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%56 = torch_c.from_builtin_tensor %arg20 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%57 = torch_c.from_builtin_tensor %arg19 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%58 = torch_c.from_builtin_tensor %arg18 : tensor<512x512xf32> -> !torch.vtensor<[512,512],f32>
%59 = torch_c.from_builtin_tensor %arg17 : tensor<1x512xf32> -> !torch.vtensor<[1,512],f32>
%60 = torch_c.from_builtin_tensor %arg16 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%61 = torch_c.from_builtin_tensor %arg15 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%62 = torch_c.from_builtin_tensor %arg14 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%63 = torch_c.from_builtin_tensor %arg13 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%64 = torch_c.from_builtin_tensor %arg12 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%65 = torch_c.from_builtin_tensor %arg11 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%66 = torch_c.from_builtin_tensor %arg10 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%67 = torch_c.from_builtin_tensor %arg9 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%68 = torch_c.from_builtin_tensor %arg8 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%69 = torch_c.from_builtin_tensor %arg7 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%70 = torch_c.from_builtin_tensor %arg6 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%71 = torch_c.from_builtin_tensor %arg5 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%72 = torch_c.from_builtin_tensor %arg4 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%73 = torch_c.from_builtin_tensor %arg3 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%74 = torch_c.from_builtin_tensor %arg2 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%75 = torch_c.from_builtin_tensor %arg1 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%76 = torch_c.from_builtin_tensor %arg0 : tensor<512xf32> -> !torch.vtensor<[512],f32>
%77 = torch_c.to_builtin_tensor %59 : !torch.vtensor<[1,512],f32> -> tensor<1x512xf32>
%78 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[1,1],si64> -> tensor<1x1xi64>
%79 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[2048,32,2],f32> -> tensor<2048x32x2xf32>
%80 = torch_c.to_builtin_tensor %76 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%81 = torch_c.to_builtin_tensor %58 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%82 = torch_c.to_builtin_tensor %57 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%83 = torch_c.to_builtin_tensor %56 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%84 = torch_c.to_builtin_tensor %55 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%85 = torch_c.to_builtin_tensor %75 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%86 = torch_c.to_builtin_tensor %54 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%87 = torch_c.to_builtin_tensor %53 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%88 = torch_c.to_builtin_tensor %52 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%89 = torch_c.to_builtin_tensor %74 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%90 = torch_c.to_builtin_tensor %51 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%91 = torch_c.to_builtin_tensor %50 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%92 = torch_c.to_builtin_tensor %49 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%93 = torch_c.to_builtin_tensor %48 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%94 = torch_c.to_builtin_tensor %73 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%95 = torch_c.to_builtin_tensor %47 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%96 = torch_c.to_builtin_tensor %46 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%97 = torch_c.to_builtin_tensor %45 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%98 = torch_c.to_builtin_tensor %72 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%99 = torch_c.to_builtin_tensor %44 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%100 = torch_c.to_builtin_tensor %43 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%101 = torch_c.to_builtin_tensor %42 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%102 = torch_c.to_builtin_tensor %41 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%103 = torch_c.to_builtin_tensor %71 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%104 = torch_c.to_builtin_tensor %40 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%105 = torch_c.to_builtin_tensor %39 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%106 = torch_c.to_builtin_tensor %38 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%107 = torch_c.to_builtin_tensor %70 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%108 = torch_c.to_builtin_tensor %37 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%109 = torch_c.to_builtin_tensor %36 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%110 = torch_c.to_builtin_tensor %35 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%111 = torch_c.to_builtin_tensor %34 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%112 = torch_c.to_builtin_tensor %69 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%113 = torch_c.to_builtin_tensor %33 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%114 = torch_c.to_builtin_tensor %32 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%115 = torch_c.to_builtin_tensor %31 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%116 = torch_c.to_builtin_tensor %68 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%117 = torch_c.to_builtin_tensor %30 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%118 = torch_c.to_builtin_tensor %29 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%119 = torch_c.to_builtin_tensor %28 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%120 = torch_c.to_builtin_tensor %27 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%121 = torch_c.to_builtin_tensor %67 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%122 = torch_c.to_builtin_tensor %26 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%123 = torch_c.to_builtin_tensor %25 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%124 = torch_c.to_builtin_tensor %24 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%125 = torch_c.to_builtin_tensor %66 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%126 = torch_c.to_builtin_tensor %23 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%127 = torch_c.to_builtin_tensor %22 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%128 = torch_c.to_builtin_tensor %21 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%129 = torch_c.to_builtin_tensor %20 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%130 = torch_c.to_builtin_tensor %65 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%131 = torch_c.to_builtin_tensor %19 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%132 = torch_c.to_builtin_tensor %18 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%133 = torch_c.to_builtin_tensor %17 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%134 = torch_c.to_builtin_tensor %64 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%135 = torch_c.to_builtin_tensor %16 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%136 = torch_c.to_builtin_tensor %15 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%137 = torch_c.to_builtin_tensor %14 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%138 = torch_c.to_builtin_tensor %13 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%139 = torch_c.to_builtin_tensor %63 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%140 = torch_c.to_builtin_tensor %12 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%141 = torch_c.to_builtin_tensor %11 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%142 = torch_c.to_builtin_tensor %10 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%143 = torch_c.to_builtin_tensor %62 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%144 = torch_c.to_builtin_tensor %9 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%145 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%146 = torch_c.to_builtin_tensor %7 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%147 = torch_c.to_builtin_tensor %6 : !torch.vtensor<[512,512],f32> -> tensor<512x512xf32>
%148 = torch_c.to_builtin_tensor %61 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%149 = torch_c.to_builtin_tensor %5 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%150 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[1536,512],f32> -> tensor<1536x512xf32>
%151 = torch_c.to_builtin_tensor %3 : !torch.vtensor<[512,1536],f32> -> tensor<512x1536xf32>
%152 = torch_c.to_builtin_tensor %60 : !torch.vtensor<[512],f32> -> tensor<512xf32>
%153 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[1,512],f32> -> tensor<1x512xf32>
%154 = tensor.empty() : tensor<1x1x512xf32>
%155 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%78 : tensor<1x1xi64>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: i64, %out: f32):
%788 = arith.index_cast %in : i64 to index
%789 = linalg.index 2 : index
%790 = arith.cmpi slt, %788, %c1 : index
cf.assert %790, "index must be smaller than dim size"
%791 = arith.cmpi sge, %in, %c0_i64 : i64
cf.assert %791, "index must be larger or equal to 0"
%extracted = tensor.extract %77[%788, %789] : tensor<1x512xf32>
linalg.yield %extracted : f32
} -> tensor<1x1x512xf32>
%extracted_slice = tensor.extract_slice %79[0, 0, 0] [1, 32, 2] [1, 1, 1] : tensor<2048x32x2xf32> to tensor<1x32x2xf32>
%156 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%155 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%157 = tensor.empty() : tensor<1x1x1xf32>
%158 = linalg.fill ins(%cst : f32) outs(%157 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
%159 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%156 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%160 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%159 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%161 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%160 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%162 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%161 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%163 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%155, %162 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%164 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%163, %80 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%165 = tensor.empty() : tensor<512x512xf32>
%166 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%81 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed = tensor.collapse_shape %164 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%167 = tensor.empty() : tensor<1x512xf32>
%168 = linalg.fill ins(%cst : f32) outs(%167 : tensor<1x512xf32>) -> tensor<1x512xf32>
%169 = linalg.matmul ins(%collapsed, %166 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%170 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%82 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%171 = linalg.matmul ins(%collapsed, %170 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%172 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%83 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%173 = linalg.matmul ins(%collapsed, %172 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded = tensor.expand_shape %173 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_6 = tensor.expand_shape %169 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%174 = tensor.empty() : tensor<1x1x8x32xcomplex<f64>>
%175 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_6[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_6[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_7 = tensor.expand_shape %171 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%176 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_7[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_7[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%177 = tensor.empty() : tensor<1x32xcomplex<f64>>
%178 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%177 : tensor<1x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%extracted = tensor.extract %extracted_slice[%788, %789, %c0] : tensor<1x32x2xf32>
%extracted_269 = tensor.extract %extracted_slice[%788, %789, %c1] : tensor<1x32x2xf32>
%790 = arith.extf %extracted : f32 to f64
%791 = arith.extf %extracted_269 : f32 to f64
%792 = complex.create %790, %791 : complex<f64>
linalg.yield %792 : complex<f64>
} -> tensor<1x32xcomplex<f64>>
%expanded_8 = tensor.expand_shape %178 [[0], [1, 2, 3]] : tensor<1x32xcomplex<f64>> into tensor<1x1x1x32xcomplex<f64>>
%179 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%175, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%180 = torch_c.from_builtin_tensor %179 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%181 = torch.aten.view_as_real %180 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%182 = torch_c.to_builtin_tensor %181 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_9 = tensor.extract_slice %182[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%183 = torch.aten.view_as_real %180 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%184 = torch_c.to_builtin_tensor %183 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_10 = tensor.extract_slice %184[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_11 = tensor.collapse_shape %extracted_slice_9 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%185 = torch_c.from_builtin_tensor %collapsed_11 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_12 = tensor.collapse_shape %extracted_slice_10 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%186 = torch_c.from_builtin_tensor %collapsed_12 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%187 = torch_c.to_builtin_tensor %185 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%188 = torch_c.to_builtin_tensor %186 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%189 = tensor.empty() : tensor<256x2xf32>
%inserted_slice = tensor.insert_slice %187 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_13 = tensor.insert_slice %188 into %inserted_slice[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_14 = tensor.expand_shape %inserted_slice_13 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_15 = tensor.collapse_shape %expanded_14 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%190 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%176, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%191 = torch_c.from_builtin_tensor %190 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%192 = torch.aten.view_as_real %191 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%193 = torch_c.to_builtin_tensor %192 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_16 = tensor.extract_slice %193[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%194 = torch.aten.view_as_real %191 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%195 = torch_c.to_builtin_tensor %194 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_17 = tensor.extract_slice %195[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_18 = tensor.collapse_shape %extracted_slice_16 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%196 = torch_c.from_builtin_tensor %collapsed_18 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_19 = tensor.collapse_shape %extracted_slice_17 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%197 = torch_c.from_builtin_tensor %collapsed_19 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%198 = torch_c.to_builtin_tensor %196 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%199 = torch_c.to_builtin_tensor %197 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_20 = tensor.insert_slice %198 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_21 = tensor.insert_slice %199 into %inserted_slice_20[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_22 = tensor.expand_shape %inserted_slice_21 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_23 = tensor.collapse_shape %expanded_22 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%200 = tensor.empty() : tensor<1x1x8x64xf32>
%201 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_23 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%202 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%203 = tensor.empty() : tensor<1x8x1x64xf32>
%204 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_15 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%205 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%204 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_24 = tensor.collapse_shape %205 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%206 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%201 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%207 = tensor.empty() : tensor<1x8x64x1xf32>
%208 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%206 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%209 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%208 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_25 = tensor.collapse_shape %209 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%210 = tensor.empty() : tensor<8x1x1xf32>
%211 = linalg.fill ins(%cst : f32) outs(%210 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%212 = linalg.batch_matmul ins(%collapsed_24, %collapsed_25 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_26 = tensor.expand_shape %212 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%213 = tensor.empty() : tensor<1x8x1x1xf32>
%214 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_26 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%215 = tensor.empty() : tensor<1x8x1x1xi64>
%216 = linalg.fill ins(%c0_i64 : i64) outs(%215 : tensor<1x8x1x1xi64>) -> tensor<1x8x1x1xi64>
%217 = linalg.fill ins(%cst_0 : f32) outs(%213 : tensor<1x8x1x1xf32>) -> tensor<1x8x1x1xf32>
%218:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%214 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%219 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%214, %218#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%220 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%219 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%221 = linalg.fill ins(%cst : f32) outs(%213 : tensor<1x8x1x1xf32>) -> tensor<1x8x1x1xf32>
%222 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%220 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%223 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%220, %222 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%224 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%223 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_27 = tensor.collapse_shape %224 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%225 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%202 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%226 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%225 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_28 = tensor.collapse_shape %226 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%227 = tensor.empty() : tensor<8x1x64xf32>
%228 = linalg.fill ins(%cst : f32) outs(%227 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%229 = linalg.batch_matmul ins(%collapsed_27, %collapsed_28 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_29 = tensor.expand_shape %229 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%230 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_29 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%231 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%84 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_30 = tensor.collapse_shape %230 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%232 = linalg.matmul ins(%collapsed_30, %231 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_31 = tensor.expand_shape %232 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%233 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%155, %expanded_31 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%234 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%233 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%235 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%234 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%236 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%235 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%237 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%236 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%238 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%237 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%239 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%233, %238 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%240 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%239, %85 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%241 = tensor.empty() : tensor<512x1536xf32>
%242 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%86 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_32 = tensor.collapse_shape %240 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%243 = tensor.empty() : tensor<1x1536xf32>
%244 = linalg.fill ins(%cst : f32) outs(%243 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%245 = linalg.matmul ins(%collapsed_32, %242 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_33 = tensor.expand_shape %245 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%246 = tensor.empty() : tensor<1x1x1536xf32>
%247 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_33 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%248 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%247, %expanded_33 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%249 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%87 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%250 = linalg.matmul ins(%collapsed_32, %249 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_34 = tensor.expand_shape %250 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%251 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%248, %expanded_34 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%252 = tensor.empty() : tensor<1536x512xf32>
%253 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%88 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_35 = tensor.collapse_shape %251 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%254 = linalg.matmul ins(%collapsed_35, %253 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_36 = tensor.expand_shape %254 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%255 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%233, %expanded_36 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%256 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%255 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%257 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%256 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%258 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%257 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%259 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%258 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%260 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%259 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%261 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%255, %260 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%262 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%261, %89 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%263 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%90 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_37 = tensor.collapse_shape %262 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%264 = linalg.matmul ins(%collapsed_37, %263 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%265 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%91 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%266 = linalg.matmul ins(%collapsed_37, %265 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%267 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%92 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%268 = linalg.matmul ins(%collapsed_37, %267 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_38 = tensor.expand_shape %268 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_39 = tensor.expand_shape %264 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%269 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_39[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_39[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_40 = tensor.expand_shape %266 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%270 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_40[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_40[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%271 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%269, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%272 = torch_c.from_builtin_tensor %271 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%273 = torch.aten.view_as_real %272 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%274 = torch_c.to_builtin_tensor %273 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_41 = tensor.extract_slice %274[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%275 = torch.aten.view_as_real %272 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%276 = torch_c.to_builtin_tensor %275 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_42 = tensor.extract_slice %276[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_43 = tensor.collapse_shape %extracted_slice_41 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%277 = torch_c.from_builtin_tensor %collapsed_43 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_44 = tensor.collapse_shape %extracted_slice_42 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%278 = torch_c.from_builtin_tensor %collapsed_44 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%279 = torch_c.to_builtin_tensor %277 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%280 = torch_c.to_builtin_tensor %278 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_45 = tensor.insert_slice %279 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_46 = tensor.insert_slice %280 into %inserted_slice_45[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_47 = tensor.expand_shape %inserted_slice_46 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_48 = tensor.collapse_shape %expanded_47 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%281 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%270, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%282 = torch_c.from_builtin_tensor %281 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%283 = torch.aten.view_as_real %282 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%284 = torch_c.to_builtin_tensor %283 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_49 = tensor.extract_slice %284[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%285 = torch.aten.view_as_real %282 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%286 = torch_c.to_builtin_tensor %285 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_50 = tensor.extract_slice %286[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_51 = tensor.collapse_shape %extracted_slice_49 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%287 = torch_c.from_builtin_tensor %collapsed_51 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_52 = tensor.collapse_shape %extracted_slice_50 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%288 = torch_c.from_builtin_tensor %collapsed_52 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%289 = torch_c.to_builtin_tensor %287 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%290 = torch_c.to_builtin_tensor %288 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_53 = tensor.insert_slice %289 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_54 = tensor.insert_slice %290 into %inserted_slice_53[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_55 = tensor.expand_shape %inserted_slice_54 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_56 = tensor.collapse_shape %expanded_55 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%291 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_56 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%292 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_38 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%293 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_48 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%294 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%293 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_57 = tensor.collapse_shape %294 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%295 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%291 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%296 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%295 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%297 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%296 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_58 = tensor.collapse_shape %297 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%298 = linalg.batch_matmul ins(%collapsed_57, %collapsed_58 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_59 = tensor.expand_shape %298 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%299 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_59 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%300:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%299 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%301 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%299, %300#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%302 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%301 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%303 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%302 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%304 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%302, %303 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%305 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%304 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_60 = tensor.collapse_shape %305 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%306 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%292 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%307 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%306 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_61 = tensor.collapse_shape %307 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%308 = linalg.batch_matmul ins(%collapsed_60, %collapsed_61 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_62 = tensor.expand_shape %308 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%309 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_62 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%310 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%93 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_63 = tensor.collapse_shape %309 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%311 = linalg.matmul ins(%collapsed_63, %310 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_64 = tensor.expand_shape %311 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%312 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%255, %expanded_64 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%313 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%312 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%314 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%313 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%315 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%314 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%316 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%315 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%317 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%316 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%318 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%312, %317 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%319 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%318, %94 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%320 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%95 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_65 = tensor.collapse_shape %319 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%321 = linalg.matmul ins(%collapsed_65, %320 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_66 = tensor.expand_shape %321 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%322 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_66 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%323 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%322, %expanded_66 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%324 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%96 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%325 = linalg.matmul ins(%collapsed_65, %324 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_67 = tensor.expand_shape %325 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%326 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%323, %expanded_67 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%327 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%97 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_68 = tensor.collapse_shape %326 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%328 = linalg.matmul ins(%collapsed_68, %327 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_69 = tensor.expand_shape %328 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%329 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%312, %expanded_69 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%330 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%329 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%331 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%330 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%332 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%331 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%333 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%332 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%334 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%333 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%335 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%329, %334 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%336 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%335, %98 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%337 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%99 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_70 = tensor.collapse_shape %336 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%338 = linalg.matmul ins(%collapsed_70, %337 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%339 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%100 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%340 = linalg.matmul ins(%collapsed_70, %339 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%341 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%101 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%342 = linalg.matmul ins(%collapsed_70, %341 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_71 = tensor.expand_shape %342 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_72 = tensor.expand_shape %338 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%343 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_72[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_72[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_73 = tensor.expand_shape %340 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%344 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_73[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_73[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%345 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%343, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%346 = torch_c.from_builtin_tensor %345 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%347 = torch.aten.view_as_real %346 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%348 = torch_c.to_builtin_tensor %347 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_74 = tensor.extract_slice %348[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%349 = torch.aten.view_as_real %346 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%350 = torch_c.to_builtin_tensor %349 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_75 = tensor.extract_slice %350[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_76 = tensor.collapse_shape %extracted_slice_74 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%351 = torch_c.from_builtin_tensor %collapsed_76 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_77 = tensor.collapse_shape %extracted_slice_75 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%352 = torch_c.from_builtin_tensor %collapsed_77 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%353 = torch_c.to_builtin_tensor %351 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%354 = torch_c.to_builtin_tensor %352 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_78 = tensor.insert_slice %353 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_79 = tensor.insert_slice %354 into %inserted_slice_78[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_80 = tensor.expand_shape %inserted_slice_79 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_81 = tensor.collapse_shape %expanded_80 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%355 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%344, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%356 = torch_c.from_builtin_tensor %355 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%357 = torch.aten.view_as_real %356 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%358 = torch_c.to_builtin_tensor %357 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_82 = tensor.extract_slice %358[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%359 = torch.aten.view_as_real %356 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%360 = torch_c.to_builtin_tensor %359 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_83 = tensor.extract_slice %360[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_84 = tensor.collapse_shape %extracted_slice_82 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%361 = torch_c.from_builtin_tensor %collapsed_84 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_85 = tensor.collapse_shape %extracted_slice_83 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%362 = torch_c.from_builtin_tensor %collapsed_85 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%363 = torch_c.to_builtin_tensor %361 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%364 = torch_c.to_builtin_tensor %362 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_86 = tensor.insert_slice %363 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_87 = tensor.insert_slice %364 into %inserted_slice_86[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_88 = tensor.expand_shape %inserted_slice_87 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_89 = tensor.collapse_shape %expanded_88 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%365 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_89 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%366 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_71 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%367 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_81 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%368 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%367 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_90 = tensor.collapse_shape %368 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%369 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%365 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%370 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%369 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%371 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%370 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_91 = tensor.collapse_shape %371 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%372 = linalg.batch_matmul ins(%collapsed_90, %collapsed_91 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_92 = tensor.expand_shape %372 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%373 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_92 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%374:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%373 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%375 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%373, %374#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%376 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%375 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%377 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%376 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%378 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%376, %377 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%379 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%378 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_93 = tensor.collapse_shape %379 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%380 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%366 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%381 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%380 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_94 = tensor.collapse_shape %381 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%382 = linalg.batch_matmul ins(%collapsed_93, %collapsed_94 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_95 = tensor.expand_shape %382 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%383 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_95 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%384 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%102 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_96 = tensor.collapse_shape %383 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%385 = linalg.matmul ins(%collapsed_96, %384 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_97 = tensor.expand_shape %385 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%386 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%329, %expanded_97 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%387 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%386 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%388 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%387 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%389 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%388 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%390 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%389 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%391 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%390 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%392 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%386, %391 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%393 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%392, %103 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%394 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%104 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_98 = tensor.collapse_shape %393 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%395 = linalg.matmul ins(%collapsed_98, %394 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_99 = tensor.expand_shape %395 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%396 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_99 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%397 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%396, %expanded_99 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%398 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%105 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%399 = linalg.matmul ins(%collapsed_98, %398 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_100 = tensor.expand_shape %399 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%400 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%397, %expanded_100 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%401 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%106 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_101 = tensor.collapse_shape %400 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%402 = linalg.matmul ins(%collapsed_101, %401 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_102 = tensor.expand_shape %402 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%403 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%386, %expanded_102 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%404 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%403 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%405 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%404 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%406 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%405 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%407 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%406 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%408 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%407 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%409 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%403, %408 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%410 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%409, %107 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%411 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%108 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_103 = tensor.collapse_shape %410 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%412 = linalg.matmul ins(%collapsed_103, %411 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%413 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%109 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%414 = linalg.matmul ins(%collapsed_103, %413 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%415 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%110 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%416 = linalg.matmul ins(%collapsed_103, %415 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_104 = tensor.expand_shape %416 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_105 = tensor.expand_shape %412 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%417 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_105[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_105[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_106 = tensor.expand_shape %414 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%418 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_106[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_106[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%419 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%417, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%420 = torch_c.from_builtin_tensor %419 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%421 = torch.aten.view_as_real %420 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%422 = torch_c.to_builtin_tensor %421 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_107 = tensor.extract_slice %422[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%423 = torch.aten.view_as_real %420 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%424 = torch_c.to_builtin_tensor %423 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_108 = tensor.extract_slice %424[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_109 = tensor.collapse_shape %extracted_slice_107 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%425 = torch_c.from_builtin_tensor %collapsed_109 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_110 = tensor.collapse_shape %extracted_slice_108 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%426 = torch_c.from_builtin_tensor %collapsed_110 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%427 = torch_c.to_builtin_tensor %425 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%428 = torch_c.to_builtin_tensor %426 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_111 = tensor.insert_slice %427 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_112 = tensor.insert_slice %428 into %inserted_slice_111[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_113 = tensor.expand_shape %inserted_slice_112 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_114 = tensor.collapse_shape %expanded_113 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%429 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%418, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%430 = torch_c.from_builtin_tensor %429 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%431 = torch.aten.view_as_real %430 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%432 = torch_c.to_builtin_tensor %431 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_115 = tensor.extract_slice %432[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%433 = torch.aten.view_as_real %430 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%434 = torch_c.to_builtin_tensor %433 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_116 = tensor.extract_slice %434[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_117 = tensor.collapse_shape %extracted_slice_115 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%435 = torch_c.from_builtin_tensor %collapsed_117 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_118 = tensor.collapse_shape %extracted_slice_116 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%436 = torch_c.from_builtin_tensor %collapsed_118 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%437 = torch_c.to_builtin_tensor %435 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%438 = torch_c.to_builtin_tensor %436 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_119 = tensor.insert_slice %437 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_120 = tensor.insert_slice %438 into %inserted_slice_119[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_121 = tensor.expand_shape %inserted_slice_120 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_122 = tensor.collapse_shape %expanded_121 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%439 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_122 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%440 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_104 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%441 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_114 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%442 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%441 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_123 = tensor.collapse_shape %442 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%443 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%439 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%444 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%443 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%445 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%444 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_124 = tensor.collapse_shape %445 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%446 = linalg.batch_matmul ins(%collapsed_123, %collapsed_124 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_125 = tensor.expand_shape %446 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%447 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_125 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%448:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%447 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%449 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%447, %448#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%450 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%449 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%451 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%450 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%452 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%450, %451 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%453 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%452 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_126 = tensor.collapse_shape %453 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%454 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%440 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%455 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%454 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_127 = tensor.collapse_shape %455 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%456 = linalg.batch_matmul ins(%collapsed_126, %collapsed_127 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_128 = tensor.expand_shape %456 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%457 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_128 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%458 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%111 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_129 = tensor.collapse_shape %457 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%459 = linalg.matmul ins(%collapsed_129, %458 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_130 = tensor.expand_shape %459 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%460 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%403, %expanded_130 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%461 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%460 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%462 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%461 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%463 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%462 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%464 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%463 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%465 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%464 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%466 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%460, %465 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%467 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%466, %112 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%468 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%113 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_131 = tensor.collapse_shape %467 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%469 = linalg.matmul ins(%collapsed_131, %468 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_132 = tensor.expand_shape %469 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%470 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_132 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%471 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%470, %expanded_132 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%472 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%114 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%473 = linalg.matmul ins(%collapsed_131, %472 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_133 = tensor.expand_shape %473 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%474 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%471, %expanded_133 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%475 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%115 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_134 = tensor.collapse_shape %474 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%476 = linalg.matmul ins(%collapsed_134, %475 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_135 = tensor.expand_shape %476 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%477 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%460, %expanded_135 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%478 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%477 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%479 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%478 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%480 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%479 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%481 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%480 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%482 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%481 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%483 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%477, %482 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%484 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%483, %116 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%485 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%117 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_136 = tensor.collapse_shape %484 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%486 = linalg.matmul ins(%collapsed_136, %485 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%487 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%118 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%488 = linalg.matmul ins(%collapsed_136, %487 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%489 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%119 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%490 = linalg.matmul ins(%collapsed_136, %489 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_137 = tensor.expand_shape %490 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_138 = tensor.expand_shape %486 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%491 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_138[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_138[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_139 = tensor.expand_shape %488 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%492 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_139[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_139[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%493 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%491, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%494 = torch_c.from_builtin_tensor %493 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%495 = torch.aten.view_as_real %494 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%496 = torch_c.to_builtin_tensor %495 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_140 = tensor.extract_slice %496[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%497 = torch.aten.view_as_real %494 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%498 = torch_c.to_builtin_tensor %497 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_141 = tensor.extract_slice %498[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_142 = tensor.collapse_shape %extracted_slice_140 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%499 = torch_c.from_builtin_tensor %collapsed_142 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_143 = tensor.collapse_shape %extracted_slice_141 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%500 = torch_c.from_builtin_tensor %collapsed_143 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%501 = torch_c.to_builtin_tensor %499 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%502 = torch_c.to_builtin_tensor %500 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_144 = tensor.insert_slice %501 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_145 = tensor.insert_slice %502 into %inserted_slice_144[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_146 = tensor.expand_shape %inserted_slice_145 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_147 = tensor.collapse_shape %expanded_146 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%503 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%492, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%504 = torch_c.from_builtin_tensor %503 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%505 = torch.aten.view_as_real %504 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%506 = torch_c.to_builtin_tensor %505 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_148 = tensor.extract_slice %506[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%507 = torch.aten.view_as_real %504 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%508 = torch_c.to_builtin_tensor %507 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_149 = tensor.extract_slice %508[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_150 = tensor.collapse_shape %extracted_slice_148 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%509 = torch_c.from_builtin_tensor %collapsed_150 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_151 = tensor.collapse_shape %extracted_slice_149 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%510 = torch_c.from_builtin_tensor %collapsed_151 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%511 = torch_c.to_builtin_tensor %509 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%512 = torch_c.to_builtin_tensor %510 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_152 = tensor.insert_slice %511 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_153 = tensor.insert_slice %512 into %inserted_slice_152[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_154 = tensor.expand_shape %inserted_slice_153 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_155 = tensor.collapse_shape %expanded_154 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%513 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_155 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%514 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_137 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%515 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_147 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%516 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%515 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_156 = tensor.collapse_shape %516 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%517 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%513 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%518 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%517 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%519 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%518 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_157 = tensor.collapse_shape %519 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%520 = linalg.batch_matmul ins(%collapsed_156, %collapsed_157 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_158 = tensor.expand_shape %520 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%521 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_158 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%522:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%521 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%523 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%521, %522#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%524 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%523 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%525 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%524 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%526 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%524, %525 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%527 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%526 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_159 = tensor.collapse_shape %527 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%528 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%514 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%529 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%528 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_160 = tensor.collapse_shape %529 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%530 = linalg.batch_matmul ins(%collapsed_159, %collapsed_160 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_161 = tensor.expand_shape %530 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%531 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_161 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%532 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%120 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_162 = tensor.collapse_shape %531 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%533 = linalg.matmul ins(%collapsed_162, %532 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_163 = tensor.expand_shape %533 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%534 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%477, %expanded_163 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%535 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%534 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%536 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%535 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%537 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%536 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%538 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%537 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%539 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%538 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%540 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%534, %539 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%541 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%540, %121 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%542 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%122 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_164 = tensor.collapse_shape %541 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%543 = linalg.matmul ins(%collapsed_164, %542 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_165 = tensor.expand_shape %543 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%544 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_165 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%545 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%544, %expanded_165 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%546 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%123 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%547 = linalg.matmul ins(%collapsed_164, %546 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_166 = tensor.expand_shape %547 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%548 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%545, %expanded_166 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%549 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%124 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_167 = tensor.collapse_shape %548 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%550 = linalg.matmul ins(%collapsed_167, %549 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_168 = tensor.expand_shape %550 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%551 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%534, %expanded_168 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%552 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%551 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%553 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%552 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%554 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%553 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%555 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%554 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%556 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%555 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%557 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%551, %556 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%558 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%557, %125 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%559 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%126 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_169 = tensor.collapse_shape %558 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%560 = linalg.matmul ins(%collapsed_169, %559 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%561 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%127 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%562 = linalg.matmul ins(%collapsed_169, %561 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%563 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%128 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%564 = linalg.matmul ins(%collapsed_169, %563 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_170 = tensor.expand_shape %564 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_171 = tensor.expand_shape %560 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%565 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_171[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_171[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_172 = tensor.expand_shape %562 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%566 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_172[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_172[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%567 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%565, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%568 = torch_c.from_builtin_tensor %567 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%569 = torch.aten.view_as_real %568 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%570 = torch_c.to_builtin_tensor %569 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_173 = tensor.extract_slice %570[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%571 = torch.aten.view_as_real %568 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%572 = torch_c.to_builtin_tensor %571 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_174 = tensor.extract_slice %572[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_175 = tensor.collapse_shape %extracted_slice_173 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%573 = torch_c.from_builtin_tensor %collapsed_175 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_176 = tensor.collapse_shape %extracted_slice_174 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%574 = torch_c.from_builtin_tensor %collapsed_176 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%575 = torch_c.to_builtin_tensor %573 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%576 = torch_c.to_builtin_tensor %574 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_177 = tensor.insert_slice %575 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_178 = tensor.insert_slice %576 into %inserted_slice_177[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_179 = tensor.expand_shape %inserted_slice_178 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_180 = tensor.collapse_shape %expanded_179 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%577 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%566, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%578 = torch_c.from_builtin_tensor %577 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%579 = torch.aten.view_as_real %578 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%580 = torch_c.to_builtin_tensor %579 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_181 = tensor.extract_slice %580[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%581 = torch.aten.view_as_real %578 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%582 = torch_c.to_builtin_tensor %581 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_182 = tensor.extract_slice %582[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_183 = tensor.collapse_shape %extracted_slice_181 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%583 = torch_c.from_builtin_tensor %collapsed_183 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_184 = tensor.collapse_shape %extracted_slice_182 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%584 = torch_c.from_builtin_tensor %collapsed_184 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%585 = torch_c.to_builtin_tensor %583 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%586 = torch_c.to_builtin_tensor %584 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_185 = tensor.insert_slice %585 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_186 = tensor.insert_slice %586 into %inserted_slice_185[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_187 = tensor.expand_shape %inserted_slice_186 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_188 = tensor.collapse_shape %expanded_187 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%587 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_188 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%588 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_170 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%589 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_180 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%590 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%589 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_189 = tensor.collapse_shape %590 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%591 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%587 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%592 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%591 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%593 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%592 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_190 = tensor.collapse_shape %593 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%594 = linalg.batch_matmul ins(%collapsed_189, %collapsed_190 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_191 = tensor.expand_shape %594 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%595 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_191 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%596:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%595 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%597 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%595, %596#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%598 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%597 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%599 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%598 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%600 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%598, %599 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%601 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%600 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_192 = tensor.collapse_shape %601 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%602 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%588 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%603 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%602 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_193 = tensor.collapse_shape %603 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%604 = linalg.batch_matmul ins(%collapsed_192, %collapsed_193 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_194 = tensor.expand_shape %604 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%605 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_194 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%606 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%129 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_195 = tensor.collapse_shape %605 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%607 = linalg.matmul ins(%collapsed_195, %606 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_196 = tensor.expand_shape %607 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%608 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%551, %expanded_196 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%609 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%608 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%610 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%609 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%611 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%610 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%612 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%611 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%613 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%612 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%614 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%608, %613 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%615 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%614, %130 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%616 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%131 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_197 = tensor.collapse_shape %615 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%617 = linalg.matmul ins(%collapsed_197, %616 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_198 = tensor.expand_shape %617 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%618 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_198 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%619 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%618, %expanded_198 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%620 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%132 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%621 = linalg.matmul ins(%collapsed_197, %620 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_199 = tensor.expand_shape %621 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%622 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%619, %expanded_199 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%623 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%133 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_200 = tensor.collapse_shape %622 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%624 = linalg.matmul ins(%collapsed_200, %623 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_201 = tensor.expand_shape %624 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%625 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%608, %expanded_201 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%626 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%625 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%627 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%626 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%628 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%627 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%629 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%628 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%630 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%629 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%631 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%625, %630 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%632 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%631, %134 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%633 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%135 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_202 = tensor.collapse_shape %632 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%634 = linalg.matmul ins(%collapsed_202, %633 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%635 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%136 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%636 = linalg.matmul ins(%collapsed_202, %635 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%637 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%137 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%638 = linalg.matmul ins(%collapsed_202, %637 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_203 = tensor.expand_shape %638 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_204 = tensor.expand_shape %634 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%639 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_204[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_204[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_205 = tensor.expand_shape %636 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%640 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_205[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_205[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%641 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%639, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%642 = torch_c.from_builtin_tensor %641 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%643 = torch.aten.view_as_real %642 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%644 = torch_c.to_builtin_tensor %643 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_206 = tensor.extract_slice %644[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%645 = torch.aten.view_as_real %642 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%646 = torch_c.to_builtin_tensor %645 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_207 = tensor.extract_slice %646[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_208 = tensor.collapse_shape %extracted_slice_206 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%647 = torch_c.from_builtin_tensor %collapsed_208 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_209 = tensor.collapse_shape %extracted_slice_207 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%648 = torch_c.from_builtin_tensor %collapsed_209 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%649 = torch_c.to_builtin_tensor %647 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%650 = torch_c.to_builtin_tensor %648 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_210 = tensor.insert_slice %649 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_211 = tensor.insert_slice %650 into %inserted_slice_210[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_212 = tensor.expand_shape %inserted_slice_211 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_213 = tensor.collapse_shape %expanded_212 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%651 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%640, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%652 = torch_c.from_builtin_tensor %651 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%653 = torch.aten.view_as_real %652 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%654 = torch_c.to_builtin_tensor %653 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_214 = tensor.extract_slice %654[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%655 = torch.aten.view_as_real %652 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%656 = torch_c.to_builtin_tensor %655 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_215 = tensor.extract_slice %656[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_216 = tensor.collapse_shape %extracted_slice_214 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%657 = torch_c.from_builtin_tensor %collapsed_216 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_217 = tensor.collapse_shape %extracted_slice_215 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%658 = torch_c.from_builtin_tensor %collapsed_217 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%659 = torch_c.to_builtin_tensor %657 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%660 = torch_c.to_builtin_tensor %658 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_218 = tensor.insert_slice %659 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_219 = tensor.insert_slice %660 into %inserted_slice_218[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_220 = tensor.expand_shape %inserted_slice_219 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_221 = tensor.collapse_shape %expanded_220 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%661 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_221 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%662 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_203 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%663 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_213 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%664 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%663 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_222 = tensor.collapse_shape %664 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%665 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%661 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%666 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%665 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%667 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%666 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_223 = tensor.collapse_shape %667 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%668 = linalg.batch_matmul ins(%collapsed_222, %collapsed_223 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_224 = tensor.expand_shape %668 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%669 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_224 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%670:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%669 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%671 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%669, %670#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%672 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%671 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%673 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%672 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%674 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%672, %673 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%675 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%674 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_225 = tensor.collapse_shape %675 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%676 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%662 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%677 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%676 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_226 = tensor.collapse_shape %677 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%678 = linalg.batch_matmul ins(%collapsed_225, %collapsed_226 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_227 = tensor.expand_shape %678 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%679 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_227 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%680 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%138 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_228 = tensor.collapse_shape %679 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%681 = linalg.matmul ins(%collapsed_228, %680 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_229 = tensor.expand_shape %681 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%682 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%625, %expanded_229 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%683 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%682 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%684 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%683 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%685 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%684 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%686 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%685 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%687 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%686 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%688 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%682, %687 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%689 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%688, %139 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%690 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%140 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_230 = tensor.collapse_shape %689 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%691 = linalg.matmul ins(%collapsed_230, %690 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_231 = tensor.expand_shape %691 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%692 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_231 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%693 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%692, %expanded_231 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%694 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%141 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%695 = linalg.matmul ins(%collapsed_230, %694 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_232 = tensor.expand_shape %695 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%696 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%693, %expanded_232 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%697 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%142 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_233 = tensor.collapse_shape %696 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%698 = linalg.matmul ins(%collapsed_233, %697 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_234 = tensor.expand_shape %698 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%699 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%682, %expanded_234 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%700 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%699 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%701 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%700 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%702 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%701 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%703 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%702 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%704 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%703 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%705 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%699, %704 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%706 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%705, %143 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%707 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%144 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_235 = tensor.collapse_shape %706 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%708 = linalg.matmul ins(%collapsed_235, %707 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%709 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%145 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%710 = linalg.matmul ins(%collapsed_235, %709 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%711 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%146 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%712 = linalg.matmul ins(%collapsed_235, %711 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_236 = tensor.expand_shape %712 [[0], [1, 2, 3]] : tensor<1x512xf32> into tensor<1x1x8x64xf32>
%expanded_237 = tensor.expand_shape %708 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%713 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_237[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_237[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%expanded_238 = tensor.expand_shape %710 [[0], [1, 2, 3, 4]] : tensor<1x512xf32> into tensor<1x1x8x32x2xf32>
%714 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%out: complex<f64>):
%788 = linalg.index 0 : index
%789 = linalg.index 1 : index
%790 = linalg.index 2 : index
%791 = linalg.index 3 : index
%extracted = tensor.extract %expanded_238[%788, %789, %790, %791, %c0] : tensor<1x1x8x32x2xf32>
%extracted_269 = tensor.extract %expanded_238[%788, %789, %790, %791, %c1] : tensor<1x1x8x32x2xf32>
%792 = arith.extf %extracted : f32 to f64
%793 = arith.extf %extracted_269 : f32 to f64
%794 = complex.create %792, %793 : complex<f64>
linalg.yield %794 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%715 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%713, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%716 = torch_c.from_builtin_tensor %715 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%717 = torch.aten.view_as_real %716 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%718 = torch_c.to_builtin_tensor %717 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_239 = tensor.extract_slice %718[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%719 = torch.aten.view_as_real %716 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%720 = torch_c.to_builtin_tensor %719 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_240 = tensor.extract_slice %720[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_241 = tensor.collapse_shape %extracted_slice_239 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%721 = torch_c.from_builtin_tensor %collapsed_241 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_242 = tensor.collapse_shape %extracted_slice_240 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%722 = torch_c.from_builtin_tensor %collapsed_242 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%723 = torch_c.to_builtin_tensor %721 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%724 = torch_c.to_builtin_tensor %722 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_243 = tensor.insert_slice %723 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_244 = tensor.insert_slice %724 into %inserted_slice_243[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_245 = tensor.expand_shape %inserted_slice_244 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_246 = tensor.collapse_shape %expanded_245 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%725 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%714, %expanded_8 : tensor<1x1x8x32xcomplex<f64>>, tensor<1x1x1x32xcomplex<f64>>) outs(%174 : tensor<1x1x8x32xcomplex<f64>>) {
^bb0(%in: complex<f64>, %in_269: complex<f64>, %out: complex<f64>):
%788 = complex.mul %in, %in_269 : complex<f64>
linalg.yield %788 : complex<f64>
} -> tensor<1x1x8x32xcomplex<f64>>
%726 = torch_c.from_builtin_tensor %725 : tensor<1x1x8x32xcomplex<f64>> -> !torch.vtensor<[1,1,8,32],complex<f64>>
%727 = torch.aten.view_as_real %726 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%728 = torch_c.to_builtin_tensor %727 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_247 = tensor.extract_slice %728[0, 0, 0, 0, 0] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%729 = torch.aten.view_as_real %726 : !torch.vtensor<[1,1,8,32],complex<f64>> -> !torch.vtensor<[1,1,8,32,2],f32>
%730 = torch_c.to_builtin_tensor %729 : !torch.vtensor<[1,1,8,32,2],f32> -> tensor<1x1x8x32x2xf32>
%extracted_slice_248 = tensor.extract_slice %730[0, 0, 0, 0, 1] [1, 1, 8, 32, 1] [1, 1, 1, 1, 1] : tensor<1x1x8x32x2xf32> to tensor<1x1x8x32x1xf32>
%collapsed_249 = tensor.collapse_shape %extracted_slice_247 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%731 = torch_c.from_builtin_tensor %collapsed_249 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%collapsed_250 = tensor.collapse_shape %extracted_slice_248 [[0, 1, 2, 3], [4]] : tensor<1x1x8x32x1xf32> into tensor<256x1xf32>
%732 = torch_c.from_builtin_tensor %collapsed_250 : tensor<256x1xf32> -> !torch.vtensor<[256,1],f32>
%733 = torch_c.to_builtin_tensor %731 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%734 = torch_c.to_builtin_tensor %732 : !torch.vtensor<[256,1],f32> -> tensor<256x1xf32>
%inserted_slice_251 = tensor.insert_slice %733 into %189[0, 0] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%inserted_slice_252 = tensor.insert_slice %734 into %inserted_slice_251[0, 1] [256, 1] [1, 1] : tensor<256x1xf32> into tensor<256x2xf32>
%expanded_253 = tensor.expand_shape %inserted_slice_252 [[0, 1, 2, 3], [4]] : tensor<256x2xf32> into tensor<1x1x8x32x2xf32>
%collapsed_254 = tensor.collapse_shape %expanded_253 [[0], [1], [2], [3, 4]] : tensor<1x1x8x32x2xf32> into tensor<1x1x8x64xf32>
%735 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_254 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%736 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_236 : tensor<1x1x8x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%737 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed_246 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%738 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%737 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_255 = tensor.collapse_shape %738 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%739 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%735 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%740 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%739 : tensor<1x8x1x64xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%741 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%740 : tensor<1x8x64x1xf32>) outs(%207 : tensor<1x8x64x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x64x1xf32>
%collapsed_256 = tensor.collapse_shape %741 [[0, 1], [2], [3]] : tensor<1x8x64x1xf32> into tensor<8x64x1xf32>
%742 = linalg.batch_matmul ins(%collapsed_255, %collapsed_256 : tensor<8x1x64xf32>, tensor<8x64x1xf32>) outs(%211 : tensor<8x1x1xf32>) -> tensor<8x1x1xf32>
%expanded_257 = tensor.expand_shape %742 [[0, 1], [2], [3]] : tensor<8x1x1xf32> into tensor<1x8x1x1xf32>
%743 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_257 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_5 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%744:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%743 : tensor<1x8x1x1xf32>) outs(%217, %216 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>) {
^bb0(%in: f32, %out: f32, %out_269: i64):
%788 = linalg.index 3 : index
%789 = arith.index_cast %788 : index to i64
%790 = arith.maxf %in, %out : f32
%791 = arith.cmpf ogt, %in, %out : f32
%792 = arith.select %791, %789, %out_269 : i64
linalg.yield %790, %792 : f32, i64
} -> (tensor<1x8x1x1xf32>, tensor<1x8x1x1xi64>)
%745 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%743, %744#0 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.subf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%746 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%745 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.exp %in : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%747 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%746 : tensor<1x8x1x1xf32>) outs(%221 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%748 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%746, %747 : tensor<1x8x1x1xf32>, tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.divf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x8x1x1xf32>
%749 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%748 : tensor<1x8x1x1xf32>) outs(%213 : tensor<1x8x1x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x1xf32>
%collapsed_258 = tensor.collapse_shape %749 [[0, 1], [2], [3]] : tensor<1x8x1x1xf32> into tensor<8x1x1xf32>
%750 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%736 : tensor<1x1x8x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%751 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%750 : tensor<1x8x1x64xf32>) outs(%203 : tensor<1x8x1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x1x64xf32>
%collapsed_259 = tensor.collapse_shape %751 [[0, 1], [2], [3]] : tensor<1x8x1x64xf32> into tensor<8x1x64xf32>
%752 = linalg.batch_matmul ins(%collapsed_258, %collapsed_259 : tensor<8x1x1xf32>, tensor<8x1x64xf32>) outs(%228 : tensor<8x1x64xf32>) -> tensor<8x1x64xf32>
%expanded_260 = tensor.expand_shape %752 [[0, 1], [2], [3]] : tensor<8x1x64xf32> into tensor<1x8x1x64xf32>
%753 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_260 : tensor<1x8x1x64xf32>) outs(%200 : tensor<1x1x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x64xf32>
%754 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%147 : tensor<512x512xf32>) outs(%165 : tensor<512x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x512xf32>
%collapsed_261 = tensor.collapse_shape %753 [[0], [1, 2, 3]] : tensor<1x1x8x64xf32> into tensor<1x512xf32>
%755 = linalg.matmul ins(%collapsed_261, %754 : tensor<1x512xf32>, tensor<512x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_262 = tensor.expand_shape %755 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%756 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%699, %expanded_262 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%757 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%756 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%758 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%757 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%759 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%758 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%760 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%759 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%761 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%760 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%762 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%756, %761 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%763 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%762, %148 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%764 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%149 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%collapsed_263 = tensor.collapse_shape %763 [[0], [1, 2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%765 = linalg.matmul ins(%collapsed_263, %764 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_264 = tensor.expand_shape %765 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%766 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_264 : tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.negf %in : f32
%789 = math.exp %788 : f32
%790 = arith.addf %789, %cst_1 : f32
%791 = arith.divf %cst_1, %790 : f32
linalg.yield %791 : f32
} -> tensor<1x1x1536xf32>
%767 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%766, %expanded_264 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%768 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%150 : tensor<1536x512xf32>) outs(%241 : tensor<512x1536xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1536xf32>
%769 = linalg.matmul ins(%collapsed_263, %768 : tensor<1x512xf32>, tensor<512x1536xf32>) outs(%244 : tensor<1x1536xf32>) -> tensor<1x1536xf32>
%expanded_265 = tensor.expand_shape %769 [[0], [1, 2]] : tensor<1x1536xf32> into tensor<1x1x1536xf32>
%770 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%767, %expanded_265 : tensor<1x1x1536xf32>, tensor<1x1x1536xf32>) outs(%246 : tensor<1x1x1536xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1536xf32>
%771 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%151 : tensor<512x1536xf32>) outs(%252 : tensor<1536x512xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1536x512xf32>
%collapsed_266 = tensor.collapse_shape %770 [[0], [1, 2]] : tensor<1x1x1536xf32> into tensor<1x1536xf32>
%772 = linalg.matmul ins(%collapsed_266, %771 : tensor<1x1536xf32>, tensor<1536x512xf32>) outs(%168 : tensor<1x512xf32>) -> tensor<1x512xf32>
%expanded_267 = tensor.expand_shape %772 [[0], [1, 2]] : tensor<1x512xf32> into tensor<1x1x512xf32>
%773 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%756, %expanded_267 : tensor<1x1x512xf32>, tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.addf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%774 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%773 : tensor<1x1x512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.powf %in, %cst_3 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%775 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%774 : tensor<1x1x512xf32>) outs(%158 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.addf %in, %out : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%776 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%775 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.divf %in, %cst_4 : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%777 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%776 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = arith.truncf %cst_2 : f64 to f32
%789 = arith.addf %in, %788 : f32
linalg.yield %789 : f32
} -> tensor<1x1x1xf32>
%778 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%777 : tensor<1x1x1xf32>) outs(%157 : tensor<1x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%788 = math.rsqrt %in : f32
linalg.yield %788 : f32
} -> tensor<1x1x1xf32>
%779 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%773, %778 : tensor<1x1x512xf32>, tensor<1x1x1xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%780 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%779, %152 : tensor<1x1x512xf32>, tensor<512xf32>) outs(%154 : tensor<1x1x512xf32>) {
^bb0(%in: f32, %in_269: f32, %out: f32):
%788 = arith.mulf %in, %in_269 : f32
linalg.yield %788 : f32
} -> tensor<1x1x512xf32>
%collapsed_268 = tensor.collapse_shape %780 [[0, 1], [2]] : tensor<1x1x512xf32> into tensor<1x512xf32>
%781 = tensor.empty() : tensor<512x1xf32>
%782 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%153 : tensor<1x512xf32>) outs(%781 : tensor<512x1xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<512x1xf32>
%783 = tensor.empty() : tensor<1x1xf32>
%784 = linalg.fill ins(%cst : f32) outs(%783 : tensor<1x1xf32>) -> tensor<1x1xf32>
%785 = linalg.matmul ins(%collapsed_268, %782 : tensor<1x512xf32>, tensor<512x1xf32>) outs(%784 : tensor<1x1xf32>) -> tensor<1x1xf32>
%786 = torch_c.from_builtin_tensor %785 : tensor<1x1xf32> -> !torch.vtensor<[1,1],f32>
%787 = torch_c.to_builtin_tensor %786 : !torch.vtensor<[1,1],f32> -> tensor<1x1xf32>
return %787, %arg76 : tensor<1x1xf32>, tensor<2048x32x2xf32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment