Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active December 19, 2022 07:11
Show Gist options
  • Save AmosLewis/df6feb5c9618bd4b5a843c3ce490a2f9 to your computer and use it in GitHub Desktop.
Save AmosLewis/df6feb5c9618bd4b5a843c3ce490a2f9 to your computer and use it in GitHub Desktop.
deberta
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
def prepare_sentence_tokens(hf_model: str, sentence: str):
tokenizer = AutoTokenizer.from_pretrained(hf_model)
return torch.tensor([tokenizer.encode(sentence)])
class HfMaskedLM(torch.nn.Module):
def __init__(self, model_name: str):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, # The pretrained model name.
# The number of output labels--2 for binary classification.
num_labels=2,
# Whether the model returns attentions weights.
output_attentions=False,
# Whether the model returns all hidden-states.
output_hidden_states=False,
torchscript=True,
)
self.model.eval()
def forward(self, tokens):
return self.model.forward(tokens)[0]
#hf_minilm_model = "hf-internal-testing/tiny-random-deberta"
hf_minilm_model = "microsoft/deberta-v3-base"
test_input = torch.randint(2, (1, 128))
model = HfMaskedLM(hf_minilm_model)
print("model(test_input): ")
print(model(test_input))
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(test_input)
# print(fx_g.graph)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
# module = torch_mlir.compile(
# ts_g,
# (test_input),
# torch_mlir.OutputType.LINALG_ON_TENSORS,
# use_tracing=True,
# verbose=False,
# )
module = torch_mlir.compile(
ts_g,
(test_input),
torch_mlir.OutputType.TOSA,
use_tracing=True,
verbose=False,
)
module.dump()
from shark.shark_inference import SharkInference
mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="tosa"
)
shark_module.compile()
def shark_result(x):
x_ny = x.detach().numpy()
inputs = (x_ny,)
result = shark_module.forward(inputs)
return torch.from_numpy(result)
observed_out = shark_result(test_input)
print(observed_out)
@AmosLewis
Copy link
Author

AmosLewis commented Nov 15, 2022

Run:
python tank/pytorch/deberta/deberta_tosa.py
Get: cmd to repeat the error and generate ir

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

Run:
torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug
Get: all IR after each Conversion combine with tosa.ops and torch.aten.ops. The torch.aten.ops is the op that need to fix. Here is the ir print after last successfully converted op.

The last successfully conversion is:
Convert from

  %162 = "tosa.reshape"(%cast) {new_shape = [1, 1, 1, 128]} : (tensor<1x1x128xf32>) -> tensor<1x1x128x1xf32> loc("<eval_with_key>.2":36:18)
  %163 = torch.aten.unsqueeze %161, %int-1 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128,1],f32> loc("<eval_with_key>.2":36:18)
  %164 = "tosa.mul"(%158, %162) {shift = 0 : i32} : (tensor<1x1x1x128xf32>, tensor<1x1x128x1xf32>) -> tensor<1x1x128x128xf32> loc("<eval_with_key>.2":37:12)
  %165 = torch.aten.mul.Tensor %159, %163 : !torch.vtensor<[1,1,1,128],f32>, !torch.vtensor<[1,1,128,1],f32> -> !torch.vtensor<[1,1,128,128],f32> loc("<eval_with_key>.2":37:12)
  %166 = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
  %167 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
  %168 = "tosa.cast"(%166) : (tensor<i64>) -> tensor<i8>
  %169 = torch.aten.to.dtype %167, %int1, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],si8>
  %170 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %171 = torch.aten.broadcast_to %169, %170 : !torch.vtensor<[],si8>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],si8>
  %172 = torch.aten.copy %171, %165, %false : !torch.vtensor<[1,1,128,128],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool -> !torch.vtensor<[1,1,128,128],si8>

TO:

  %162 = "tosa.reshape"(%cast) {new_shape = [1, 1, 1, 128]} : (tensor<1x1x128xf32>) -> tensor<1x1x128x1xf32> loc("<eval_with_key>.2":36:18)
  %163 = torch.aten.unsqueeze %161, %int-1 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128,1],f32> loc("<eval_with_key>.2":36:18)
  %164 = "tosa.mul"(%158, %162) {shift = 0 : i32} : (tensor<1x1x1x128xf32>, tensor<1x1x128x1xf32>) -> tensor<1x1x128x128xf32> loc("<eval_with_key>.2":37:12)
  %165 = torch.aten.mul.Tensor %159, %163 : !torch.vtensor<[1,1,1,128],f32>, !torch.vtensor<[1,1,128,1],f32> -> !torch.vtensor<[1,1,128,128],f32> loc("<eval_with_key>.2":37:12)
  %166 = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> loc("<eval_with_key>.2":38:15)
  %167 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> loc("<eval_with_key>.2":38:15)
  %168 = "tosa.cast"(%166) : (tensor<i64>) -> tensor<i8>
  %169 = torch.aten.to.dtype %167, %int1, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],si8>
  %170 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %171 = torch.aten.broadcast_to %169, %170 : !torch.vtensor<[],si8>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],si8>
  %172 = torch.aten.copy %169, %165, %false : !torch.vtensor<[],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool -> !torch.vtensor<[1,1,128,128],si8>

BY:

//===-------------------------------------------===//
Legalizing operation : 'torch.prim.ListConstruct'(0x996cd90) {
  %195 = "torch.prim.ListConstruct"(%1, %1, %5, %5) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'torch.aten.broadcast_to'(0x996e130) {
  %196 = "torch.aten.broadcast_to"(%194, %195) : (!torch.vtensor<[],si8>, !torch.list<int>) -> !torch.vtensor<[1,1,128,128],si8>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.broadcast_to -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenBroadcastToOp>"
    ** Erase   : 'torch.aten.broadcast_to'(0x996e130)
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenBroadcastToOp>" result 1
  } -> SUCCESS : pattern applied successfully


@AmosLewis
Copy link
Author

AmosLewis commented Nov 15, 2022

IR Dump After Last Successfully Pattern Application

func.func @forward(%arg0: !torch.vtensor<[1,128],si64>) -> !torch.vtensor<[1,2],f32> {
  %0 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[1,128],si64> to tensor<1x128xi64>
  %int1 = torch.constant.int 1
  %1 = builtin.unrealized_conversion_cast %int1 : !torch.int to i64
  %int32 = torch.constant.int 32
  %2 = builtin.unrealized_conversion_cast %int32 : !torch.int to i64
  %int128 = torch.constant.int 128
  %3 = builtin.unrealized_conversion_cast %int128 : !torch.int to i64
  %float1.000000e00 = torch.constant.float 1.000000e+00
  %4 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
  %5 = torch.vtensor.literal(dense<0.000000e+00> : tensor<2xf32>) : !torch.vtensor<[2],f32>
  %6 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2x32xf32>} : () -> tensor<2x32xf32>
  %7 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2x32xf32>) : !torch.vtensor<[2,32],f32>
  %8 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
  %9 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
  %10 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
  %11 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
  %12 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
  %13 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
  %14 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
  %15 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
  %16 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
  %17 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
  %18 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
  %19 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
  %20 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
  %21 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
  %22 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
  %23 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
  %24 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
  %25 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
  %26 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
  %27 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
  %28 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
  %29 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
  %30 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
  %31 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
  %32 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
  %33 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
  %34 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
  %35 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
  %36 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
  %37 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
  %38 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
  %39 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
  %40 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
  %41 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
  %42 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x37xf32>} : () -> tensor<32x37xf32>
  %43 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
  %44 = "tosa.const"() {value = dense<0.000000e+00> : tensor<37xf32>} : () -> tensor<37xf32>
  %45 = torch.vtensor.literal(dense<0.000000e+00> : tensor<37xf32>) : !torch.vtensor<[37],f32>
  %46 = "tosa.const"() {value = dense_resource<__elided__> : tensor<37x32xf32>} : () -> tensor<37x32xf32>
  %47 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
  %48 = "tosa.const"() {value = dense_resource<__elided__> : tensor<32x32xf32>} : () -> tensor<32x32xf32>
  %49 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
  %50 = "tosa.const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
  %51 = torch.vtensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.vtensor<[],f32>
  %52 = "tosa.const"() {value = dense_resource<__elided__> : tensor<96x32xf32>} : () -> tensor<96x32xf32>
  %53 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
  %54 = "tosa.const"() {value = dense<0.000000e+00> : tensor<32xf32>} : () -> tensor<32xf32>
  %55 = torch.vtensor.literal(dense<0.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
  %56 = "tosa.const"() {value = dense<1.000000e+00> : tensor<32xf32>} : () -> tensor<32xf32>
  %57 = torch.vtensor.literal(dense<1.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
  %58 = "tosa.const"() {value = dense_resource<__elided__> : tensor<16x32xf32>} : () -> tensor<16x32xf32>
  %59 = torch.vtensor.literal(dense_resource<__elided__> : tensor<16x32xf32>) : !torch.vtensor<[16,32],f32>
  %60 = "tosa.const"() {value = dense_resource<__elided__> : tensor<512x32xf32>} : () -> tensor<512x32xf32>
  %61 = torch.vtensor.literal(dense_resource<__elided__> : tensor<512x32xf32>) : !torch.vtensor<[512,32],f32>
  %62 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1000x32xf32>} : () -> tensor<1000x32xf32>
  %63 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1000x32xf32>) : !torch.vtensor<[1000,32],f32>
  %64 = "tosa.const"() {value = dense_resource<__elided__> : tensor<1x512xsi64>} : () -> tensor<1x512xi64>
  %65 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
  %false = torch.constant.bool false
  %66 = builtin.unrealized_conversion_cast %false : !torch.bool to i1
  %none = torch.constant.none
  %int-1 = torch.constant.int -1
  %67 = builtin.unrealized_conversion_cast %int-1 : !torch.int to i64
  %true = torch.constant.bool true
  %68 = builtin.unrealized_conversion_cast %true : !torch.bool to i1
  %int-2 = torch.constant.int -2
  %69 = builtin.unrealized_conversion_cast %int-2 : !torch.int to i64
  %int11 = torch.constant.int 11
  %str = torch.constant.str "none"
  %int0 = torch.constant.int 0
  %70 = builtin.unrealized_conversion_cast %int0 : !torch.int to i64
  %int9223372036854775807 = torch.constant.int 9223372036854775807
  %71 = builtin.unrealized_conversion_cast %int9223372036854775807 : !torch.int to i64
  %int2 = torch.constant.int 2
  %72 = builtin.unrealized_conversion_cast %int2 : !torch.int to i64
  %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8
  %73 = builtin.unrealized_conversion_cast %float9.999990e-08 : !torch.float to f64
  %int96 = torch.constant.int 96
  %int4 = torch.constant.int 4
  %74 = builtin.unrealized_conversion_cast %int4 : !torch.int to i64
  %int3 = torch.constant.int 3
  %int8 = torch.constant.int 8
  %int16 = torch.constant.int 16
  %int24 = torch.constant.int 24
  %float4.000000e00 = torch.constant.float 4.000000e+00
  %int37 = torch.constant.int 37
  %cpu = torch.constant.device "cpu"
  %75 = torch.prim.ListConstruct %int1, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
  %76 = "tosa.const"() {value = dense<1> : tensor<1x128xi32>} : () -> tensor<1x128xi32>
  %77 = "tosa.cast"(%76) : (tensor<1x128xi32>) -> tensor<1x128xf32>
  %78 = torch.aten.ones %75, %none, %none, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],f32>
  %79 = "tosa.const"() {value = dense<0> : tensor<1x128xi32>} : () -> tensor<1x128xi32>
  %80 = "tosa.cast"(%79) : (tensor<1x128xi32>) -> tensor<1x128xi64>
  %81 = torch.aten.zeros %75, %int4, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],si64>
  %82 = "tosa.slice"(%64) {size = [9223372036854775807, 512], start = [0, 0]} : (tensor<1x512xi64>) -> tensor<1x512xi64>
  %83 = torch.aten.slice.Tensor %65, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64>
  %84 = "tosa.slice"(%82) {size = [1, 128], start = [0, 0]} : (tensor<1x512xi64>) -> tensor<1x128xi64>
  %85 = torch.aten.slice.Tensor %83, %int1, %int0, %int128, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128],si64>
  %86 = "tosa.reshape"(%62) {new_shape = [1, 1000, 32]} : (tensor<1000x32xf32>) -> tensor<1x1000x32xf32>
  %87 = "tosa.reshape"(%0) {new_shape = [1, 128]} : (tensor<1x128xi64>) -> tensor<1x128xi64>
  %88 = "tosa.cast"(%87) : (tensor<1x128xi64>) -> tensor<1x128xi32>
  %89 = "tosa.gather"(%86, %88) : (tensor<1x1000x32xf32>, tensor<1x128xi32>) -> tensor<1x128x32xf32>
  %90 = "tosa.reshape"(%89) {new_shape = [1, 128, 32]} : (tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
  %91 = torch.aten.embedding %63, %arg0, %int0, %false, %false : !torch.vtensor<[1000,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
  %92 = "tosa.reshape"(%60) {new_shape = [1, 512, 32]} : (tensor<512x32xf32>) -> tensor<1x512x32xf32>
  %93 = "tosa.reshape"(%84) {new_shape = [1, 128]} : (tensor<1x128xi64>) -> tensor<1x128xi64>
  %94 = "tosa.cast"(%93) : (tensor<1x128xi64>) -> tensor<1x128xi32>
  %95 = "tosa.gather"(%92, %94) : (tensor<1x512x32xf32>, tensor<1x128xi32>) -> tensor<1x128x32xf32>
  %96 = "tosa.reshape"(%95) {new_shape = [1, 128, 32]} : (tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
  %97 = torch.aten.embedding %61, %85, %int-1, %false, %false : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
  %98 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %99 = "tosa.mul"(%96, %98) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<f32>) -> tensor<1x128x32xf32>
  %100 = "tosa.add"(%90, %99) : (tensor<1x128x32xf32>, tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
  %101 = torch.aten.add.Tensor %91, %97, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %102 = "tosa.reshape"(%58) {new_shape = [1, 16, 32]} : (tensor<16x32xf32>) -> tensor<1x16x32xf32>
  %103 = "tosa.reshape"(%80) {new_shape = [1, 128]} : (tensor<1x128xi64>) -> tensor<1x128xi64>
  %104 = "tosa.cast"(%103) : (tensor<1x128xi64>) -> tensor<1x128xi32>
  %105 = "tosa.gather"(%102, %104) : (tensor<1x16x32xf32>, tensor<1x128xi32>) -> tensor<1x128x32xf32>
  %106 = "tosa.reshape"(%105) {new_shape = [1, 128, 32]} : (tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
  %107 = torch.aten.embedding %59, %81, %int-1, %false, %false : !torch.vtensor<[16,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
  %108 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %109 = "tosa.mul"(%106, %108) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<f32>) -> tensor<1x128x32xf32>
  %110 = "tosa.add"(%100, %109) : (tensor<1x128x32xf32>, tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
  %111 = torch.aten.add.Tensor %101, %107, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %112 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
  %113 = "tosa.reduce_sum"(%110) {axis = 2 : i64} : (tensor<1x128x32xf32>) -> tensor<1x128x1xf32>
  %114 = torch.aten.sum.dim_IntList %111, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %115 = "tosa.const"() {value = dense<3.200000e+01> : tensor<f32>} : () -> tensor<f32>
  %116 = "tosa.reciprocal"(%115) : (tensor<f32>) -> tensor<f32>
  %117 = "tosa.mul"(%113, %116) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
  %118 = torch.aten.div.Scalar %114, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %119 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %120 = "tosa.mul"(%117, %119) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
  %121 = "tosa.sub"(%110, %120) : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
  %122 = torch.aten.sub.Tensor %111, %118, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %123 = "tosa.const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %124 = "tosa.pow"(%121, %123) : (tensor<1x128x32xf32>, tensor<f32>) -> tensor<1x128x32xf32>
  %125 = torch.aten.pow.Tensor_Scalar %122, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %126 = "tosa.reduce_sum"(%124) {axis = 2 : i64} : (tensor<1x128x32xf32>) -> tensor<1x128x1xf32>
  %127 = torch.aten.sum.dim_IntList %125, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %128 = "tosa.const"() {value = dense<3.200000e+01> : tensor<f32>} : () -> tensor<f32>
  %129 = "tosa.reciprocal"(%128) : (tensor<f32>) -> tensor<f32>
  %130 = "tosa.mul"(%126, %129) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
  %131 = torch.aten.div.Scalar %127, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %132 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %133 = "tosa.mul"(%117, %132) {shift = 0 : i32} : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
  %134 = "tosa.sub"(%110, %133) : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
  %135 = torch.aten.sub.Tensor %111, %118, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %136 = "tosa.const"() {value = dense<1.000000e-07> : tensor<f32>} : () -> tensor<f32>
  %137 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %138 = "tosa.mul"(%136, %137) {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
  %139 = "tosa.add"(%130, %138) : (tensor<1x128x1xf32>, tensor<f32>) -> tensor<1x128x1xf32>
  %140 = torch.aten.add.Scalar %131, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %141 = torch.aten.sqrt %140 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %142 = builtin.unrealized_conversion_cast %141 : !torch.vtensor<[1,128,1],f32> to tensor<1x128x1xf32>
  %143 = "tosa.reciprocal"(%142) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32>
  %144 = "tosa.mul"(%134, %143) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
  %145 = torch.aten.div.Tensor %135, %141 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %146 = "tosa.mul"(%56, %144) {shift = 0 : i32} : (tensor<32xf32>, tensor<1x128x32xf32>) -> tensor<1x128x32xf32>
  %147 = torch.aten.mul.Tensor %57, %145 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %148 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %149 = "tosa.mul"(%54, %148) {shift = 0 : i32} : (tensor<32xf32>, tensor<f32>) -> tensor<32xf32>
  %150 = "tosa.add"(%146, %149) : (tensor<1x128x32xf32>, tensor<32xf32>) -> tensor<1x128x32xf32>
  %151 = torch.aten.add.Tensor %147, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %152 = "tosa.reshape"(%77) {new_shape = [1, 128]} : (tensor<1x128xf32>) -> tensor<1x128x1xf32>
  %153 = torch.aten.unsqueeze %78, %int2 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %154 = "tosa.mul"(%150, %152) {shift = 0 : i32} : (tensor<1x128x32xf32>, tensor<1x128x1xf32>) -> tensor<1x128x32xf32>
  %155 = torch.aten.mul.Tensor %151, %153 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %156 = "tosa.reshape"(%77) {new_shape = [1, 1, 128]} : (tensor<1x128xf32>) -> tensor<1x1x128xf32>
  %157 = torch.aten.unsqueeze %78, %int1 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
  %158 = "tosa.reshape"(%156) {new_shape = [1, 1, 1, 128]} : (tensor<1x1x128xf32>) -> tensor<1x1x1x128xf32>
  %159 = torch.aten.unsqueeze %157, %int2 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,1,128],f32>
  %160 = "tosa.reshape"(%158) {new_shape = [1, 1, 128]} : (tensor<1x1x1x128xf32>) -> tensor<1x1x128xf32>
  %cast = tensor.cast %160 : tensor<1x1x128xf32> to tensor<1x1x128xf32>
  %161 = torch.aten.squeeze.dim %159, %int-2 : !torch.vtensor<[1,1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
  %162 = "tosa.reshape"(%cast) {new_shape = [1, 1, 1, 128]} : (tensor<1x1x128xf32>) -> tensor<1x1x128x1xf32>
  %163 = torch.aten.unsqueeze %161, %int-1 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128,1],f32>
  %164 = "tosa.mul"(%158, %162) {shift = 0 : i32} : (tensor<1x1x1x128xf32>, tensor<1x1x128x1xf32>) -> tensor<1x1x128x128xf32>
  %165 = torch.aten.mul.Tensor %159, %163 : !torch.vtensor<[1,1,1,128],f32>, !torch.vtensor<[1,1,128,1],f32> -> !torch.vtensor<[1,1,128,128],f32>
  %166 = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
  %167 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
  %168 = "tosa.cast"(%166) : (tensor<i64>) -> tensor<i8>
  %169 = torch.aten.to.dtype %167, %int1, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],si8>
  %170 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %171 = torch.aten.broadcast_to %169, %170 : !torch.vtensor<[],si8>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],si8>
  %172 = torch.aten.copy %169, %165, %false : !torch.vtensor<[],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool -> !torch.vtensor<[1,1,128,128],si8>
  %173 = torch.aten.transpose.int %53, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
  %174 = torch.prim.ListConstruct %int128, %int32 : (!torch.int, !torch.int) -> !torch.list<int>
  %175 = torch.aten.view %155, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %176 = torch.aten.mm %175, %173 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
  %177 = torch.prim.ListConstruct %int1, %int128, %int96 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %178 = torch.aten.view %176, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
  %179 = torch.prim.ListConstruct %int1, %int128, %int4, %int-1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %180 = torch.aten.view %178, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
  %181 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %182 = torch.aten.permute %180, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
  %183 = torch.aten.slice.Tensor %182, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %184 = torch.aten.slice.Tensor %182, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %185 = torch.aten.slice.Tensor %182, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %186 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %187 = torch.aten.unsqueeze %186, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %188 = torch.aten.slice.Tensor %187, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %189 = torch.prim.ListConstruct %int1, %int1, %int4, %int-1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %190 = torch.aten.view %188, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %191 = torch.aten.permute %190, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %192 = torch.aten.add.Tensor %183, %191, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %193 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %194 = torch.aten.unsqueeze %193, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %195 = torch.aten.slice.Tensor %194, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %196 = torch.aten.view %195, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %197 = torch.aten.permute %196, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %198 = torch.aten.add.Tensor %185, %197, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %199 = torch.aten.div.Scalar %192, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
  %200 = torch.aten.transpose.int %184, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
  %201 = torch.prim.ListConstruct %int1, %int4, %int128, %int8 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %202 = torch.aten.broadcast_to %199, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %203 = torch.prim.ListConstruct %int4, %int128, %int8 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %204 = torch.aten.view %202, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %205 = torch.prim.ListConstruct %int1, %int4, %int8, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %206 = torch.aten.broadcast_to %200, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
  %207 = torch.prim.ListConstruct %int4, %int8, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %208 = torch.aten.view %206, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
  %209 = torch.aten.bmm %204, %208 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
  %210 = torch.prim.ListConstruct %int1, %int4, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %211 = torch.aten.view %209, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %212 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
  %213 = torch.aten.to.dtype %212, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
  %214 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %215 = torch.aten.broadcast_to %213, %214 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
  %216 = torch.aten.copy %215, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
  %217 = torch.aten.bitwise_not %216 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
  %218 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
  %219 = torch.aten.masked_fill.Tensor %211, %217, %218 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %values, %indices = torch.aten.max.dim %219, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
  %220 = torch.aten.sub.Tensor %219, %values, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
  %221 = torch.aten.exp %220 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %222 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
  %223 = torch.aten.sum.dim_IntList %221, %222, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
  %224 = torch.aten.div.Tensor %221, %223 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %225 = torch.aten.masked_fill.Scalar %224, %217, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
  %226 = torch.aten.broadcast_to %225, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %227 = torch.prim.ListConstruct %int4, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %228 = torch.aten.view %226, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
  %229 = torch.aten.broadcast_to %198, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %230 = torch.aten.view %229, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %231 = torch.aten.bmm %228, %230 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
  %232 = torch.aten.view %231, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %233 = torch.aten.permute %232, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
  %234 = torch.aten.clone %233, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
  %235 = torch.prim.ListConstruct %int1, %int128, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %236 = torch.aten.view %234, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %237 = torch.aten.transpose.int %49, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
  %238 = torch.aten.view %236, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %239 = torch.aten.mm %238, %237 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
  %240 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %241 = torch.aten.add.Tensor %240, %239, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %242 = torch.prim.ListConstruct %int1, %int128, %int32 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %243 = torch.aten.view %241, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %244 = torch.aten.add.Tensor %243, %155, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %245 = torch.aten.sum.dim_IntList %244, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %246 = torch.aten.div.Scalar %245, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %247 = torch.aten.sub.Tensor %244, %246, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %248 = torch.aten.pow.Tensor_Scalar %247, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %249 = torch.aten.sum.dim_IntList %248, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %250 = torch.aten.div.Scalar %249, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %251 = torch.aten.sub.Tensor %244, %246, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %252 = torch.aten.add.Scalar %250, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %253 = torch.aten.sqrt %252 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %254 = torch.aten.div.Tensor %251, %253 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %255 = torch.aten.mul.Tensor %57, %254 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %256 = torch.aten.add.Tensor %255, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %257 = torch.aten.transpose.int %47, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
  %258 = torch.aten.view %256, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %259 = torch.aten.mm %258, %257 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
  %260 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
  %261 = torch.aten.add.Tensor %260, %259, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
  %262 = torch.prim.ListConstruct %int1, %int128, %int37 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %263 = torch.aten.view %261, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
  %264 = torch.aten.gelu %263, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
  %265 = torch.aten.transpose.int %43, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
  %266 = torch.prim.ListConstruct %int128, %int37 : (!torch.int, !torch.int) -> !torch.list<int>
  %267 = torch.aten.view %264, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
  %268 = torch.aten.mm %267, %265 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
  %269 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %270 = torch.aten.add.Tensor %269, %268, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %271 = torch.aten.view %270, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %272 = torch.aten.add.Tensor %271, %256, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %273 = torch.aten.sum.dim_IntList %272, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %274 = torch.aten.div.Scalar %273, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %275 = torch.aten.sub.Tensor %272, %274, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %276 = torch.aten.pow.Tensor_Scalar %275, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %277 = torch.aten.sum.dim_IntList %276, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %278 = torch.aten.div.Scalar %277, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %279 = torch.aten.sub.Tensor %272, %274, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %280 = torch.aten.add.Scalar %278, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %281 = torch.aten.sqrt %280 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %282 = torch.aten.div.Tensor %279, %281 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %283 = torch.aten.mul.Tensor %57, %282 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %284 = torch.aten.add.Tensor %283, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %285 = torch.aten.transpose.int %41, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
  %286 = torch.aten.view %284, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %287 = torch.aten.mm %286, %285 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
  %288 = torch.aten.view %287, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
  %289 = torch.aten.view %288, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
  %290 = torch.aten.permute %289, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
  %291 = torch.aten.slice.Tensor %290, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %292 = torch.aten.slice.Tensor %290, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %293 = torch.aten.slice.Tensor %290, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %294 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %295 = torch.aten.unsqueeze %294, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %296 = torch.aten.slice.Tensor %295, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %297 = torch.aten.view %296, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %298 = torch.aten.permute %297, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %299 = torch.aten.add.Tensor %291, %298, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %300 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %301 = torch.aten.unsqueeze %300, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %302 = torch.aten.slice.Tensor %301, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %303 = torch.aten.view %302, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %304 = torch.aten.permute %303, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %305 = torch.aten.add.Tensor %293, %304, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %306 = torch.aten.div.Scalar %299, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
  %307 = torch.aten.transpose.int %292, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
  %308 = torch.aten.broadcast_to %306, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %309 = torch.aten.view %308, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %310 = torch.aten.broadcast_to %307, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
  %311 = torch.aten.view %310, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
  %312 = torch.aten.bmm %309, %311 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
  %313 = torch.aten.view %312, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %314 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
  %315 = torch.aten.to.dtype %314, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
  %316 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %317 = torch.aten.broadcast_to %315, %316 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
  %318 = torch.aten.copy %317, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
  %319 = torch.aten.bitwise_not %318 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
  %320 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
  %321 = torch.aten.masked_fill.Tensor %313, %319, %320 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %values_0, %indices_1 = torch.aten.max.dim %321, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
  %322 = torch.aten.sub.Tensor %321, %values_0, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
  %323 = torch.aten.exp %322 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %324 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
  %325 = torch.aten.sum.dim_IntList %323, %324, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
  %326 = torch.aten.div.Tensor %323, %325 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %327 = torch.aten.masked_fill.Scalar %326, %319, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
  %328 = torch.aten.broadcast_to %327, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %329 = torch.aten.view %328, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
  %330 = torch.aten.broadcast_to %305, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %331 = torch.aten.view %330, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %332 = torch.aten.bmm %329, %331 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
  %333 = torch.aten.view %332, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %334 = torch.aten.permute %333, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
  %335 = torch.aten.clone %334, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
  %336 = torch.aten.view %335, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %337 = torch.aten.transpose.int %39, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
  %338 = torch.aten.view %336, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %339 = torch.aten.mm %338, %337 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
  %340 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %341 = torch.aten.add.Tensor %340, %339, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %342 = torch.aten.view %341, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %343 = torch.aten.add.Tensor %342, %284, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %344 = torch.aten.sum.dim_IntList %343, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %345 = torch.aten.div.Scalar %344, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %346 = torch.aten.sub.Tensor %343, %345, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %347 = torch.aten.pow.Tensor_Scalar %346, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %348 = torch.aten.sum.dim_IntList %347, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %349 = torch.aten.div.Scalar %348, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %350 = torch.aten.sub.Tensor %343, %345, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %351 = torch.aten.add.Scalar %349, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %352 = torch.aten.sqrt %351 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %353 = torch.aten.div.Tensor %350, %352 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %354 = torch.aten.mul.Tensor %57, %353 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %355 = torch.aten.add.Tensor %354, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %356 = torch.aten.transpose.int %37, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
  %357 = torch.aten.view %355, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %358 = torch.aten.mm %357, %356 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
  %359 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
  %360 = torch.aten.add.Tensor %359, %358, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
  %361 = torch.aten.view %360, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
  %362 = torch.aten.gelu %361, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
  %363 = torch.aten.transpose.int %35, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
  %364 = torch.aten.view %362, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
  %365 = torch.aten.mm %364, %363 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
  %366 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %367 = torch.aten.add.Tensor %366, %365, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %368 = torch.aten.view %367, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %369 = torch.aten.add.Tensor %368, %355, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %370 = torch.aten.sum.dim_IntList %369, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %371 = torch.aten.div.Scalar %370, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %372 = torch.aten.sub.Tensor %369, %371, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %373 = torch.aten.pow.Tensor_Scalar %372, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %374 = torch.aten.sum.dim_IntList %373, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %375 = torch.aten.div.Scalar %374, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %376 = torch.aten.sub.Tensor %369, %371, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %377 = torch.aten.add.Scalar %375, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %378 = torch.aten.sqrt %377 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %379 = torch.aten.div.Tensor %376, %378 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %380 = torch.aten.mul.Tensor %57, %379 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %381 = torch.aten.add.Tensor %380, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %382 = torch.aten.transpose.int %33, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
  %383 = torch.aten.view %381, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %384 = torch.aten.mm %383, %382 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
  %385 = torch.aten.view %384, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
  %386 = torch.aten.view %385, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
  %387 = torch.aten.permute %386, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
  %388 = torch.aten.slice.Tensor %387, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %389 = torch.aten.slice.Tensor %387, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %390 = torch.aten.slice.Tensor %387, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %391 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %392 = torch.aten.unsqueeze %391, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %393 = torch.aten.slice.Tensor %392, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %394 = torch.aten.view %393, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %395 = torch.aten.permute %394, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %396 = torch.aten.add.Tensor %388, %395, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %397 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %398 = torch.aten.unsqueeze %397, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %399 = torch.aten.slice.Tensor %398, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %400 = torch.aten.view %399, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %401 = torch.aten.permute %400, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %402 = torch.aten.add.Tensor %390, %401, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %403 = torch.aten.div.Scalar %396, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
  %404 = torch.aten.transpose.int %389, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
  %405 = torch.aten.broadcast_to %403, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %406 = torch.aten.view %405, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %407 = torch.aten.broadcast_to %404, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
  %408 = torch.aten.view %407, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
  %409 = torch.aten.bmm %406, %408 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
  %410 = torch.aten.view %409, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %411 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
  %412 = torch.aten.to.dtype %411, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
  %413 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %414 = torch.aten.broadcast_to %412, %413 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
  %415 = torch.aten.copy %414, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
  %416 = torch.aten.bitwise_not %415 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
  %417 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
  %418 = torch.aten.masked_fill.Tensor %410, %416, %417 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %values_2, %indices_3 = torch.aten.max.dim %418, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
  %419 = torch.aten.sub.Tensor %418, %values_2, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
  %420 = torch.aten.exp %419 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %421 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
  %422 = torch.aten.sum.dim_IntList %420, %421, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
  %423 = torch.aten.div.Tensor %420, %422 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %424 = torch.aten.masked_fill.Scalar %423, %416, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
  %425 = torch.aten.broadcast_to %424, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %426 = torch.aten.view %425, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
  %427 = torch.aten.broadcast_to %402, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %428 = torch.aten.view %427, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %429 = torch.aten.bmm %426, %428 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
  %430 = torch.aten.view %429, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %431 = torch.aten.permute %430, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
  %432 = torch.aten.clone %431, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
  %433 = torch.aten.view %432, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %434 = torch.aten.transpose.int %31, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
  %435 = torch.aten.view %433, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %436 = torch.aten.mm %435, %434 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
  %437 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %438 = torch.aten.add.Tensor %437, %436, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %439 = torch.aten.view %438, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %440 = torch.aten.add.Tensor %439, %381, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %441 = torch.aten.sum.dim_IntList %440, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %442 = torch.aten.div.Scalar %441, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %443 = torch.aten.sub.Tensor %440, %442, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %444 = torch.aten.pow.Tensor_Scalar %443, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %445 = torch.aten.sum.dim_IntList %444, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %446 = torch.aten.div.Scalar %445, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %447 = torch.aten.sub.Tensor %440, %442, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %448 = torch.aten.add.Scalar %446, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %449 = torch.aten.sqrt %448 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %450 = torch.aten.div.Tensor %447, %449 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %451 = torch.aten.mul.Tensor %57, %450 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %452 = torch.aten.add.Tensor %451, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %453 = torch.aten.transpose.int %29, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
  %454 = torch.aten.view %452, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %455 = torch.aten.mm %454, %453 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
  %456 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
  %457 = torch.aten.add.Tensor %456, %455, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
  %458 = torch.aten.view %457, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
  %459 = torch.aten.gelu %458, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
  %460 = torch.aten.transpose.int %27, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
  %461 = torch.aten.view %459, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
  %462 = torch.aten.mm %461, %460 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
  %463 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %464 = torch.aten.add.Tensor %463, %462, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %465 = torch.aten.view %464, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %466 = torch.aten.add.Tensor %465, %452, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %467 = torch.aten.sum.dim_IntList %466, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %468 = torch.aten.div.Scalar %467, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %469 = torch.aten.sub.Tensor %466, %468, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %470 = torch.aten.pow.Tensor_Scalar %469, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %471 = torch.aten.sum.dim_IntList %470, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %472 = torch.aten.div.Scalar %471, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %473 = torch.aten.sub.Tensor %466, %468, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %474 = torch.aten.add.Scalar %472, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %475 = torch.aten.sqrt %474 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %476 = torch.aten.div.Tensor %473, %475 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %477 = torch.aten.mul.Tensor %57, %476 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %478 = torch.aten.add.Tensor %477, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %479 = torch.aten.transpose.int %25, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
  %480 = torch.aten.view %478, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %481 = torch.aten.mm %480, %479 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
  %482 = torch.aten.view %481, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
  %483 = torch.aten.view %482, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
  %484 = torch.aten.permute %483, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
  %485 = torch.aten.slice.Tensor %484, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %486 = torch.aten.slice.Tensor %484, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %487 = torch.aten.slice.Tensor %484, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %488 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %489 = torch.aten.unsqueeze %488, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %490 = torch.aten.slice.Tensor %489, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %491 = torch.aten.view %490, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %492 = torch.aten.permute %491, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %493 = torch.aten.add.Tensor %485, %492, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %494 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %495 = torch.aten.unsqueeze %494, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %496 = torch.aten.slice.Tensor %495, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %497 = torch.aten.view %496, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %498 = torch.aten.permute %497, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %499 = torch.aten.add.Tensor %487, %498, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %500 = torch.aten.div.Scalar %493, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
  %501 = torch.aten.transpose.int %486, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
  %502 = torch.aten.broadcast_to %500, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %503 = torch.aten.view %502, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %504 = torch.aten.broadcast_to %501, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
  %505 = torch.aten.view %504, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
  %506 = torch.aten.bmm %503, %505 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
  %507 = torch.aten.view %506, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %508 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
  %509 = torch.aten.to.dtype %508, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
  %510 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %511 = torch.aten.broadcast_to %509, %510 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
  %512 = torch.aten.copy %511, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
  %513 = torch.aten.bitwise_not %512 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
  %514 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
  %515 = torch.aten.masked_fill.Tensor %507, %513, %514 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %values_4, %indices_5 = torch.aten.max.dim %515, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
  %516 = torch.aten.sub.Tensor %515, %values_4, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
  %517 = torch.aten.exp %516 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %518 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
  %519 = torch.aten.sum.dim_IntList %517, %518, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
  %520 = torch.aten.div.Tensor %517, %519 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %521 = torch.aten.masked_fill.Scalar %520, %513, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
  %522 = torch.aten.broadcast_to %521, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %523 = torch.aten.view %522, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
  %524 = torch.aten.broadcast_to %499, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %525 = torch.aten.view %524, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %526 = torch.aten.bmm %523, %525 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
  %527 = torch.aten.view %526, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %528 = torch.aten.permute %527, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
  %529 = torch.aten.clone %528, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
  %530 = torch.aten.view %529, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %531 = torch.aten.transpose.int %23, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
  %532 = torch.aten.view %530, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %533 = torch.aten.mm %532, %531 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
  %534 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %535 = torch.aten.add.Tensor %534, %533, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %536 = torch.aten.view %535, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %537 = torch.aten.add.Tensor %536, %478, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %538 = torch.aten.sum.dim_IntList %537, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %539 = torch.aten.div.Scalar %538, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %540 = torch.aten.sub.Tensor %537, %539, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %541 = torch.aten.pow.Tensor_Scalar %540, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %542 = torch.aten.sum.dim_IntList %541, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %543 = torch.aten.div.Scalar %542, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %544 = torch.aten.sub.Tensor %537, %539, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %545 = torch.aten.add.Scalar %543, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %546 = torch.aten.sqrt %545 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %547 = torch.aten.div.Tensor %544, %546 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %548 = torch.aten.mul.Tensor %57, %547 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %549 = torch.aten.add.Tensor %548, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %550 = torch.aten.transpose.int %21, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
  %551 = torch.aten.view %549, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %552 = torch.aten.mm %551, %550 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
  %553 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
  %554 = torch.aten.add.Tensor %553, %552, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
  %555 = torch.aten.view %554, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
  %556 = torch.aten.gelu %555, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
  %557 = torch.aten.transpose.int %19, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
  %558 = torch.aten.view %556, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
  %559 = torch.aten.mm %558, %557 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
  %560 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %561 = torch.aten.add.Tensor %560, %559, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %562 = torch.aten.view %561, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %563 = torch.aten.add.Tensor %562, %549, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %564 = torch.aten.sum.dim_IntList %563, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %565 = torch.aten.div.Scalar %564, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %566 = torch.aten.sub.Tensor %563, %565, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %567 = torch.aten.pow.Tensor_Scalar %566, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %568 = torch.aten.sum.dim_IntList %567, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %569 = torch.aten.div.Scalar %568, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %570 = torch.aten.sub.Tensor %563, %565, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %571 = torch.aten.add.Scalar %569, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %572 = torch.aten.sqrt %571 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %573 = torch.aten.div.Tensor %570, %572 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %574 = torch.aten.mul.Tensor %57, %573 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %575 = torch.aten.add.Tensor %574, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %576 = torch.aten.transpose.int %17, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
  %577 = torch.aten.view %575, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %578 = torch.aten.mm %577, %576 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
  %579 = torch.aten.view %578, %177 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
  %580 = torch.aten.view %579, %179 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
  %581 = torch.aten.permute %580, %181 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
  %582 = torch.aten.slice.Tensor %581, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %583 = torch.aten.slice.Tensor %581, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %584 = torch.aten.slice.Tensor %581, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %585 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %586 = torch.aten.unsqueeze %585, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %587 = torch.aten.slice.Tensor %586, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %588 = torch.aten.view %587, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %589 = torch.aten.permute %588, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %590 = torch.aten.add.Tensor %582, %589, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %591 = torch.aten.unsqueeze %55, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %592 = torch.aten.unsqueeze %591, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %593 = torch.aten.slice.Tensor %592, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %594 = torch.aten.view %593, %189 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
  %595 = torch.aten.permute %594, %181 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
  %596 = torch.aten.add.Tensor %584, %595, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
  %597 = torch.aten.div.Scalar %590, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
  %598 = torch.aten.transpose.int %583, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
  %599 = torch.aten.broadcast_to %597, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %600 = torch.aten.view %599, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %601 = torch.aten.broadcast_to %598, %205 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
  %602 = torch.aten.view %601, %207 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
  %603 = torch.aten.bmm %600, %602 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
  %604 = torch.aten.view %603, %210 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %605 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
  %606 = torch.aten.to.dtype %605, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
  %607 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %608 = torch.aten.broadcast_to %606, %607 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
  %609 = torch.aten.copy %608, %172, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
  %610 = torch.aten.bitwise_not %609 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
  %611 = torch.aten.clone %51, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
  %612 = torch.aten.masked_fill.Tensor %604, %610, %611 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %values_6, %indices_7 = torch.aten.max.dim %612, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
  %613 = torch.aten.sub.Tensor %612, %values_6, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
  %614 = torch.aten.exp %613 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %615 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
  %616 = torch.aten.sum.dim_IntList %614, %615, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
  %617 = torch.aten.div.Tensor %614, %616 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
  %618 = torch.aten.masked_fill.Scalar %617, %610, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
  %619 = torch.aten.broadcast_to %618, %210 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
  %620 = torch.aten.view %619, %227 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
  %621 = torch.aten.broadcast_to %596, %201 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %622 = torch.aten.view %621, %203 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
  %623 = torch.aten.bmm %620, %622 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
  %624 = torch.aten.view %623, %201 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
  %625 = torch.aten.permute %624, %181 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
  %626 = torch.aten.clone %625, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
  %627 = torch.aten.view %626, %235 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %628 = torch.aten.transpose.int %15, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
  %629 = torch.aten.view %627, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %630 = torch.aten.mm %629, %628 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
  %631 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %632 = torch.aten.add.Tensor %631, %630, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %633 = torch.aten.view %632, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %634 = torch.aten.add.Tensor %633, %575, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %635 = torch.aten.sum.dim_IntList %634, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %636 = torch.aten.div.Scalar %635, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %637 = torch.aten.sub.Tensor %634, %636, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %638 = torch.aten.pow.Tensor_Scalar %637, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %639 = torch.aten.sum.dim_IntList %638, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %640 = torch.aten.div.Scalar %639, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %641 = torch.aten.sub.Tensor %634, %636, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %642 = torch.aten.add.Scalar %640, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %643 = torch.aten.sqrt %642 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %644 = torch.aten.div.Tensor %641, %643 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %645 = torch.aten.mul.Tensor %57, %644 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %646 = torch.aten.add.Tensor %645, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %647 = torch.aten.transpose.int %13, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
  %648 = torch.aten.view %646, %174 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
  %649 = torch.aten.mm %648, %647 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
  %650 = torch.aten.mul.Scalar %45, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
  %651 = torch.aten.add.Tensor %650, %649, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
  %652 = torch.aten.view %651, %262 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
  %653 = torch.aten.gelu %652, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
  %654 = torch.aten.transpose.int %11, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
  %655 = torch.aten.view %653, %266 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
  %656 = torch.aten.mm %655, %654 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
  %657 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %658 = torch.aten.add.Tensor %657, %656, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
  %659 = torch.aten.view %658, %242 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
  %660 = torch.aten.add.Tensor %659, %646, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %661 = torch.aten.sum.dim_IntList %660, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %662 = torch.aten.div.Scalar %661, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %663 = torch.aten.sub.Tensor %660, %662, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %664 = torch.aten.pow.Tensor_Scalar %663, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %665 = torch.aten.sum.dim_IntList %664, %112, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
  %666 = torch.aten.div.Scalar %665, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %667 = torch.aten.sub.Tensor %660, %662, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %668 = torch.aten.add.Scalar %666, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
  %669 = torch.aten.sqrt %668 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
  %670 = torch.aten.div.Tensor %667, %669 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
  %671 = torch.aten.mul.Tensor %57, %670 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
  %672 = torch.aten.add.Tensor %671, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %673 = torch.aten.slice.Tensor %672, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32],f32>
  %674 = torch.aten.slice.Tensor %673, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
  %675 = torch.aten.squeeze.dim %674, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %676 = torch.aten.transpose.int %9, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
  %677 = torch.aten.mm %675, %676 : !torch.vtensor<[1,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1,32],f32>
  %678 = torch.aten.mul.Scalar %55, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
  %679 = torch.aten.add.Tensor %678, %677, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
  %680 = torch.aten.gelu %679, %str : !torch.vtensor<[1,32],f32>, !torch.str -> !torch.vtensor<[1,32],f32>
  %681 = torch.aten.transpose.int %7, %int0, %int1 : !torch.vtensor<[2,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,2],f32>
  %682 = torch.aten.mm %680, %681 : !torch.vtensor<[1,32],f32>, !torch.vtensor<[32,2],f32> -> !torch.vtensor<[1,2],f32>
  %683 = torch.aten.mul.Scalar %5, %int1 : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
  %684 = torch.aten.add.Tensor %683, %682, %int1 : !torch.vtensor<[2],f32>, !torch.vtensor<[1,2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
  return %684 : !torch.vtensor<[1,2],f32>
}

@AmosLewis
Copy link
Author

AmosLewis commented Nov 15, 2022

Copy op Error:

Legalizing operation : 'torch.aten.copy'(0x996e240) {
  %197 = "torch.aten.copy"(%194, %190, %70) : (!torch.vtensor<[],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool) -> !torch.vtensor<[1,1,128,128],si8>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.copy -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenCopyOp>"
    ** Failure : casting to result dtype is invalid or unsupported
    ** Failure : unimplemented: cast to result type not supported
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenCopyOp>" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<eval_with_key>.2:38:15: error: failed to legalize operation 'torch.aten.copy' that was explicitly marked illegal
<eval_with_key>.2:38:15: note: see current operation: %197 = "torch.aten.copy"(%194, %190, %70) : (!torch.vtensor<[],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool) -> !torch.vtensor<[1,1,128,128],si8>

Fixed by add llvm/torch-mlir#1592

@AmosLewis
Copy link
Author

All fix.
Here is the output deberta_tosa.mlir

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment