Last active
December 19, 2022 07:11
-
-
Save AmosLewis/df6feb5c9618bd4b5a843c3ce490a2f9 to your computer and use it in GitHub Desktop.
deberta
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch._decomp import get_decompositions | |
import tempfile | |
import torch_mlir | |
def prepare_sentence_tokens(hf_model: str, sentence: str): | |
tokenizer = AutoTokenizer.from_pretrained(hf_model) | |
return torch.tensor([tokenizer.encode(sentence)]) | |
class HfMaskedLM(torch.nn.Module): | |
def __init__(self, model_name: str): | |
super().__init__() | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
model_name, # The pretrained model name. | |
# The number of output labels--2 for binary classification. | |
num_labels=2, | |
# Whether the model returns attentions weights. | |
output_attentions=False, | |
# Whether the model returns all hidden-states. | |
output_hidden_states=False, | |
torchscript=True, | |
) | |
self.model.eval() | |
def forward(self, tokens): | |
return self.model.forward(tokens)[0] | |
#hf_minilm_model = "hf-internal-testing/tiny-random-deberta" | |
hf_minilm_model = "microsoft/deberta-v3-base" | |
test_input = torch.randint(2, (1, 128)) | |
model = HfMaskedLM(hf_minilm_model) | |
print("model(test_input): ") | |
print(model(test_input)) | |
fx_g = make_fx( | |
model, | |
decomposition_table=get_decompositions( | |
[ | |
torch.ops.aten.embedding_dense_backward, | |
torch.ops.aten.native_layer_norm_backward, | |
torch.ops.aten.slice_backward, | |
torch.ops.aten.select_backward, | |
torch.ops.aten.norm.ScalarOpt_dim, | |
torch.ops.aten.native_group_norm, | |
torch.ops.aten.upsample_bilinear2d.vec, | |
torch.ops.aten.split.Tensor, | |
torch.ops.aten.split_with_sizes, | |
] | |
), | |
)(test_input) | |
# print(fx_g.graph) | |
fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
fx_g.recompile() | |
def strip_overloads(gm): | |
""" | |
Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
Args: | |
gm(fx.GraphModule): The input Fx graph module to be modified | |
""" | |
for node in gm.graph.nodes: | |
if isinstance(node.target, torch._ops.OpOverload): | |
node.target = node.target.overloadpacket | |
gm.recompile() | |
strip_overloads(fx_g) | |
ts_g = torch.jit.script(fx_g) | |
# module = torch_mlir.compile( | |
# ts_g, | |
# (test_input), | |
# torch_mlir.OutputType.LINALG_ON_TENSORS, | |
# use_tracing=True, | |
# verbose=False, | |
# ) | |
module = torch_mlir.compile( | |
ts_g, | |
(test_input), | |
torch_mlir.OutputType.TOSA, | |
use_tracing=True, | |
verbose=False, | |
) | |
module.dump() | |
from shark.shark_inference import SharkInference | |
mlir_model = module | |
func_name = "forward" | |
shark_module = SharkInference( | |
mlir_model, func_name, device="cpu", mlir_dialect="tosa" | |
) | |
shark_module.compile() | |
def shark_result(x): | |
x_ny = x.detach().numpy() | |
inputs = (x_ny,) | |
result = shark_module.forward(inputs) | |
return torch.from_numpy(result) | |
observed_out = shark_result(test_input) | |
print(observed_out) |
IR Dump After Last Successfully Pattern Application
func.func @forward(%arg0: !torch.vtensor<[1,128],si64>) -> !torch.vtensor<[1,2],f32> {
%0 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[1,128],si64> to tensor<1x128xi64>
%int1 = torch.constant.int 1
%1 = builtin.unrealized_conversion_cast %int1 : !torch.int to i64
%int32 = torch.constant.int 32
%2 = builtin.unrealized_conversion_cast %int32 : !torch.int to i64
%int128 = torch.constant.int 128
%3 = builtin.unrealized_conversion_cast %int128 : !torch.int to i64
%float1.000000e00 = torch.constant.float 1.000000e+00
%4 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
%5 = torch.vtensor.literal(dense<0.000000e+00> : tensor<2xf32>) : !torch.vtensor<[2],f32>
%6 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2x32xf32>} : () -> tensor<2x32xf32>
%7 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2x32xf32>) : !torch.vtensor<[2,32],f32>
%8 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
%9 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%10 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
%11 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%12 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
%13 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%14 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
%15 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%16 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
%17 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%18 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
%19 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%20 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
%21 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%22 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
%23 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%24 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
%25 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%26 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
%27 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%28 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
%29 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%30 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
%31 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%32 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
%33 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%34 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
%35 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%36 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
%37 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%38 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
%39 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%40 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
%41 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%42 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
%43 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%44 = "tosa.const"() {value = dense<0.000000e+00> : tensor<37xf32>} : () -> tensor<37xf32>
%45 = torch.vtensor.literal(dense<0.000000e+00> : tensor<37xf32>) : !torch.vtensor<[37],f32>
%46 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
%47 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%48 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
%49 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%50 = "tosa.const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
%51 = torch.vtensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.vtensor<[],f32>
%52 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
%53 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%54 = "tosa.const"() {value = dense<0.000000e+00> : tensor<32xf32>} : () -> tensor<32xf32>
%55 = torch.vtensor.literal(dense<0.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
%56 = "tosa.const"() {value = dense<1.000000e+00> : tensor<32xf32>} : () -> tensor<32xf32>
%57 = torch.vtensor.literal(dense<1.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
%58 = "tosa.const"() {value = dense_resource<__elided__> : tensor<16x32xf32>} : () -> tensor<16x32xf32>
%59 = torch.vtensor.literal(dense_resource<__elided__> : tensor<16x32xf32>) : !torch.vtensor<[16,32],f32>
%60 = "tosa.const"() {value = dense_resource<__elided__> : tensor<512x32xf32>} : () -> tensor<512x32xf32>
%61 = torch.vtensor.literal(dense_resource<__elided__> : tensor<512x32xf32>) : !torch.vtensor<[512,32],f32>
%62 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1000x32xf32>} : () -> tensor<1000x32xf32>
%63 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1000x32xf32>) : !torch.vtensor<[1000,32],f32>
%64 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x512xsi64>} : () -> tensor<1x512xi64>
%65 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
%false = torch.constant.bool false
%66 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
%none = torch.constant.none
%int-1 = torch.constant.int -1
%67 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
%true = torch.constant.bool true
%68 = builtin.unrealized_conversion_cast %true : !torch.bool to i1
%int-2 = torch.constant.int -2
%69 = builtin.unrealized_conversion_cast %int-2 : !torch.int to i64
%int11 = torch.constant.int 11
%str = torch.constant.str "none"
%int0 = torch.constant.int 0
%70 = builtin.unrealized_conversion_cast %int0 : !torch.int to i64
%int9223372036854775807 = torch.constant.int 9223372036854775807
%71 = builtin.unrealized_conversion_cast %int9223372036854775807 : !torch.int to i64
%int2 = torch.constant.int 2
%72 = builtin.unrealized_conversion_cast %int2 : !torch.int to i64
%float9.999990e-08 = torch.constant.float 9.9999999999999995E-8
%73 = builtin.unrealized_conversion_cast %float9.999990e-08 : !torch.float to f64
%int96 = torch.constant.int 96
%int4 = torch.constant.int 4
%74 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
%int3 = torch.constant.int 3
%int8 = torch.constant.int 8
%int16 = torch.constant.int 16
%int24 = torch.constant.int 24
%float4.000000e00 = torch.constant.float 4.000000e+00
%int37 = torch.constant.int 37
%cpu = torch.constant.device "cpu"
%75 = torch.prim.ListConstruct %int1, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
%76 = "tosa.const"() {value = dense<1> : tensor<1x128xi32>} : () -> tensor<1x128xi32>
%77 = "tosa.cast"(%76) : (tensor<1x128xi32>) -> tensor<1x128xf32>
%78 = torch.aten.ones %75, %none, %none, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],f32>
%79 = "tosa.const"() {value = dense<0> : tensor<1x128xi32>} : () -> tensor<1x128xi32>
%80 = "tosa.cast"(%79) : (tensor<1x128xi32>) -> tensor<1x128xi64>
%81 = torch.aten.zeros %75, %int4, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],si64>
%82 = "tosa.slice"(%64) {size = [9223372036854775807, 512], start = [0, 0]} : (tensor<1x512xi64>) -> tensor<1x512xi64>
%83 = torch.aten.slice.Tensor %65, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64>
%84 = "tosa.slice"(%82) {size = [1, 128], start = [0, 0]} : (tensor<1x512xi64>) -> tensor<1x128xi64>
%85 = torch.aten.slice.Tensor %83, %int1, %int0, %int128, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128],si64>
%86 = "tosa.reshape"(%62) {new_shape = [1, 1000, 32]} : (tensor<1000x32xf32>) -> tensor<1x1000x32xf32>
%87 = "tosa.reshape"(%0) {new_shape = [1, 128]} : (tensor<1x128xi64>) -> tensor<1x128xi64>
%88 = "tosa.cast"(%87) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%89 = "tosa.gather"(%86, %88) : (tensor<1x1000x32xf32>, tensor<1x128xi32>) -> tensor<1x128x32xf32>
%90 = "tosa.reshape"(%89) {new_shape = [1, 128, 32]} : (tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
%91 = torch.aten.embedding %63, %arg0, %int0, %false, %false : !torch.vtensor<[1000,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%92 = "tosa.reshape"(%60) {new_shape = [1, 512, 32]} : (tensor<512x32xf32>) -> tensor<1x512x32xf32>
%93 = "tosa.reshape"(%84) {new_shape = [1, 128]} : (tensor<1x128xi64>) -> tensor<1x128xi64>
%94 = "tosa.cast"(%93) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%95 = "tosa.gather"(%92, %94) : (tensor<1x512x32xf32>, tensor<1x128xi32>) -> tensor<1x128x32xf32>
%96 = "tosa.reshape"(%95) {new_shape = [1, 128, 32]} : (tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
%97 = torch.aten.embedding %61, %85, %int-1, %false, %false : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%98 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%99 = "tosa.mul"(%96, %98) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<f32>) -> tensor<1x128x32xf32>
%100 = "tosa.add"(%90, %99) : (tensor<1x128x32xf32>, tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
%101 = torch.aten.add.Tensor %91, %97, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%102 = "tosa.reshape"(%58) {new_shape = [1, 16, 32]} : (tensor<16x32xf32>) -> tensor<1x16x32xf32>
%103 = "tosa.reshape"(%80) {new_shape = [1, 128]} : (tensor<1x128xi64>) -> tensor<1x128xi64>
%104 = "tosa.cast"(%103) : (tensor<1x128xi64>) -> tensor<1x128xi32>
%105 = "tosa.gather"(%102, %104) : (tensor<1x16x32xf32>, tensor<1x128xi32>) -> tensor<1x128x32xf32>
%106 = "tosa.reshape"(%105) {new_shape = [1, 128, 32]} : (tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
%107 = torch.aten.embedding %59, %81, %int-1, %false, %false : !torch.vtensor<[16,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%108 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%109 = "tosa.mul"(%106, %108) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<f32>) -> tensor<1x128x32xf32>
%110 = "tosa.add"(%100, %109) : (tensor<1x128x32xf32>, tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
%111 = torch.aten.add.Tensor %101, %107, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%112 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%113 = "tosa.reduce_sum"(%110) {axis = 2 : i64} : (tensor<1x128x32xf32>) -> tensor<1x128x1xf32>
%114 = torch.aten.sum.dim_IntList %111, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%115 = "tosa.const"() {value = dense<3.200000e+01> : tensor<f32>} : () -> tensor<f32>
%116 = "tosa.reciprocal"(%115) : (tensor<f32>) -> tensor<f32>
%117 = "tosa.mul"(%113, %116) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
%118 = torch.aten.div.Scalar %114, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%119 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%120 = "tosa.mul"(%117, %119) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
%121 = "tosa.sub"(%110, %120) : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
%122 = torch.aten.sub.Tensor %111, %118, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%123 = "tosa.const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%124 = "tosa.pow"(%121, %123) : (tensor<1x128x32xf32>, tensor<f32>) -> tensor<1x128x32xf32>
%125 = torch.aten.pow.Tensor_Scalar %122, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%126 = "tosa.reduce_sum"(%124) {axis = 2 : i64} : (tensor<1x128x32xf32>) -> tensor<1x128x1xf32>
%127 = torch.aten.sum.dim_IntList %125, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%128 = "tosa.const"() {value = dense<3.200000e+01> : tensor<f32>} : () -> tensor<f32>
%129 = "tosa.reciprocal"(%128) : (tensor<f32>) -> tensor<f32>
%130 = "tosa.mul"(%126, %129) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
%131 = torch.aten.div.Scalar %127, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%132 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%133 = "tosa.mul"(%117, %132) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
%134 = "tosa.sub"(%110, %133) : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
%135 = torch.aten.sub.Tensor %111, %118, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%136 = "tosa.const"() {value = dense<1.000000e-07> : tensor<f32>} : () -> tensor<f32>
%137 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%138 = "tosa.mul"(%136, %137) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%139 = "tosa.add"(%130, %138) : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
%140 = torch.aten.add.Scalar %131, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%141 = torch.aten.sqrt %140 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%142 = builtin.unrealized_conversion_cast %141 : !torch.vtensor<[1,128,1],f32> to tensor<1x128x1xf32>
%143 = "tosa.reciprocal"(%142) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
%144 = "tosa.mul"(%134, %143) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
%145 = torch.aten.div.Tensor %135, %141 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%146 = "tosa.mul"(%56, %144) {shift = 0 : i32} : (tensor<32xf32>, tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
%147 = torch.aten.mul.Tensor %57, %145 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%148 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%149 = "tosa.mul"(%54, %148) {shift = 0 : i32} : (tensor<32xf32>, tensor<f32>) -> tensor<32xf32>
%150 = "tosa.add"(%146, %149) : (tensor<1x128x32xf32>, tensor<32xf32>) -> tensor<1x128x32xf32>
%151 = torch.aten.add.Tensor %147, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%152 = "tosa.reshape"(%77) {new_shape = [1, 128]} : (tensor<1x128xf32>) -> tensor<1x128x1xf32>
%153 = torch.aten.unsqueeze %78, %int2 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%154 = "tosa.mul"(%150, %152) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
%155 = torch.aten.mul.Tensor %151, %153 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%156 = "tosa.reshape"(%77) {new_shape = [1, 1, 128]} : (tensor<1x128xf32>) -> tensor<1x1x128xf32>
%157 = torch.aten.unsqueeze %78, %int1 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
%158 = "tosa.reshape"(%156) {new_shape = [1, 1, 1, 128]} : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32>
%159 = torch.aten.unsqueeze %157, %int2 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,1,128],f32>
%160 = "tosa.reshape"(%158) {new_shape = [1, 1, 128]} : (tensor<1x1x1x128xf32>) -> tensor<1x1x128xf32>
%cast = tensor.cast %160 : tensor<1x1x128xf32> to tensor<1x1x128xf32>
%161 = torch.aten.squeeze.dim %159, %int-2 : !torch.vtensor<[1,1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
%162 = "tosa.reshape"(%cast) {new_shape = [1, 1, 1, 128]} : (tensor<1x1x128xf32>) -> tensor<1x1x128x1xf32>
%163 = torch.aten.unsqueeze %161, %int-1 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128,1],f32>
%164 = "tosa.mul"(%158, %162) {shift = 0 : i32} : (tensor<1x1x1x128xf32>, tensor<1x1x128x1xf32>) -> tensor<1x1x128x128xf32>
%165 = torch.aten.mul.Tensor %159, %163 : !torch.vtensor<[1,1,1,128],f32>, !torch.vtensor<[1,1,128,1],f32> -> !torch.vtensor<[1,1,128,128],f32>
%166 = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
%167 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%168 = "tosa.cast"(%166) : (tensor<i64>) -> tensor<i8>
%169 = torch.aten.to.dtype %167, %int1, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],si8>
%170 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%171 = torch.aten.broadcast_to %169, %170 : !torch.vtensor<[],si8>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],si8>
%172 = torch.aten.copy %169, %165, %false : !torch.vtensor<[],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool -> !torch.vtensor<[1,1,128,128],si8>
%173 = torch.aten.transpose.int %53, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%174 = torch.prim.ListConstruct %int128, %int32 : (!torch.int, !torch.int) -> !torch.list<int>
%175 = torch.aten.view %155, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%176 = torch.aten.mm %175, %173 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%177 = torch.prim.ListConstruct %int1, %int128, %int96 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%178 = torch.aten.view %176, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%179 = torch.prim.ListConstruct %int1, %int128, %int4, %int-1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%180 = torch.aten.view %178, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%181 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%182 = torch.aten.permute %180, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%183 = torch.aten.slice.Tensor %182, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%184 = torch.aten.slice.Tensor %182, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%185 = torch.aten.slice.Tensor %182, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%186 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%187 = torch.aten.unsqueeze %186, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%188 = torch.aten.slice.Tensor %187, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%189 = torch.prim.ListConstruct %int1, %int1, %int4, %int-1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%190 = torch.aten.view %188, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%191 = torch.aten.permute %190, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%192 = torch.aten.add.Tensor %183, %191, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%193 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%194 = torch.aten.unsqueeze %193, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%195 = torch.aten.slice.Tensor %194, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%196 = torch.aten.view %195, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%197 = torch.aten.permute %196, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%198 = torch.aten.add.Tensor %185, %197, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%199 = torch.aten.div.Scalar %192, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%200 = torch.aten.transpose.int %184, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%201 = torch.prim.ListConstruct %int1, %int4, %int128, %int8 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%202 = torch.aten.broadcast_to %199, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%203 = torch.prim.ListConstruct %int4, %int128, %int8 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%204 = torch.aten.view %202, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%205 = torch.prim.ListConstruct %int1, %int4, %int8, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%206 = torch.aten.broadcast_to %200, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%207 = torch.prim.ListConstruct %int4, %int8, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%208 = torch.aten.view %206, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%209 = torch.aten.bmm %204, %208 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%210 = torch.prim.ListConstruct %int1, %int4, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%211 = torch.aten.view %209, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%212 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%213 = torch.aten.to.dtype %212, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%214 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%215 = torch.aten.broadcast_to %213, %214 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%216 = torch.aten.copy %215, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%217 = torch.aten.bitwise_not %216 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%218 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%219 = torch.aten.masked_fill.Tensor %211, %217, %218 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values, %indices = torch.aten.max.dim %219, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%220 = torch.aten.sub.Tensor %219, %values, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%221 = torch.aten.exp %220 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%222 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%223 = torch.aten.sum.dim_IntList %221, %222, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%224 = torch.aten.div.Tensor %221, %223 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%225 = torch.aten.masked_fill.Scalar %224, %217, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%226 = torch.aten.broadcast_to %225, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%227 = torch.prim.ListConstruct %int4, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%228 = torch.aten.view %226, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%229 = torch.aten.broadcast_to %198, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%230 = torch.aten.view %229, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%231 = torch.aten.bmm %228, %230 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%232 = torch.aten.view %231, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%233 = torch.aten.permute %232, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%234 = torch.aten.clone %233, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%235 = torch.prim.ListConstruct %int1, %int128, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%236 = torch.aten.view %234, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%237 = torch.aten.transpose.int %49, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%238 = torch.aten.view %236, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%239 = torch.aten.mm %238, %237 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%240 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%241 = torch.aten.add.Tensor %240, %239, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%242 = torch.prim.ListConstruct %int1, %int128, %int32 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%243 = torch.aten.view %241, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%244 = torch.aten.add.Tensor %243, %155, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%245 = torch.aten.sum.dim_IntList %244, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%246 = torch.aten.div.Scalar %245, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%247 = torch.aten.sub.Tensor %244, %246, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%248 = torch.aten.pow.Tensor_Scalar %247, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%249 = torch.aten.sum.dim_IntList %248, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%250 = torch.aten.div.Scalar %249, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%251 = torch.aten.sub.Tensor %244, %246, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%252 = torch.aten.add.Scalar %250, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%253 = torch.aten.sqrt %252 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%254 = torch.aten.div.Tensor %251, %253 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%255 = torch.aten.mul.Tensor %57, %254 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%256 = torch.aten.add.Tensor %255, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%257 = torch.aten.transpose.int %47, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%258 = torch.aten.view %256, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%259 = torch.aten.mm %258, %257 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%260 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%261 = torch.aten.add.Tensor %260, %259, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%262 = torch.prim.ListConstruct %int1, %int128, %int37 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%263 = torch.aten.view %261, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%264 = torch.aten.gelu %263, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%265 = torch.aten.transpose.int %43, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%266 = torch.prim.ListConstruct %int128, %int37 : (!torch.int, !torch.int) -> !torch.list<int>
%267 = torch.aten.view %264, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%268 = torch.aten.mm %267, %265 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%269 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%270 = torch.aten.add.Tensor %269, %268, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%271 = torch.aten.view %270, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%272 = torch.aten.add.Tensor %271, %256, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%273 = torch.aten.sum.dim_IntList %272, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%274 = torch.aten.div.Scalar %273, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%275 = torch.aten.sub.Tensor %272, %274, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%276 = torch.aten.pow.Tensor_Scalar %275, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%277 = torch.aten.sum.dim_IntList %276, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%278 = torch.aten.div.Scalar %277, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%279 = torch.aten.sub.Tensor %272, %274, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%280 = torch.aten.add.Scalar %278, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%281 = torch.aten.sqrt %280 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%282 = torch.aten.div.Tensor %279, %281 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%283 = torch.aten.mul.Tensor %57, %282 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%284 = torch.aten.add.Tensor %283, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%285 = torch.aten.transpose.int %41, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%286 = torch.aten.view %284, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%287 = torch.aten.mm %286, %285 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%288 = torch.aten.view %287, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%289 = torch.aten.view %288, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%290 = torch.aten.permute %289, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%291 = torch.aten.slice.Tensor %290, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%292 = torch.aten.slice.Tensor %290, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%293 = torch.aten.slice.Tensor %290, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%294 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%295 = torch.aten.unsqueeze %294, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%296 = torch.aten.slice.Tensor %295, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%297 = torch.aten.view %296, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%298 = torch.aten.permute %297, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%299 = torch.aten.add.Tensor %291, %298, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%300 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%301 = torch.aten.unsqueeze %300, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%302 = torch.aten.slice.Tensor %301, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%303 = torch.aten.view %302, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%304 = torch.aten.permute %303, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%305 = torch.aten.add.Tensor %293, %304, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%306 = torch.aten.div.Scalar %299, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%307 = torch.aten.transpose.int %292, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%308 = torch.aten.broadcast_to %306, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%309 = torch.aten.view %308, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%310 = torch.aten.broadcast_to %307, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%311 = torch.aten.view %310, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%312 = torch.aten.bmm %309, %311 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%313 = torch.aten.view %312, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%314 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%315 = torch.aten.to.dtype %314, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%316 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%317 = torch.aten.broadcast_to %315, %316 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%318 = torch.aten.copy %317, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%319 = torch.aten.bitwise_not %318 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%320 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%321 = torch.aten.masked_fill.Tensor %313, %319, %320 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_0, %indices_1 = torch.aten.max.dim %321, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%322 = torch.aten.sub.Tensor %321, %values_0, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%323 = torch.aten.exp %322 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%324 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%325 = torch.aten.sum.dim_IntList %323, %324, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%326 = torch.aten.div.Tensor %323, %325 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%327 = torch.aten.masked_fill.Scalar %326, %319, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%328 = torch.aten.broadcast_to %327, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%329 = torch.aten.view %328, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%330 = torch.aten.broadcast_to %305, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%331 = torch.aten.view %330, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%332 = torch.aten.bmm %329, %331 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%333 = torch.aten.view %332, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%334 = torch.aten.permute %333, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%335 = torch.aten.clone %334, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%336 = torch.aten.view %335, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%337 = torch.aten.transpose.int %39, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%338 = torch.aten.view %336, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%339 = torch.aten.mm %338, %337 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%340 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%341 = torch.aten.add.Tensor %340, %339, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%342 = torch.aten.view %341, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%343 = torch.aten.add.Tensor %342, %284, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%344 = torch.aten.sum.dim_IntList %343, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%345 = torch.aten.div.Scalar %344, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%346 = torch.aten.sub.Tensor %343, %345, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%347 = torch.aten.pow.Tensor_Scalar %346, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%348 = torch.aten.sum.dim_IntList %347, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%349 = torch.aten.div.Scalar %348, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%350 = torch.aten.sub.Tensor %343, %345, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%351 = torch.aten.add.Scalar %349, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%352 = torch.aten.sqrt %351 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%353 = torch.aten.div.Tensor %350, %352 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%354 = torch.aten.mul.Tensor %57, %353 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%355 = torch.aten.add.Tensor %354, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%356 = torch.aten.transpose.int %37, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%357 = torch.aten.view %355, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%358 = torch.aten.mm %357, %356 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%359 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%360 = torch.aten.add.Tensor %359, %358, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%361 = torch.aten.view %360, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%362 = torch.aten.gelu %361, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%363 = torch.aten.transpose.int %35, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%364 = torch.aten.view %362, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%365 = torch.aten.mm %364, %363 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%366 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%367 = torch.aten.add.Tensor %366, %365, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%368 = torch.aten.view %367, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%369 = torch.aten.add.Tensor %368, %355, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%370 = torch.aten.sum.dim_IntList %369, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%371 = torch.aten.div.Scalar %370, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%372 = torch.aten.sub.Tensor %369, %371, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%373 = torch.aten.pow.Tensor_Scalar %372, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%374 = torch.aten.sum.dim_IntList %373, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%375 = torch.aten.div.Scalar %374, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%376 = torch.aten.sub.Tensor %369, %371, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%377 = torch.aten.add.Scalar %375, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%378 = torch.aten.sqrt %377 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%379 = torch.aten.div.Tensor %376, %378 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%380 = torch.aten.mul.Tensor %57, %379 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%381 = torch.aten.add.Tensor %380, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%382 = torch.aten.transpose.int %33, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%383 = torch.aten.view %381, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%384 = torch.aten.mm %383, %382 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%385 = torch.aten.view %384, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%386 = torch.aten.view %385, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%387 = torch.aten.permute %386, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%388 = torch.aten.slice.Tensor %387, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%389 = torch.aten.slice.Tensor %387, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%390 = torch.aten.slice.Tensor %387, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%391 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%392 = torch.aten.unsqueeze %391, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%393 = torch.aten.slice.Tensor %392, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%394 = torch.aten.view %393, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%395 = torch.aten.permute %394, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%396 = torch.aten.add.Tensor %388, %395, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%397 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%398 = torch.aten.unsqueeze %397, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%399 = torch.aten.slice.Tensor %398, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%400 = torch.aten.view %399, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%401 = torch.aten.permute %400, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%402 = torch.aten.add.Tensor %390, %401, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%403 = torch.aten.div.Scalar %396, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%404 = torch.aten.transpose.int %389, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%405 = torch.aten.broadcast_to %403, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%406 = torch.aten.view %405, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%407 = torch.aten.broadcast_to %404, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%408 = torch.aten.view %407, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%409 = torch.aten.bmm %406, %408 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%410 = torch.aten.view %409, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%411 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%412 = torch.aten.to.dtype %411, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%413 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%414 = torch.aten.broadcast_to %412, %413 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%415 = torch.aten.copy %414, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%416 = torch.aten.bitwise_not %415 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%417 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%418 = torch.aten.masked_fill.Tensor %410, %416, %417 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_2, %indices_3 = torch.aten.max.dim %418, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%419 = torch.aten.sub.Tensor %418, %values_2, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%420 = torch.aten.exp %419 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%421 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%422 = torch.aten.sum.dim_IntList %420, %421, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%423 = torch.aten.div.Tensor %420, %422 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%424 = torch.aten.masked_fill.Scalar %423, %416, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%425 = torch.aten.broadcast_to %424, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%426 = torch.aten.view %425, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%427 = torch.aten.broadcast_to %402, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%428 = torch.aten.view %427, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%429 = torch.aten.bmm %426, %428 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%430 = torch.aten.view %429, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%431 = torch.aten.permute %430, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%432 = torch.aten.clone %431, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%433 = torch.aten.view %432, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%434 = torch.aten.transpose.int %31, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%435 = torch.aten.view %433, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%436 = torch.aten.mm %435, %434 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%437 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%438 = torch.aten.add.Tensor %437, %436, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%439 = torch.aten.view %438, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%440 = torch.aten.add.Tensor %439, %381, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%441 = torch.aten.sum.dim_IntList %440, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%442 = torch.aten.div.Scalar %441, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%443 = torch.aten.sub.Tensor %440, %442, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%444 = torch.aten.pow.Tensor_Scalar %443, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%445 = torch.aten.sum.dim_IntList %444, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%446 = torch.aten.div.Scalar %445, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%447 = torch.aten.sub.Tensor %440, %442, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%448 = torch.aten.add.Scalar %446, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%449 = torch.aten.sqrt %448 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%450 = torch.aten.div.Tensor %447, %449 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%451 = torch.aten.mul.Tensor %57, %450 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%452 = torch.aten.add.Tensor %451, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%453 = torch.aten.transpose.int %29, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%454 = torch.aten.view %452, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%455 = torch.aten.mm %454, %453 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%456 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%457 = torch.aten.add.Tensor %456, %455, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%458 = torch.aten.view %457, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%459 = torch.aten.gelu %458, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%460 = torch.aten.transpose.int %27, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%461 = torch.aten.view %459, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%462 = torch.aten.mm %461, %460 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%463 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%464 = torch.aten.add.Tensor %463, %462, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%465 = torch.aten.view %464, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%466 = torch.aten.add.Tensor %465, %452, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%467 = torch.aten.sum.dim_IntList %466, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%468 = torch.aten.div.Scalar %467, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%469 = torch.aten.sub.Tensor %466, %468, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%470 = torch.aten.pow.Tensor_Scalar %469, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%471 = torch.aten.sum.dim_IntList %470, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%472 = torch.aten.div.Scalar %471, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%473 = torch.aten.sub.Tensor %466, %468, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%474 = torch.aten.add.Scalar %472, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%475 = torch.aten.sqrt %474 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%476 = torch.aten.div.Tensor %473, %475 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%477 = torch.aten.mul.Tensor %57, %476 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%478 = torch.aten.add.Tensor %477, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%479 = torch.aten.transpose.int %25, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%480 = torch.aten.view %478, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%481 = torch.aten.mm %480, %479 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%482 = torch.aten.view %481, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%483 = torch.aten.view %482, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%484 = torch.aten.permute %483, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%485 = torch.aten.slice.Tensor %484, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%486 = torch.aten.slice.Tensor %484, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%487 = torch.aten.slice.Tensor %484, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%488 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%489 = torch.aten.unsqueeze %488, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%490 = torch.aten.slice.Tensor %489, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%491 = torch.aten.view %490, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%492 = torch.aten.permute %491, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%493 = torch.aten.add.Tensor %485, %492, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%494 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%495 = torch.aten.unsqueeze %494, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%496 = torch.aten.slice.Tensor %495, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%497 = torch.aten.view %496, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%498 = torch.aten.permute %497, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%499 = torch.aten.add.Tensor %487, %498, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%500 = torch.aten.div.Scalar %493, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%501 = torch.aten.transpose.int %486, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%502 = torch.aten.broadcast_to %500, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%503 = torch.aten.view %502, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%504 = torch.aten.broadcast_to %501, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%505 = torch.aten.view %504, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%506 = torch.aten.bmm %503, %505 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%507 = torch.aten.view %506, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%508 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%509 = torch.aten.to.dtype %508, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%510 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%511 = torch.aten.broadcast_to %509, %510 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%512 = torch.aten.copy %511, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%513 = torch.aten.bitwise_not %512 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%514 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%515 = torch.aten.masked_fill.Tensor %507, %513, %514 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_4, %indices_5 = torch.aten.max.dim %515, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%516 = torch.aten.sub.Tensor %515, %values_4, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%517 = torch.aten.exp %516 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%518 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%519 = torch.aten.sum.dim_IntList %517, %518, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%520 = torch.aten.div.Tensor %517, %519 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%521 = torch.aten.masked_fill.Scalar %520, %513, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%522 = torch.aten.broadcast_to %521, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%523 = torch.aten.view %522, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%524 = torch.aten.broadcast_to %499, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%525 = torch.aten.view %524, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%526 = torch.aten.bmm %523, %525 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%527 = torch.aten.view %526, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%528 = torch.aten.permute %527, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%529 = torch.aten.clone %528, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%530 = torch.aten.view %529, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%531 = torch.aten.transpose.int %23, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%532 = torch.aten.view %530, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%533 = torch.aten.mm %532, %531 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%534 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%535 = torch.aten.add.Tensor %534, %533, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%536 = torch.aten.view %535, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%537 = torch.aten.add.Tensor %536, %478, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%538 = torch.aten.sum.dim_IntList %537, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%539 = torch.aten.div.Scalar %538, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%540 = torch.aten.sub.Tensor %537, %539, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%541 = torch.aten.pow.Tensor_Scalar %540, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%542 = torch.aten.sum.dim_IntList %541, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%543 = torch.aten.div.Scalar %542, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%544 = torch.aten.sub.Tensor %537, %539, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%545 = torch.aten.add.Scalar %543, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%546 = torch.aten.sqrt %545 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%547 = torch.aten.div.Tensor %544, %546 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%548 = torch.aten.mul.Tensor %57, %547 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%549 = torch.aten.add.Tensor %548, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%550 = torch.aten.transpose.int %21, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%551 = torch.aten.view %549, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%552 = torch.aten.mm %551, %550 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%553 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%554 = torch.aten.add.Tensor %553, %552, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%555 = torch.aten.view %554, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%556 = torch.aten.gelu %555, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%557 = torch.aten.transpose.int %19, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%558 = torch.aten.view %556, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%559 = torch.aten.mm %558, %557 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%560 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%561 = torch.aten.add.Tensor %560, %559, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%562 = torch.aten.view %561, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%563 = torch.aten.add.Tensor %562, %549, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%564 = torch.aten.sum.dim_IntList %563, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%565 = torch.aten.div.Scalar %564, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%566 = torch.aten.sub.Tensor %563, %565, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%567 = torch.aten.pow.Tensor_Scalar %566, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%568 = torch.aten.sum.dim_IntList %567, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%569 = torch.aten.div.Scalar %568, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%570 = torch.aten.sub.Tensor %563, %565, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%571 = torch.aten.add.Scalar %569, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%572 = torch.aten.sqrt %571 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%573 = torch.aten.div.Tensor %570, %572 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%574 = torch.aten.mul.Tensor %57, %573 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%575 = torch.aten.add.Tensor %574, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%576 = torch.aten.transpose.int %17, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%577 = torch.aten.view %575, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%578 = torch.aten.mm %577, %576 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%579 = torch.aten.view %578, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%580 = torch.aten.view %579, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%581 = torch.aten.permute %580, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%582 = torch.aten.slice.Tensor %581, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%583 = torch.aten.slice.Tensor %581, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%584 = torch.aten.slice.Tensor %581, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%585 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%586 = torch.aten.unsqueeze %585, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%587 = torch.aten.slice.Tensor %586, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%588 = torch.aten.view %587, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%589 = torch.aten.permute %588, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%590 = torch.aten.add.Tensor %582, %589, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%591 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%592 = torch.aten.unsqueeze %591, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%593 = torch.aten.slice.Tensor %592, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%594 = torch.aten.view %593, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%595 = torch.aten.permute %594, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%596 = torch.aten.add.Tensor %584, %595, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%597 = torch.aten.div.Scalar %590, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%598 = torch.aten.transpose.int %583, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%599 = torch.aten.broadcast_to %597, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%600 = torch.aten.view %599, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%601 = torch.aten.broadcast_to %598, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%602 = torch.aten.view %601, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%603 = torch.aten.bmm %600, %602 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%604 = torch.aten.view %603, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%605 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%606 = torch.aten.to.dtype %605, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%607 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%608 = torch.aten.broadcast_to %606, %607 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%609 = torch.aten.copy %608, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%610 = torch.aten.bitwise_not %609 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%611 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%612 = torch.aten.masked_fill.Tensor %604, %610, %611 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_6, %indices_7 = torch.aten.max.dim %612, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%613 = torch.aten.sub.Tensor %612, %values_6, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%614 = torch.aten.exp %613 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%615 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%616 = torch.aten.sum.dim_IntList %614, %615, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%617 = torch.aten.div.Tensor %614, %616 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%618 = torch.aten.masked_fill.Scalar %617, %610, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%619 = torch.aten.broadcast_to %618, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%620 = torch.aten.view %619, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%621 = torch.aten.broadcast_to %596, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%622 = torch.aten.view %621, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%623 = torch.aten.bmm %620, %622 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%624 = torch.aten.view %623, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%625 = torch.aten.permute %624, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%626 = torch.aten.clone %625, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%627 = torch.aten.view %626, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%628 = torch.aten.transpose.int %15, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%629 = torch.aten.view %627, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%630 = torch.aten.mm %629, %628 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%631 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%632 = torch.aten.add.Tensor %631, %630, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%633 = torch.aten.view %632, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%634 = torch.aten.add.Tensor %633, %575, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%635 = torch.aten.sum.dim_IntList %634, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%636 = torch.aten.div.Scalar %635, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%637 = torch.aten.sub.Tensor %634, %636, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%638 = torch.aten.pow.Tensor_Scalar %637, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%639 = torch.aten.sum.dim_IntList %638, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%640 = torch.aten.div.Scalar %639, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%641 = torch.aten.sub.Tensor %634, %636, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%642 = torch.aten.add.Scalar %640, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%643 = torch.aten.sqrt %642 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%644 = torch.aten.div.Tensor %641, %643 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%645 = torch.aten.mul.Tensor %57, %644 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%646 = torch.aten.add.Tensor %645, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%647 = torch.aten.transpose.int %13, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%648 = torch.aten.view %646, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%649 = torch.aten.mm %648, %647 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%650 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%651 = torch.aten.add.Tensor %650, %649, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%652 = torch.aten.view %651, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%653 = torch.aten.gelu %652, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%654 = torch.aten.transpose.int %11, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%655 = torch.aten.view %653, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%656 = torch.aten.mm %655, %654 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%657 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%658 = torch.aten.add.Tensor %657, %656, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%659 = torch.aten.view %658, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%660 = torch.aten.add.Tensor %659, %646, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%661 = torch.aten.sum.dim_IntList %660, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%662 = torch.aten.div.Scalar %661, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%663 = torch.aten.sub.Tensor %660, %662, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%664 = torch.aten.pow.Tensor_Scalar %663, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%665 = torch.aten.sum.dim_IntList %664, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%666 = torch.aten.div.Scalar %665, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%667 = torch.aten.sub.Tensor %660, %662, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%668 = torch.aten.add.Scalar %666, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%669 = torch.aten.sqrt %668 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%670 = torch.aten.div.Tensor %667, %669 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%671 = torch.aten.mul.Tensor %57, %670 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%672 = torch.aten.add.Tensor %671, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%673 = torch.aten.slice.Tensor %672, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32],f32>
%674 = torch.aten.slice.Tensor %673, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%675 = torch.aten.squeeze.dim %674, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%676 = torch.aten.transpose.int %9, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%677 = torch.aten.mm %675, %676 : !torch.vtensor<[1,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1,32],f32>
%678 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%679 = torch.aten.add.Tensor %678, %677, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%680 = torch.aten.gelu %679, %str : !torch.vtensor<[1,32],f32>, !torch.str -> !torch.vtensor<[1,32],f32>
%681 = torch.aten.transpose.int %7, %int0, %int1 : !torch.vtensor<[2,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,2],f32>
%682 = torch.aten.mm %680, %681 : !torch.vtensor<[1,32],f32>, !torch.vtensor<[32,2],f32> -> !torch.vtensor<[1,2],f32>
%683 = torch.aten.mul.Scalar %5, %int1 : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
%684 = torch.aten.add.Tensor %683, %682, %int1 : !torch.vtensor<[2],f32>, !torch.vtensor<[1,2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
return %684 : !torch.vtensor<[1,2],f32>
}
Copy op Error:
Legalizing operation : 'torch.aten.copy'(0x996e240) {
%197 = "torch.aten.copy"(%194, %190, %70) : (!torch.vtensor<[],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool) -> !torch.vtensor<[1,1,128,128],si8>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.copy -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenCopyOp>"
** Failure : casting to result dtype is invalid or unsupported
** Failure : unimplemented: cast to result type not supported
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenCopyOp>" result 0
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<eval_with_key>.2:38:15: error: failed to legalize operation 'torch.aten.copy' that was explicitly marked illegal
<eval_with_key>.2:38:15: note: see current operation: %197 = "torch.aten.copy"(%194, %190, %70) : (!torch.vtensor<[],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool) -> !torch.vtensor<[1,1,128,128],si8>
Fixed by add llvm/torch-mlir#1592
All fix.
Here is the output deberta_tosa.mlir
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run:
python tank/pytorch/deberta/deberta_tosa.py
Get: cmd to repeat the error and generate ir
Run:
torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug
Get: all IR after each Conversion combine with tosa.ops and torch.aten.ops. The torch.aten.ops is the op that need to fix. Here is the ir print after last successfully converted op.
The last successfully conversion is:
Convert from
TO:
BY: