Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created November 10, 2022 06:06
Show Gist options
  • Save AmosLewis/1c654721fcf9a974484bef178e8d449d to your computer and use it in GitHub Desktop.
Save AmosLewis/1c654721fcf9a974484bef178e8d449d to your computer and use it in GitHub Desktop.
➜ SHARK git:(main) ✗ python tank/pytorch/deberta/deberta_tosa.py
Some weights of the model checkpoint at hf-internal-testing/tiny-random-deberta were not used when initializing DebertaForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'qa_outputs.bias', 'cls.predictions.transform.LayerNorm.weight', 'qa_outputs.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
model(test_input):
tensor([[-0.0038, -0.0069]], grad_fn=<AddmmBackward0>)
/home/chi/src/ubuntu20/shark/SHARK/shark.venv/lib/python3.10/site-packages/torch/jit/_check.py:181: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn("The TorchScript type system doesn't support "
/home/chi/src/ubuntu20/shark/SHARK/shark.venv/lib/python3.10/site-packages/torch/jit/_trace.py:744: UserWarning: The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is.
warnings.warn(
Traceback (most recent call last):
File "/home/chi/src/ubuntu20/shark/SHARK/tank/pytorch/deberta/deberta_tosa.py", line 89, in <module>
module = torch_mlir.compile(
File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 289, in compile
run_pipeline_with_repro_report(
File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.aten.unsqueeze' that was explicitly marked illegal
note: see current operation: %175 = "torch.aten.unsqueeze"(%101, %84) : (!torch.vtensor<[1,128],f32>, !torch.int) -> !torch.vtensor<[1,128,1],f32>
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torch-backend-to-tosa-backend-pipeline' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
➜ SHARK git:(main) ✗ torch-mlir-opt -pass-pipeline='torch-backend-to-tosa-backend-pipeline' /tmp/deberta1.mlir -mlir-print-ir-after-all -mlir-disable-threading
<eval_with_key>.2:31:16: error: failed to legalize operation 'torch.aten.unsqueeze' that was explicitly marked illegal
<eval_with_key>.2:31:16: note: see current operation: %175 = "torch.aten.unsqueeze"(%101, %84) : (!torch.vtensor<[1,128],f32>, !torch.int) -> !torch.vtensor<[1,128,1],f32>
// -----// IR Dump After ConvertTorchToTosa Failed (convert-torch-to-tosa) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,128],si64>) -> !torch.vtensor<[1,2],f32> {
%int1 = torch.constant.int 1
%int32 = torch.constant.int 32
%int128 = torch.constant.int 128
%float1.000000e00 = torch.constant.float 1.000000e+00
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<2xf32>) : !torch.vtensor<[2],f32>
%1 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2x32xf32>) : !torch.vtensor<[2,32],f32>
%2 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%3 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%4 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%5 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%6 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%7 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%8 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%9 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%10 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%11 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%12 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%13 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%14 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%15 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%16 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%17 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%18 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%19 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%20 = torch.vtensor.literal(dense<0.000000e+00> : tensor<37xf32>) : !torch.vtensor<[37],f32>
%21 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%22 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%23 = torch.vtensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.vtensor<[],f32>
%24 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%25 = torch.vtensor.literal(dense<0.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
%26 = torch.vtensor.literal(dense<1.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
%27 = torch.vtensor.literal(dense_resource<__elided__> : tensor<16x32xf32>) : !torch.vtensor<[16,32],f32>
%28 = torch.vtensor.literal(dense_resource<__elided__> : tensor<512x32xf32>) : !torch.vtensor<[512,32],f32>
%29 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1000x32xf32>) : !torch.vtensor<[1000,32],f32>
%30 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
%false = torch.constant.bool false
%none = torch.constant.none
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
%int-2 = torch.constant.int -2
%int11 = torch.constant.int 11
%str = torch.constant.str "none"
%int0 = torch.constant.int 0
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int2 = torch.constant.int 2
%float9.999990e-08 = torch.constant.float 9.9999999999999995E-8
%int96 = torch.constant.int 96
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%int8 = torch.constant.int 8
%int16 = torch.constant.int 16
%int24 = torch.constant.int 24
%float4.000000e00 = torch.constant.float 4.000000e+00
%int37 = torch.constant.int 37
%cpu = torch.constant.device "cpu"
%31 = torch.prim.ListConstruct %int1, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
%32 = torch.aten.ones %31, %none, %none, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],f32>
%33 = torch.aten.zeros %31, %int4, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],si64>
%34 = torch.aten.slice.Tensor %30, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64>
%35 = torch.aten.slice.Tensor %34, %int1, %int0, %int128, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128],si64>
%36 = torch.aten.embedding %29, %arg0, %int0, %false, %false : !torch.vtensor<[1000,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%37 = torch.aten.embedding %28, %35, %int-1, %false, %false : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%38 = torch.aten.add.Tensor %36, %37, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%39 = torch.aten.embedding %27, %33, %int-1, %false, %false : !torch.vtensor<[16,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%40 = torch.aten.add.Tensor %38, %39, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%41 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%42 = torch.aten.sum.dim_IntList %40, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%43 = torch.aten.div.Scalar %42, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%44 = torch.aten.sub.Tensor %40, %43, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%45 = torch.aten.pow.Tensor_Scalar %44, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%46 = torch.aten.sum.dim_IntList %45, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%47 = torch.aten.div.Scalar %46, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%48 = torch.aten.sub.Tensor %40, %43, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%49 = torch.aten.add.Scalar %47, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%50 = torch.aten.sqrt %49 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%51 = torch.aten.div.Tensor %48, %50 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%52 = torch.aten.mul.Tensor %26, %51 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%53 = torch.aten.add.Tensor %52, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%54 = torch.aten.unsqueeze %32, %int2 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%55 = torch.aten.mul.Tensor %53, %54 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%56 = torch.aten.unsqueeze %32, %int1 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
%57 = torch.aten.unsqueeze %56, %int2 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,1,128],f32>
%58 = torch.aten.squeeze.dim %57, %int-2 : !torch.vtensor<[1,1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
%59 = torch.aten.unsqueeze %58, %int-1 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128,1],f32>
%60 = torch.aten.mul.Tensor %57, %59 : !torch.vtensor<[1,1,1,128],f32>, !torch.vtensor<[1,1,128,1],f32> -> !torch.vtensor<[1,1,128,128],f32>
%61 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%62 = torch.aten.to.dtype %61, %int1, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],si8>
%63 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%64 = torch.aten.broadcast_to %62, %63 : !torch.vtensor<[],si8>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],si8>
%65 = torch.aten.copy %64, %60, %false : !torch.vtensor<[1,1,128,128],si8>, !torch.vtensor<[1,1,128,128],f32>, !torch.bool -> !torch.vtensor<[1,1,128,128],si8>
%66 = torch.aten.transpose.int %24, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%67 = torch.prim.ListConstruct %int128, %int32 : (!torch.int, !torch.int) -> !torch.list<int>
%68 = torch.aten.view %55, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%69 = torch.aten.mm %68, %66 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%70 = torch.prim.ListConstruct %int1, %int128, %int96 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%71 = torch.aten.view %69, %70 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%72 = torch.prim.ListConstruct %int1, %int128, %int4, %int-1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%73 = torch.aten.view %71, %72 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%74 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%75 = torch.aten.permute %73, %74 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%76 = torch.aten.slice.Tensor %75, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%77 = torch.aten.slice.Tensor %75, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%78 = torch.aten.slice.Tensor %75, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%79 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%80 = torch.aten.unsqueeze %79, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%81 = torch.aten.slice.Tensor %80, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%82 = torch.prim.ListConstruct %int1, %int1, %int4, %int-1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%83 = torch.aten.view %81, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%84 = torch.aten.permute %83, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%85 = torch.aten.add.Tensor %76, %84, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%86 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%87 = torch.aten.unsqueeze %86, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%88 = torch.aten.slice.Tensor %87, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%89 = torch.aten.view %88, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%90 = torch.aten.permute %89, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%91 = torch.aten.add.Tensor %78, %90, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%92 = torch.aten.div.Scalar %85, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%93 = torch.aten.transpose.int %77, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%94 = torch.prim.ListConstruct %int1, %int4, %int128, %int8 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%95 = torch.aten.broadcast_to %92, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%96 = torch.prim.ListConstruct %int4, %int128, %int8 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%97 = torch.aten.view %95, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%98 = torch.prim.ListConstruct %int1, %int4, %int8, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%99 = torch.aten.broadcast_to %93, %98 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%100 = torch.prim.ListConstruct %int4, %int8, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%101 = torch.aten.view %99, %100 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%102 = torch.aten.bmm %97, %101 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%103 = torch.prim.ListConstruct %int1, %int4, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%104 = torch.aten.view %102, %103 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%105 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%106 = torch.aten.to.dtype %105, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%107 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%108 = torch.aten.broadcast_to %106, %107 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%109 = torch.aten.copy %108, %65, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%110 = torch.aten.bitwise_not %109 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%111 = torch.aten.clone %23, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%112 = torch.aten.masked_fill.Tensor %104, %110, %111 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values, %indices = torch.aten.max.dim %112, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%113 = torch.aten.sub.Tensor %112, %values, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%114 = torch.aten.exp %113 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%115 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%116 = torch.aten.sum.dim_IntList %114, %115, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%117 = torch.aten.div.Tensor %114, %116 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%118 = torch.aten.masked_fill.Scalar %117, %110, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%119 = torch.aten.broadcast_to %118, %103 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%120 = torch.prim.ListConstruct %int4, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%121 = torch.aten.view %119, %120 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%122 = torch.aten.broadcast_to %91, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%123 = torch.aten.view %122, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%124 = torch.aten.bmm %121, %123 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%125 = torch.aten.view %124, %94 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%126 = torch.aten.permute %125, %74 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%127 = torch.aten.clone %126, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%128 = torch.prim.ListConstruct %int1, %int128, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%129 = torch.aten.view %127, %128 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%130 = torch.aten.transpose.int %22, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%131 = torch.aten.view %129, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%132 = torch.aten.mm %131, %130 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%133 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%134 = torch.aten.add.Tensor %133, %132, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%135 = torch.prim.ListConstruct %int1, %int128, %int32 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%136 = torch.aten.view %134, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%137 = torch.aten.add.Tensor %136, %55, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%138 = torch.aten.sum.dim_IntList %137, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%139 = torch.aten.div.Scalar %138, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%140 = torch.aten.sub.Tensor %137, %139, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%141 = torch.aten.pow.Tensor_Scalar %140, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%142 = torch.aten.sum.dim_IntList %141, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%143 = torch.aten.div.Scalar %142, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%144 = torch.aten.sub.Tensor %137, %139, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%145 = torch.aten.add.Scalar %143, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%146 = torch.aten.sqrt %145 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%147 = torch.aten.div.Tensor %144, %146 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%148 = torch.aten.mul.Tensor %26, %147 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%149 = torch.aten.add.Tensor %148, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%150 = torch.aten.transpose.int %21, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%151 = torch.aten.view %149, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%152 = torch.aten.mm %151, %150 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%153 = torch.aten.mul.Scalar %20, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%154 = torch.aten.add.Tensor %153, %152, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%155 = torch.prim.ListConstruct %int1, %int128, %int37 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%156 = torch.aten.view %154, %155 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%157 = torch.aten.gelu %156, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%158 = torch.aten.transpose.int %19, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%159 = torch.prim.ListConstruct %int128, %int37 : (!torch.int, !torch.int) -> !torch.list<int>
%160 = torch.aten.view %157, %159 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%161 = torch.aten.mm %160, %158 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%162 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%163 = torch.aten.add.Tensor %162, %161, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%164 = torch.aten.view %163, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%165 = torch.aten.add.Tensor %164, %149, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%166 = torch.aten.sum.dim_IntList %165, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%167 = torch.aten.div.Scalar %166, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%168 = torch.aten.sub.Tensor %165, %167, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%169 = torch.aten.pow.Tensor_Scalar %168, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%170 = torch.aten.sum.dim_IntList %169, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%171 = torch.aten.div.Scalar %170, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%172 = torch.aten.sub.Tensor %165, %167, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%173 = torch.aten.add.Scalar %171, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%174 = torch.aten.sqrt %173 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%175 = torch.aten.div.Tensor %172, %174 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%176 = torch.aten.mul.Tensor %26, %175 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%177 = torch.aten.add.Tensor %176, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%178 = torch.aten.transpose.int %18, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%179 = torch.aten.view %177, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%180 = torch.aten.mm %179, %178 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%181 = torch.aten.view %180, %70 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%182 = torch.aten.view %181, %72 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%183 = torch.aten.permute %182, %74 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%184 = torch.aten.slice.Tensor %183, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%185 = torch.aten.slice.Tensor %183, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%186 = torch.aten.slice.Tensor %183, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%187 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%188 = torch.aten.unsqueeze %187, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%189 = torch.aten.slice.Tensor %188, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%190 = torch.aten.view %189, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%191 = torch.aten.permute %190, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%192 = torch.aten.add.Tensor %184, %191, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%193 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%194 = torch.aten.unsqueeze %193, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%195 = torch.aten.slice.Tensor %194, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%196 = torch.aten.view %195, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%197 = torch.aten.permute %196, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%198 = torch.aten.add.Tensor %186, %197, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%199 = torch.aten.div.Scalar %192, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%200 = torch.aten.transpose.int %185, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%201 = torch.aten.broadcast_to %199, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%202 = torch.aten.view %201, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%203 = torch.aten.broadcast_to %200, %98 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%204 = torch.aten.view %203, %100 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%205 = torch.aten.bmm %202, %204 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%206 = torch.aten.view %205, %103 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%207 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%208 = torch.aten.to.dtype %207, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%209 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%210 = torch.aten.broadcast_to %208, %209 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%211 = torch.aten.copy %210, %65, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%212 = torch.aten.bitwise_not %211 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%213 = torch.aten.clone %23, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%214 = torch.aten.masked_fill.Tensor %206, %212, %213 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_0, %indices_1 = torch.aten.max.dim %214, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%215 = torch.aten.sub.Tensor %214, %values_0, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%216 = torch.aten.exp %215 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%217 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%218 = torch.aten.sum.dim_IntList %216, %217, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%219 = torch.aten.div.Tensor %216, %218 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%220 = torch.aten.masked_fill.Scalar %219, %212, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%221 = torch.aten.broadcast_to %220, %103 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%222 = torch.aten.view %221, %120 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%223 = torch.aten.broadcast_to %198, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%224 = torch.aten.view %223, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%225 = torch.aten.bmm %222, %224 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%226 = torch.aten.view %225, %94 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%227 = torch.aten.permute %226, %74 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%228 = torch.aten.clone %227, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%229 = torch.aten.view %228, %128 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%230 = torch.aten.transpose.int %17, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%231 = torch.aten.view %229, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%232 = torch.aten.mm %231, %230 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%233 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%234 = torch.aten.add.Tensor %233, %232, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%235 = torch.aten.view %234, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%236 = torch.aten.add.Tensor %235, %177, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%237 = torch.aten.sum.dim_IntList %236, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%238 = torch.aten.div.Scalar %237, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%239 = torch.aten.sub.Tensor %236, %238, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%240 = torch.aten.pow.Tensor_Scalar %239, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%241 = torch.aten.sum.dim_IntList %240, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%242 = torch.aten.div.Scalar %241, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%243 = torch.aten.sub.Tensor %236, %238, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%244 = torch.aten.add.Scalar %242, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%245 = torch.aten.sqrt %244 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%246 = torch.aten.div.Tensor %243, %245 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%247 = torch.aten.mul.Tensor %26, %246 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%248 = torch.aten.add.Tensor %247, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%249 = torch.aten.transpose.int %16, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%250 = torch.aten.view %248, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%251 = torch.aten.mm %250, %249 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%252 = torch.aten.mul.Scalar %20, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%253 = torch.aten.add.Tensor %252, %251, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%254 = torch.aten.view %253, %155 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%255 = torch.aten.gelu %254, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%256 = torch.aten.transpose.int %15, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%257 = torch.aten.view %255, %159 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%258 = torch.aten.mm %257, %256 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%259 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%260 = torch.aten.add.Tensor %259, %258, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%261 = torch.aten.view %260, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%262 = torch.aten.add.Tensor %261, %248, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%263 = torch.aten.sum.dim_IntList %262, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%264 = torch.aten.div.Scalar %263, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%265 = torch.aten.sub.Tensor %262, %264, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%266 = torch.aten.pow.Tensor_Scalar %265, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%267 = torch.aten.sum.dim_IntList %266, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%268 = torch.aten.div.Scalar %267, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%269 = torch.aten.sub.Tensor %262, %264, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%270 = torch.aten.add.Scalar %268, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%271 = torch.aten.sqrt %270 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%272 = torch.aten.div.Tensor %269, %271 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%273 = torch.aten.mul.Tensor %26, %272 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%274 = torch.aten.add.Tensor %273, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%275 = torch.aten.transpose.int %14, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%276 = torch.aten.view %274, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%277 = torch.aten.mm %276, %275 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%278 = torch.aten.view %277, %70 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%279 = torch.aten.view %278, %72 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%280 = torch.aten.permute %279, %74 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%281 = torch.aten.slice.Tensor %280, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%282 = torch.aten.slice.Tensor %280, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%283 = torch.aten.slice.Tensor %280, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%284 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%285 = torch.aten.unsqueeze %284, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%286 = torch.aten.slice.Tensor %285, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%287 = torch.aten.view %286, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%288 = torch.aten.permute %287, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%289 = torch.aten.add.Tensor %281, %288, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%290 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%291 = torch.aten.unsqueeze %290, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%292 = torch.aten.slice.Tensor %291, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%293 = torch.aten.view %292, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%294 = torch.aten.permute %293, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%295 = torch.aten.add.Tensor %283, %294, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%296 = torch.aten.div.Scalar %289, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%297 = torch.aten.transpose.int %282, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%298 = torch.aten.broadcast_to %296, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%299 = torch.aten.view %298, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%300 = torch.aten.broadcast_to %297, %98 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%301 = torch.aten.view %300, %100 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%302 = torch.aten.bmm %299, %301 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%303 = torch.aten.view %302, %103 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%304 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%305 = torch.aten.to.dtype %304, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%306 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%307 = torch.aten.broadcast_to %305, %306 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%308 = torch.aten.copy %307, %65, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%309 = torch.aten.bitwise_not %308 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%310 = torch.aten.clone %23, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%311 = torch.aten.masked_fill.Tensor %303, %309, %310 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_2, %indices_3 = torch.aten.max.dim %311, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%312 = torch.aten.sub.Tensor %311, %values_2, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%313 = torch.aten.exp %312 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%314 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%315 = torch.aten.sum.dim_IntList %313, %314, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%316 = torch.aten.div.Tensor %313, %315 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%317 = torch.aten.masked_fill.Scalar %316, %309, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%318 = torch.aten.broadcast_to %317, %103 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%319 = torch.aten.view %318, %120 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%320 = torch.aten.broadcast_to %295, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%321 = torch.aten.view %320, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%322 = torch.aten.bmm %319, %321 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%323 = torch.aten.view %322, %94 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%324 = torch.aten.permute %323, %74 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%325 = torch.aten.clone %324, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%326 = torch.aten.view %325, %128 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%327 = torch.aten.transpose.int %13, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%328 = torch.aten.view %326, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%329 = torch.aten.mm %328, %327 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%330 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%331 = torch.aten.add.Tensor %330, %329, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%332 = torch.aten.view %331, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%333 = torch.aten.add.Tensor %332, %274, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%334 = torch.aten.sum.dim_IntList %333, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%335 = torch.aten.div.Scalar %334, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%336 = torch.aten.sub.Tensor %333, %335, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%337 = torch.aten.pow.Tensor_Scalar %336, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%338 = torch.aten.sum.dim_IntList %337, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%339 = torch.aten.div.Scalar %338, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%340 = torch.aten.sub.Tensor %333, %335, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%341 = torch.aten.add.Scalar %339, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%342 = torch.aten.sqrt %341 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%343 = torch.aten.div.Tensor %340, %342 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%344 = torch.aten.mul.Tensor %26, %343 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%345 = torch.aten.add.Tensor %344, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%346 = torch.aten.transpose.int %12, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%347 = torch.aten.view %345, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%348 = torch.aten.mm %347, %346 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%349 = torch.aten.mul.Scalar %20, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%350 = torch.aten.add.Tensor %349, %348, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%351 = torch.aten.view %350, %155 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%352 = torch.aten.gelu %351, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%353 = torch.aten.transpose.int %11, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%354 = torch.aten.view %352, %159 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%355 = torch.aten.mm %354, %353 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%356 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%357 = torch.aten.add.Tensor %356, %355, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%358 = torch.aten.view %357, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%359 = torch.aten.add.Tensor %358, %345, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%360 = torch.aten.sum.dim_IntList %359, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%361 = torch.aten.div.Scalar %360, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%362 = torch.aten.sub.Tensor %359, %361, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%363 = torch.aten.pow.Tensor_Scalar %362, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%364 = torch.aten.sum.dim_IntList %363, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%365 = torch.aten.div.Scalar %364, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%366 = torch.aten.sub.Tensor %359, %361, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%367 = torch.aten.add.Scalar %365, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%368 = torch.aten.sqrt %367 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%369 = torch.aten.div.Tensor %366, %368 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%370 = torch.aten.mul.Tensor %26, %369 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%371 = torch.aten.add.Tensor %370, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%372 = torch.aten.transpose.int %10, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%373 = torch.aten.view %371, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%374 = torch.aten.mm %373, %372 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%375 = torch.aten.view %374, %70 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%376 = torch.aten.view %375, %72 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%377 = torch.aten.permute %376, %74 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%378 = torch.aten.slice.Tensor %377, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%379 = torch.aten.slice.Tensor %377, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%380 = torch.aten.slice.Tensor %377, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%381 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%382 = torch.aten.unsqueeze %381, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%383 = torch.aten.slice.Tensor %382, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%384 = torch.aten.view %383, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%385 = torch.aten.permute %384, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%386 = torch.aten.add.Tensor %378, %385, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%387 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%388 = torch.aten.unsqueeze %387, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%389 = torch.aten.slice.Tensor %388, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%390 = torch.aten.view %389, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%391 = torch.aten.permute %390, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%392 = torch.aten.add.Tensor %380, %391, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%393 = torch.aten.div.Scalar %386, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%394 = torch.aten.transpose.int %379, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%395 = torch.aten.broadcast_to %393, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%396 = torch.aten.view %395, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%397 = torch.aten.broadcast_to %394, %98 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%398 = torch.aten.view %397, %100 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%399 = torch.aten.bmm %396, %398 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%400 = torch.aten.view %399, %103 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%401 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%402 = torch.aten.to.dtype %401, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%403 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%404 = torch.aten.broadcast_to %402, %403 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%405 = torch.aten.copy %404, %65, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%406 = torch.aten.bitwise_not %405 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%407 = torch.aten.clone %23, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%408 = torch.aten.masked_fill.Tensor %400, %406, %407 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_4, %indices_5 = torch.aten.max.dim %408, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%409 = torch.aten.sub.Tensor %408, %values_4, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%410 = torch.aten.exp %409 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%411 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%412 = torch.aten.sum.dim_IntList %410, %411, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%413 = torch.aten.div.Tensor %410, %412 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%414 = torch.aten.masked_fill.Scalar %413, %406, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%415 = torch.aten.broadcast_to %414, %103 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%416 = torch.aten.view %415, %120 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%417 = torch.aten.broadcast_to %392, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%418 = torch.aten.view %417, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%419 = torch.aten.bmm %416, %418 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%420 = torch.aten.view %419, %94 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%421 = torch.aten.permute %420, %74 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%422 = torch.aten.clone %421, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%423 = torch.aten.view %422, %128 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%424 = torch.aten.transpose.int %9, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%425 = torch.aten.view %423, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%426 = torch.aten.mm %425, %424 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%427 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%428 = torch.aten.add.Tensor %427, %426, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%429 = torch.aten.view %428, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%430 = torch.aten.add.Tensor %429, %371, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%431 = torch.aten.sum.dim_IntList %430, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%432 = torch.aten.div.Scalar %431, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%433 = torch.aten.sub.Tensor %430, %432, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%434 = torch.aten.pow.Tensor_Scalar %433, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%435 = torch.aten.sum.dim_IntList %434, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%436 = torch.aten.div.Scalar %435, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%437 = torch.aten.sub.Tensor %430, %432, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%438 = torch.aten.add.Scalar %436, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%439 = torch.aten.sqrt %438 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%440 = torch.aten.div.Tensor %437, %439 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%441 = torch.aten.mul.Tensor %26, %440 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%442 = torch.aten.add.Tensor %441, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%443 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%444 = torch.aten.view %442, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%445 = torch.aten.mm %444, %443 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%446 = torch.aten.mul.Scalar %20, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%447 = torch.aten.add.Tensor %446, %445, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%448 = torch.aten.view %447, %155 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%449 = torch.aten.gelu %448, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%450 = torch.aten.transpose.int %7, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%451 = torch.aten.view %449, %159 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%452 = torch.aten.mm %451, %450 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%453 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%454 = torch.aten.add.Tensor %453, %452, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%455 = torch.aten.view %454, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%456 = torch.aten.add.Tensor %455, %442, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%457 = torch.aten.sum.dim_IntList %456, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%458 = torch.aten.div.Scalar %457, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%459 = torch.aten.sub.Tensor %456, %458, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%460 = torch.aten.pow.Tensor_Scalar %459, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%461 = torch.aten.sum.dim_IntList %460, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%462 = torch.aten.div.Scalar %461, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%463 = torch.aten.sub.Tensor %456, %458, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%464 = torch.aten.add.Scalar %462, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%465 = torch.aten.sqrt %464 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%466 = torch.aten.div.Tensor %463, %465 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%467 = torch.aten.mul.Tensor %26, %466 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%468 = torch.aten.add.Tensor %467, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%469 = torch.aten.transpose.int %6, %int0, %int1 : !torch.vtensor<[96,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,96],f32>
%470 = torch.aten.view %468, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%471 = torch.aten.mm %470, %469 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,96],f32> -> !torch.vtensor<[128,96],f32>
%472 = torch.aten.view %471, %70 : !torch.vtensor<[128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,96],f32>
%473 = torch.aten.view %472, %72 : !torch.vtensor<[1,128,96],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,24],f32>
%474 = torch.aten.permute %473, %74 : !torch.vtensor<[1,128,4,24],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,24],f32>
%475 = torch.aten.slice.Tensor %474, %int-1, %int0, %int8, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%476 = torch.aten.slice.Tensor %474, %int-1, %int8, %int16, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%477 = torch.aten.slice.Tensor %474, %int-1, %int16, %int24, %int1 : !torch.vtensor<[1,4,128,24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%478 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%479 = torch.aten.unsqueeze %478, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%480 = torch.aten.slice.Tensor %479, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%481 = torch.aten.view %480, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%482 = torch.aten.permute %481, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%483 = torch.aten.add.Tensor %475, %482, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%484 = torch.aten.unsqueeze %25, %int0 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%485 = torch.aten.unsqueeze %484, %int1 : !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,1,32],f32>
%486 = torch.aten.slice.Tensor %485, %int2, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%487 = torch.aten.view %486, %82 : !torch.vtensor<[1,1,32],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,8],f32>
%488 = torch.aten.permute %487, %74 : !torch.vtensor<[1,1,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,1,8],f32>
%489 = torch.aten.add.Tensor %477, %488, %int1 : !torch.vtensor<[1,4,128,8],f32>, !torch.vtensor<[1,4,1,8],f32>, !torch.int -> !torch.vtensor<[1,4,128,8],f32>
%490 = torch.aten.div.Scalar %483, %float4.000000e00 : !torch.vtensor<[1,4,128,8],f32>, !torch.float -> !torch.vtensor<[1,4,128,8],f32>
%491 = torch.aten.transpose.int %476, %int-1, %int-2 : !torch.vtensor<[1,4,128,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,4,8,128],f32>
%492 = torch.aten.broadcast_to %490, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%493 = torch.aten.view %492, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%494 = torch.aten.broadcast_to %491, %98 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,8,128],f32>
%495 = torch.aten.view %494, %100 : !torch.vtensor<[1,4,8,128],f32>, !torch.list<int> -> !torch.vtensor<[4,8,128],f32>
%496 = torch.aten.bmm %493, %495 : !torch.vtensor<[4,128,8],f32>, !torch.vtensor<[4,8,128],f32> -> !torch.vtensor<[4,128,128],f32>
%497 = torch.aten.view %496, %103 : !torch.vtensor<[4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%498 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%499 = torch.aten.to.dtype %498, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%500 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%501 = torch.aten.broadcast_to %499, %500 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
%502 = torch.aten.copy %501, %65, %false : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,1,128,128],si8>, !torch.bool -> !torch.vtensor<[1,1,128,128],i1>
%503 = torch.aten.bitwise_not %502 : !torch.vtensor<[1,1,128,128],i1> -> !torch.vtensor<[1,1,128,128],i1>
%504 = torch.aten.clone %23, %none : !torch.vtensor<[],f32>, !torch.none -> !torch.vtensor<[],f32>
%505 = torch.aten.masked_fill.Tensor %497, %503, %504 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,4,128,128],f32>
%values_6, %indices_7 = torch.aten.max.dim %505, %int-1, %true : !torch.vtensor<[1,4,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,4,128,1],f32>, !torch.vtensor<[1,4,128,1],si64>
%506 = torch.aten.sub.Tensor %505, %values_6, %float1.000000e00 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32>, !torch.float -> !torch.vtensor<[1,4,128,128],f32>
%507 = torch.aten.exp %506 : !torch.vtensor<[1,4,128,128],f32> -> !torch.vtensor<[1,4,128,128],f32>
%508 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%509 = torch.aten.sum.dim_IntList %507, %508, %true, %none : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,4,128,1],f32>
%510 = torch.aten.div.Tensor %507, %509 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,4,128,1],f32> -> !torch.vtensor<[1,4,128,128],f32>
%511 = torch.aten.masked_fill.Scalar %510, %503, %int0 : !torch.vtensor<[1,4,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,4,128,128],f32>
%512 = torch.aten.broadcast_to %511, %103 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,128],f32>
%513 = torch.aten.view %512, %120 : !torch.vtensor<[1,4,128,128],f32>, !torch.list<int> -> !torch.vtensor<[4,128,128],f32>
%514 = torch.aten.broadcast_to %489, %94 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%515 = torch.aten.view %514, %96 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[4,128,8],f32>
%516 = torch.aten.bmm %513, %515 : !torch.vtensor<[4,128,128],f32>, !torch.vtensor<[4,128,8],f32> -> !torch.vtensor<[4,128,8],f32>
%517 = torch.aten.view %516, %94 : !torch.vtensor<[4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,4,128,8],f32>
%518 = torch.aten.permute %517, %74 : !torch.vtensor<[1,4,128,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,4,8],f32>
%519 = torch.aten.clone %518, %int0 : !torch.vtensor<[1,128,4,8],f32>, !torch.int -> !torch.vtensor<[1,128,4,8],f32>
%520 = torch.aten.view %519, %128 : !torch.vtensor<[1,128,4,8],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%521 = torch.aten.transpose.int %5, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%522 = torch.aten.view %520, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%523 = torch.aten.mm %522, %521 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[128,32],f32>
%524 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%525 = torch.aten.add.Tensor %524, %523, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%526 = torch.aten.view %525, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%527 = torch.aten.add.Tensor %526, %468, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%528 = torch.aten.sum.dim_IntList %527, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%529 = torch.aten.div.Scalar %528, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%530 = torch.aten.sub.Tensor %527, %529, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%531 = torch.aten.pow.Tensor_Scalar %530, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%532 = torch.aten.sum.dim_IntList %531, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%533 = torch.aten.div.Scalar %532, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%534 = torch.aten.sub.Tensor %527, %529, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%535 = torch.aten.add.Scalar %533, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%536 = torch.aten.sqrt %535 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%537 = torch.aten.div.Tensor %534, %536 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%538 = torch.aten.mul.Tensor %26, %537 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%539 = torch.aten.add.Tensor %538, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%540 = torch.aten.transpose.int %4, %int0, %int1 : !torch.vtensor<[37,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,37],f32>
%541 = torch.aten.view %539, %67 : !torch.vtensor<[1,128,32],f32>, !torch.list<int> -> !torch.vtensor<[128,32],f32>
%542 = torch.aten.mm %541, %540 : !torch.vtensor<[128,32],f32>, !torch.vtensor<[32,37],f32> -> !torch.vtensor<[128,37],f32>
%543 = torch.aten.mul.Scalar %20, %int1 : !torch.vtensor<[37],f32>, !torch.int -> !torch.vtensor<[37],f32>
%544 = torch.aten.add.Tensor %543, %542, %int1 : !torch.vtensor<[37],f32>, !torch.vtensor<[128,37],f32>, !torch.int -> !torch.vtensor<[128,37],f32>
%545 = torch.aten.view %544, %155 : !torch.vtensor<[128,37],f32>, !torch.list<int> -> !torch.vtensor<[1,128,37],f32>
%546 = torch.aten.gelu %545, %str : !torch.vtensor<[1,128,37],f32>, !torch.str -> !torch.vtensor<[1,128,37],f32>
%547 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[32,37],f32>, !torch.int, !torch.int -> !torch.vtensor<[37,32],f32>
%548 = torch.aten.view %546, %159 : !torch.vtensor<[1,128,37],f32>, !torch.list<int> -> !torch.vtensor<[128,37],f32>
%549 = torch.aten.mm %548, %547 : !torch.vtensor<[128,37],f32>, !torch.vtensor<[37,32],f32> -> !torch.vtensor<[128,32],f32>
%550 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%551 = torch.aten.add.Tensor %550, %549, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[128,32],f32>, !torch.int -> !torch.vtensor<[128,32],f32>
%552 = torch.aten.view %551, %135 : !torch.vtensor<[128,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,32],f32>
%553 = torch.aten.add.Tensor %552, %539, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%554 = torch.aten.sum.dim_IntList %553, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%555 = torch.aten.div.Scalar %554, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%556 = torch.aten.sub.Tensor %553, %555, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%557 = torch.aten.pow.Tensor_Scalar %556, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%558 = torch.aten.sum.dim_IntList %557, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%559 = torch.aten.div.Scalar %558, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%560 = torch.aten.sub.Tensor %553, %555, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%561 = torch.aten.add.Scalar %559, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%562 = torch.aten.sqrt %561 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%563 = torch.aten.div.Tensor %560, %562 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%564 = torch.aten.mul.Tensor %26, %563 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%565 = torch.aten.add.Tensor %564, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%566 = torch.aten.slice.Tensor %565, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32],f32>
%567 = torch.aten.slice.Tensor %566, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,32],f32>
%568 = torch.aten.squeeze.dim %567, %int1 : !torch.vtensor<[1,1,32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%569 = torch.aten.transpose.int %2, %int0, %int1 : !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,32],f32>
%570 = torch.aten.mm %568, %569 : !torch.vtensor<[1,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1,32],f32>
%571 = torch.aten.mul.Scalar %25, %int1 : !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32>
%572 = torch.aten.add.Tensor %571, %570, %int1 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,32],f32>, !torch.int -> !torch.vtensor<[1,32],f32>
%573 = torch.aten.gelu %572, %str : !torch.vtensor<[1,32],f32>, !torch.str -> !torch.vtensor<[1,32],f32>
%574 = torch.aten.transpose.int %1, %int0, %int1 : !torch.vtensor<[2,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[32,2],f32>
%575 = torch.aten.mm %573, %574 : !torch.vtensor<[1,32],f32>, !torch.vtensor<[32,2],f32> -> !torch.vtensor<[1,2],f32>
%576 = torch.aten.mul.Scalar %0, %int1 : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
%577 = torch.aten.add.Tensor %576, %575, %int1 : !torch.vtensor<[2],f32>, !torch.vtensor<[1,2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
return %577 : !torch.vtensor<[1,2],f32>
}
@AmosLewis
Copy link
Author

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir

def prepare_sentence_tokens(hf_model: str, sentence: str):
    tokenizer = AutoTokenizer.from_pretrained(hf_model)
    return torch.tensor([tokenizer.encode(sentence)])


class HfMaskedLM(torch.nn.Module):

    def __init__(self, model_name: str):
        super().__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name,  # The pretrained model name.
            # The number of output labels--2 for binary classification.
            num_labels=2,
            # Whether the model returns attentions weights.
            output_attentions=False,
            # Whether the model returns all hidden-states.
            output_hidden_states=False,
            torchscript=True,
        )
        self.model.eval()

    def forward(self, tokens):
        return self.model.forward(tokens)[0]



hf_minilm_model = "hf-internal-testing/tiny-random-deberta"

test_input = torch.randint(2, (1, 128))

model = HfMaskedLM(hf_minilm_model)

print("model(test_input): ")
print(model(test_input))

fx_g = make_fx(
    model,
    decomposition_table=get_decompositions(
        [
            torch.ops.aten.embedding_dense_backward,
            torch.ops.aten.native_layer_norm_backward,
            torch.ops.aten.slice_backward,
            torch.ops.aten.select_backward,
            torch.ops.aten.norm.ScalarOpt_dim,
            torch.ops.aten.native_group_norm,
            torch.ops.aten.upsample_bilinear2d.vec,
            torch.ops.aten.split.Tensor,
            torch.ops.aten.split_with_sizes,
        ]
    ),
)(test_input)

# print(fx_g.graph)

fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()

def strip_overloads(gm):
    """
    Modifies the target of graph nodes in :attr:`gm` to strip overloads.
    Args:
        gm(fx.GraphModule): The input Fx graph module to be modified
    """
    for node in gm.graph.nodes:
        if isinstance(node.target, torch._ops.OpOverload):
            node.target = node.target.overloadpacket
    gm.recompile()

strip_overloads(fx_g)

ts_g = torch.jit.script(fx_g)

# module = torch_mlir.compile(
#     ts_g,
#     (test_input),
#     torch_mlir.OutputType.LINALG_ON_TENSORS,
#     use_tracing=True,
#     verbose=False,
# )

module = torch_mlir.compile(
    ts_g,
    (test_input),
    torch_mlir.OutputType.TOSA,
    use_tracing=True,
    verbose=False,
)
module.dump()


from shark.shark_inference import SharkInference
mlir_model = module
func_name = "forward"

shark_module = SharkInference(
    mlir_model, func_name, device="cpu", mlir_dialect="tosa"
)
shark_module.compile()

def shark_result(x):
    x_ny = x.detach().numpy()
    inputs = (x_ny,)
    result = shark_module.forward(inputs)
    return torch.from_numpy(result)

observed_out = shark_result(test_input)
print(observed_out)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment