Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created January 24, 2023 04:47
Show Gist options
  • Save AmosLewis/cb49c739949bc03f9cba6d366c07c7e8 to your computer and use it in GitHub Desktop.
Save AmosLewis/cb49c739949bc03f9cba6d366c07c7e8 to your computer and use it in GitHub Desktop.
fx_g.graph:
graph():
%arg0_1 : [#users=1] = placeholder[target=arg0_1]
%view : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%arg0_1, [-1, 128]), kwargs = {})
%arange : [#users=1] = call_function[target=torch.ops.aten.arange.start](args = (0, 128), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
%unsqueeze : [#users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {})
%view_1 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%unsqueeze, [-1, 128]), kwargs = {})
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%embedding : [#users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_param_constant0, %view), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%embedding_1 : [#users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_param_constant1, %view_1), kwargs = {})
%add : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%native_layer_norm : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add, [768], %_param_constant2, %_param_constant3, 1e-05), kwargs = {})
%getitem : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
%getitem_1 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm, 1), kwargs = {})
%getitem_2 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm, 2), kwargs = {})
%view_2 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem, [-1, 768]), kwargs = {})
%_param_constant4 : [#users=1] = get_attr[target=_param_constant4]
%_param_constant5 : [#users=1] = get_attr[target=_param_constant5]
%addmm : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant4, %view_2, %_param_constant5), kwargs = {})
%view_3 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm, [1, 128, 2304]), kwargs = {})
%as_strided : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_3, [1, 128, 768], [294912, 2304, 1], 0), kwargs = {})
%as_strided_1 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_3, [1, 128, 768], [294912, 2304, 1], 768), kwargs = {})
%as_strided_2 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_3, [1, 128, 768], [294912, 2304, 1], 1536), kwargs = {})
%view_4 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided, [1, 128, 12, 64]), kwargs = {})
%permute : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_4, [0, 2, 1, 3]), kwargs = {})
%view_5 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_1, [1, 128, 12, 64]), kwargs = {})
%permute_1 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_5, [0, 2, 1, 3]), kwargs = {})
%view_6 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_2, [1, 128, 12, 64]), kwargs = {})
%permute_2 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_6, [0, 2, 1, 3]), kwargs = {})
%transpose : [#users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%permute_1, -1, -2), kwargs = {})
%expand : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [1, 12, 128, 64]), kwargs = {})
%view_7 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand, [12, 128, 64]), kwargs = {})
%expand_1 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%transpose, [1, 12, 64, 128]), kwargs = {})
%view_8 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_1, [12, 64, 128]), kwargs = {})
%bmm : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_7, %view_8), kwargs = {})
%_unsafe_view : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm, [1, 12, 128, 128]), kwargs = {})
%_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
%lift_fresh_copy : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,), kwargs = {})
%div : [#users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%_unsafe_view, %lift_fresh_copy), kwargs = {})
%_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1]
%as_strided_3 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%_tensor_constant1, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_4 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_3, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_5 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_4, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_6 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_5, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%convert_element_type : [#users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%as_strided_6, torch.bool), kwargs = {})
%_tensor_constant2 : [#users=1] = get_attr[target=_tensor_constant2]
%lift_fresh_copy_1 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant2,), kwargs = {})
%where : [#users=2] = call_function[target=torch.ops.aten.where.self](args = (%convert_element_type, %div, %lift_fresh_copy_1), kwargs = {})
%amax : [#users=1] = call_function[target=torch.ops.aten.amax.default](args = (%where, [-1], True), kwargs = {})
%sub : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%where, %amax), kwargs = {})
%exp : [#users=2] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
%sum_1 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [-1], True), kwargs = {})
%div_1 : [#users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp, %sum_1), kwargs = {})
%detach : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%div_1,), kwargs = {})
%expand_2 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%div_1, [1, 12, 128, 128]), kwargs = {})
%view_9 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_2, [12, 128, 128]), kwargs = {})
%expand_3 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_2, [1, 12, 128, 64]), kwargs = {})
%view_10 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_3, [12, 128, 64]), kwargs = {})
%bmm_1 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_9, %view_10), kwargs = {})
%_unsafe_view_1 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_1, [1, 12, 128, 64]), kwargs = {})
%permute_3 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_unsafe_view_1, [0, 2, 1, 3]), kwargs = {})
%clone : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_3,), kwargs = {memory_format: torch.contiguous_format})
%view_11 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%clone, [1, 128, 768]), kwargs = {})
%view_12 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_11, [-1, 768]), kwargs = {})
%_param_constant6 : [#users=1] = get_attr[target=_param_constant6]
%_param_constant7 : [#users=1] = get_attr[target=_param_constant7]
%addmm_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant6, %view_12, %_param_constant7), kwargs = {})
%view_13 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_1, [1, 128, 768]), kwargs = {})
%add_1 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_13, %add), kwargs = {})
%_param_constant8 : [#users=1] = get_attr[target=_param_constant8]
%_param_constant9 : [#users=1] = get_attr[target=_param_constant9]
%native_layer_norm_1 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_1, [768], %_param_constant8, %_param_constant9, 1e-05), kwargs = {})
%getitem_3 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 0), kwargs = {})
%getitem_4 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 1), kwargs = {})
%getitem_5 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_1, 2), kwargs = {})
%view_14 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_3, [-1, 768]), kwargs = {})
%_param_constant10 : [#users=1] = get_attr[target=_param_constant10]
%_param_constant11 : [#users=1] = get_attr[target=_param_constant11]
%addmm_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant10, %view_14, %_param_constant11), kwargs = {})
%view_15 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_2, [1, 128, 3072]), kwargs = {})
%mul : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_15, 0.5), kwargs = {})
%pow_1 : [#users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%view_15, 3.0), kwargs = {})
%mul_1 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%pow_1, 0.044715), kwargs = {})
%add_2 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_15, %mul_1), kwargs = {})
%mul_2 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_2, 0.7978845608028654), kwargs = {})
%tanh : [#users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_2,), kwargs = {})
%detach_1 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%tanh,), kwargs = {})
%add_3 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1.0), kwargs = {})
%mul_3 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %add_3), kwargs = {})
%view_16 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_3, [-1, 3072]), kwargs = {})
%_param_constant12 : [#users=1] = get_attr[target=_param_constant12]
%_param_constant13 : [#users=1] = get_attr[target=_param_constant13]
%addmm_3 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant12, %view_16, %_param_constant13), kwargs = {})
%view_17 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_3, [1, 128, 768]), kwargs = {})
%add_4 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %view_17), kwargs = {})
%_param_constant14 : [#users=1] = get_attr[target=_param_constant14]
%_param_constant15 : [#users=1] = get_attr[target=_param_constant15]
%native_layer_norm_2 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_4, [768], %_param_constant14, %_param_constant15, 1e-05), kwargs = {})
%getitem_6 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 0), kwargs = {})
%getitem_7 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 1), kwargs = {})
%getitem_8 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_2, 2), kwargs = {})
%view_18 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_6, [-1, 768]), kwargs = {})
%_param_constant16 : [#users=1] = get_attr[target=_param_constant16]
%_param_constant17 : [#users=1] = get_attr[target=_param_constant17]
%addmm_4 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant16, %view_18, %_param_constant17), kwargs = {})
%view_19 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_4, [1, 128, 2304]), kwargs = {})
%as_strided_7 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_19, [1, 128, 768], [294912, 2304, 1], 0), kwargs = {})
%as_strided_8 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_19, [1, 128, 768], [294912, 2304, 1], 768), kwargs = {})
%as_strided_9 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_19, [1, 128, 768], [294912, 2304, 1], 1536), kwargs = {})
%view_20 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_7, [1, 128, 12, 64]), kwargs = {})
%permute_4 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_20, [0, 2, 1, 3]), kwargs = {})
%view_21 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_8, [1, 128, 12, 64]), kwargs = {})
%permute_5 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_21, [0, 2, 1, 3]), kwargs = {})
%view_22 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_9, [1, 128, 12, 64]), kwargs = {})
%permute_6 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_22, [0, 2, 1, 3]), kwargs = {})
%transpose_1 : [#users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%permute_5, -1, -2), kwargs = {})
%expand_4 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_4, [1, 12, 128, 64]), kwargs = {})
%view_23 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_4, [12, 128, 64]), kwargs = {})
%expand_5 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%transpose_1, [1, 12, 64, 128]), kwargs = {})
%view_24 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_5, [12, 64, 128]), kwargs = {})
%bmm_2 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_23, %view_24), kwargs = {})
%_unsafe_view_2 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_2, [1, 12, 128, 128]), kwargs = {})
%_tensor_constant3 : [#users=1] = get_attr[target=_tensor_constant3]
%lift_fresh_copy_2 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant3,), kwargs = {})
%div_2 : [#users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%_unsafe_view_2, %lift_fresh_copy_2), kwargs = {})
%_tensor_constant4 : [#users=1] = get_attr[target=_tensor_constant4]
%as_strided_10 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%_tensor_constant4, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_11 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_10, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_12 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_11, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_13 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_12, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%convert_element_type_1 : [#users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%as_strided_13, torch.bool), kwargs = {})
%_tensor_constant5 : [#users=1] = get_attr[target=_tensor_constant5]
%lift_fresh_copy_3 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant5,), kwargs = {})
%where_1 : [#users=2] = call_function[target=torch.ops.aten.where.self](args = (%convert_element_type_1, %div_2, %lift_fresh_copy_3), kwargs = {})
%amax_1 : [#users=1] = call_function[target=torch.ops.aten.amax.default](args = (%where_1, [-1], True), kwargs = {})
%sub_1 : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%where_1, %amax_1), kwargs = {})
%exp_1 : [#users=2] = call_function[target=torch.ops.aten.exp.default](args = (%sub_1,), kwargs = {})
%sum_2 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp_1, [-1], True), kwargs = {})
%div_3 : [#users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_1, %sum_2), kwargs = {})
%detach_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%div_3,), kwargs = {})
%expand_6 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%div_3, [1, 12, 128, 128]), kwargs = {})
%view_25 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_6, [12, 128, 128]), kwargs = {})
%expand_7 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_6, [1, 12, 128, 64]), kwargs = {})
%view_26 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_7, [12, 128, 64]), kwargs = {})
%bmm_3 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_25, %view_26), kwargs = {})
%_unsafe_view_3 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_3, [1, 12, 128, 64]), kwargs = {})
%permute_7 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_unsafe_view_3, [0, 2, 1, 3]), kwargs = {})
%clone_1 : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_7,), kwargs = {memory_format: torch.contiguous_format})
%view_27 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%clone_1, [1, 128, 768]), kwargs = {})
%view_28 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_27, [-1, 768]), kwargs = {})
%_param_constant18 : [#users=1] = get_attr[target=_param_constant18]
%_param_constant19 : [#users=1] = get_attr[target=_param_constant19]
%addmm_5 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant18, %view_28, %_param_constant19), kwargs = {})
%view_29 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_5, [1, 128, 768]), kwargs = {})
%add_5 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_29, %add_4), kwargs = {})
%_param_constant20 : [#users=1] = get_attr[target=_param_constant20]
%_param_constant21 : [#users=1] = get_attr[target=_param_constant21]
%native_layer_norm_3 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_5, [768], %_param_constant20, %_param_constant21, 1e-05), kwargs = {})
%getitem_9 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_3, 0), kwargs = {})
%getitem_10 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_3, 1), kwargs = {})
%getitem_11 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_3, 2), kwargs = {})
%view_30 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_9, [-1, 768]), kwargs = {})
%_param_constant22 : [#users=1] = get_attr[target=_param_constant22]
%_param_constant23 : [#users=1] = get_attr[target=_param_constant23]
%addmm_6 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant22, %view_30, %_param_constant23), kwargs = {})
%view_31 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_6, [1, 128, 3072]), kwargs = {})
%mul_4 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_31, 0.5), kwargs = {})
%pow_2 : [#users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%view_31, 3.0), kwargs = {})
%mul_5 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%pow_2, 0.044715), kwargs = {})
%add_6 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_31, %mul_5), kwargs = {})
%mul_6 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_6, 0.7978845608028654), kwargs = {})
%tanh_1 : [#users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_6,), kwargs = {})
%detach_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%tanh_1,), kwargs = {})
%add_7 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh_1, 1.0), kwargs = {})
%mul_7 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_7), kwargs = {})
%view_32 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_7, [-1, 3072]), kwargs = {})
%_param_constant24 : [#users=1] = get_attr[target=_param_constant24]
%_param_constant25 : [#users=1] = get_attr[target=_param_constant25]
%addmm_7 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant24, %view_32, %_param_constant25), kwargs = {})
%view_33 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_7, [1, 128, 768]), kwargs = {})
%add_8 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_5, %view_33), kwargs = {})
%_param_constant26 : [#users=1] = get_attr[target=_param_constant26]
%_param_constant27 : [#users=1] = get_attr[target=_param_constant27]
%native_layer_norm_4 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_8, [768], %_param_constant26, %_param_constant27, 1e-05), kwargs = {})
%getitem_12 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_4, 0), kwargs = {})
%getitem_13 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_4, 1), kwargs = {})
%getitem_14 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_4, 2), kwargs = {})
%view_34 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_12, [-1, 768]), kwargs = {})
%_param_constant28 : [#users=1] = get_attr[target=_param_constant28]
%_param_constant29 : [#users=1] = get_attr[target=_param_constant29]
%addmm_8 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant28, %view_34, %_param_constant29), kwargs = {})
%view_35 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_8, [1, 128, 2304]), kwargs = {})
%as_strided_14 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_35, [1, 128, 768], [294912, 2304, 1], 0), kwargs = {})
%as_strided_15 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_35, [1, 128, 768], [294912, 2304, 1], 768), kwargs = {})
%as_strided_16 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_35, [1, 128, 768], [294912, 2304, 1], 1536), kwargs = {})
%view_36 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_14, [1, 128, 12, 64]), kwargs = {})
%permute_8 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_36, [0, 2, 1, 3]), kwargs = {})
%view_37 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_15, [1, 128, 12, 64]), kwargs = {})
%permute_9 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_37, [0, 2, 1, 3]), kwargs = {})
%view_38 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_16, [1, 128, 12, 64]), kwargs = {})
%permute_10 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_38, [0, 2, 1, 3]), kwargs = {})
%transpose_2 : [#users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%permute_9, -1, -2), kwargs = {})
%expand_8 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_8, [1, 12, 128, 64]), kwargs = {})
%view_39 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_8, [12, 128, 64]), kwargs = {})
%expand_9 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%transpose_2, [1, 12, 64, 128]), kwargs = {})
%view_40 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_9, [12, 64, 128]), kwargs = {})
%bmm_4 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_39, %view_40), kwargs = {})
%_unsafe_view_4 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_4, [1, 12, 128, 128]), kwargs = {})
%_tensor_constant6 : [#users=1] = get_attr[target=_tensor_constant6]
%lift_fresh_copy_4 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant6,), kwargs = {})
%div_4 : [#users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%_unsafe_view_4, %lift_fresh_copy_4), kwargs = {})
%_tensor_constant7 : [#users=1] = get_attr[target=_tensor_constant7]
%as_strided_17 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%_tensor_constant7, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_18 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_17, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_19 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_18, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_20 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_19, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%convert_element_type_2 : [#users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%as_strided_20, torch.bool), kwargs = {})
%_tensor_constant8 : [#users=1] = get_attr[target=_tensor_constant8]
%lift_fresh_copy_5 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant8,), kwargs = {})
%where_2 : [#users=2] = call_function[target=torch.ops.aten.where.self](args = (%convert_element_type_2, %div_4, %lift_fresh_copy_5), kwargs = {})
%amax_2 : [#users=1] = call_function[target=torch.ops.aten.amax.default](args = (%where_2, [-1], True), kwargs = {})
%sub_2 : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%where_2, %amax_2), kwargs = {})
%exp_2 : [#users=2] = call_function[target=torch.ops.aten.exp.default](args = (%sub_2,), kwargs = {})
%sum_3 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp_2, [-1], True), kwargs = {})
%div_5 : [#users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_2, %sum_3), kwargs = {})
%detach_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%div_5,), kwargs = {})
%expand_10 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%div_5, [1, 12, 128, 128]), kwargs = {})
%view_41 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_10, [12, 128, 128]), kwargs = {})
%expand_11 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_10, [1, 12, 128, 64]), kwargs = {})
%view_42 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_11, [12, 128, 64]), kwargs = {})
%bmm_5 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_41, %view_42), kwargs = {})
%_unsafe_view_5 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_5, [1, 12, 128, 64]), kwargs = {})
%permute_11 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_unsafe_view_5, [0, 2, 1, 3]), kwargs = {})
%clone_2 : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_11,), kwargs = {memory_format: torch.contiguous_format})
%view_43 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%clone_2, [1, 128, 768]), kwargs = {})
%view_44 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_43, [-1, 768]), kwargs = {})
%_param_constant30 : [#users=1] = get_attr[target=_param_constant30]
%_param_constant31 : [#users=1] = get_attr[target=_param_constant31]
%addmm_9 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant30, %view_44, %_param_constant31), kwargs = {})
%view_45 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_9, [1, 128, 768]), kwargs = {})
%add_9 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_45, %add_8), kwargs = {})
%_param_constant32 : [#users=1] = get_attr[target=_param_constant32]
%_param_constant33 : [#users=1] = get_attr[target=_param_constant33]
%native_layer_norm_5 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_9, [768], %_param_constant32, %_param_constant33, 1e-05), kwargs = {})
%getitem_15 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_5, 0), kwargs = {})
%getitem_16 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_5, 1), kwargs = {})
%getitem_17 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_5, 2), kwargs = {})
%view_46 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_15, [-1, 768]), kwargs = {})
%_param_constant34 : [#users=1] = get_attr[target=_param_constant34]
%_param_constant35 : [#users=1] = get_attr[target=_param_constant35]
%addmm_10 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant34, %view_46, %_param_constant35), kwargs = {})
%view_47 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_10, [1, 128, 3072]), kwargs = {})
%mul_8 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_47, 0.5), kwargs = {})
%pow_3 : [#users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%view_47, 3.0), kwargs = {})
%mul_9 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%pow_3, 0.044715), kwargs = {})
%add_10 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_47, %mul_9), kwargs = {})
%mul_10 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_10, 0.7978845608028654), kwargs = {})
%tanh_2 : [#users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_10,), kwargs = {})
%detach_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%tanh_2,), kwargs = {})
%add_11 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh_2, 1.0), kwargs = {})
%mul_11 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_8, %add_11), kwargs = {})
%view_48 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_11, [-1, 3072]), kwargs = {})
%_param_constant36 : [#users=1] = get_attr[target=_param_constant36]
%_param_constant37 : [#users=1] = get_attr[target=_param_constant37]
%addmm_11 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant36, %view_48, %_param_constant37), kwargs = {})
%view_49 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_11, [1, 128, 768]), kwargs = {})
%add_12 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_9, %view_49), kwargs = {})
%_param_constant38 : [#users=1] = get_attr[target=_param_constant38]
%_param_constant39 : [#users=1] = get_attr[target=_param_constant39]
%native_layer_norm_6 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_12, [768], %_param_constant38, %_param_constant39, 1e-05), kwargs = {})
%getitem_18 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_6, 0), kwargs = {})
%getitem_19 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_6, 1), kwargs = {})
%getitem_20 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_6, 2), kwargs = {})
%view_50 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_18, [-1, 768]), kwargs = {})
%_param_constant40 : [#users=1] = get_attr[target=_param_constant40]
%_param_constant41 : [#users=1] = get_attr[target=_param_constant41]
%addmm_12 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant40, %view_50, %_param_constant41), kwargs = {})
%view_51 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_12, [1, 128, 2304]), kwargs = {})
%as_strided_21 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_51, [1, 128, 768], [294912, 2304, 1], 0), kwargs = {})
%as_strided_22 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_51, [1, 128, 768], [294912, 2304, 1], 768), kwargs = {})
%as_strided_23 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_51, [1, 128, 768], [294912, 2304, 1], 1536), kwargs = {})
%view_52 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_21, [1, 128, 12, 64]), kwargs = {})
%permute_12 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_52, [0, 2, 1, 3]), kwargs = {})
%view_53 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_22, [1, 128, 12, 64]), kwargs = {})
%permute_13 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_53, [0, 2, 1, 3]), kwargs = {})
%view_54 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_23, [1, 128, 12, 64]), kwargs = {})
%permute_14 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_54, [0, 2, 1, 3]), kwargs = {})
%transpose_3 : [#users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%permute_13, -1, -2), kwargs = {})
%expand_12 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_12, [1, 12, 128, 64]), kwargs = {})
%view_55 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_12, [12, 128, 64]), kwargs = {})
%expand_13 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%transpose_3, [1, 12, 64, 128]), kwargs = {})
%view_56 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_13, [12, 64, 128]), kwargs = {})
%bmm_6 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_55, %view_56), kwargs = {})
%_unsafe_view_6 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_6, [1, 12, 128, 128]), kwargs = {})
%_tensor_constant9 : [#users=1] = get_attr[target=_tensor_constant9]
%lift_fresh_copy_6 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant9,), kwargs = {})
%div_6 : [#users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%_unsafe_view_6, %lift_fresh_copy_6), kwargs = {})
%_tensor_constant10 : [#users=1] = get_attr[target=_tensor_constant10]
%as_strided_24 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%_tensor_constant10, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_25 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_24, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_26 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_25, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_27 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_26, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%convert_element_type_3 : [#users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%as_strided_27, torch.bool), kwargs = {})
%_tensor_constant11 : [#users=1] = get_attr[target=_tensor_constant11]
%lift_fresh_copy_7 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant11,), kwargs = {})
%where_3 : [#users=2] = call_function[target=torch.ops.aten.where.self](args = (%convert_element_type_3, %div_6, %lift_fresh_copy_7), kwargs = {})
%amax_3 : [#users=1] = call_function[target=torch.ops.aten.amax.default](args = (%where_3, [-1], True), kwargs = {})
%sub_3 : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%where_3, %amax_3), kwargs = {})
%exp_3 : [#users=2] = call_function[target=torch.ops.aten.exp.default](args = (%sub_3,), kwargs = {})
%sum_4 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp_3, [-1], True), kwargs = {})
%div_7 : [#users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_3, %sum_4), kwargs = {})
%detach_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%div_7,), kwargs = {})
%expand_14 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%div_7, [1, 12, 128, 128]), kwargs = {})
%view_57 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_14, [12, 128, 128]), kwargs = {})
%expand_15 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_14, [1, 12, 128, 64]), kwargs = {})
%view_58 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_15, [12, 128, 64]), kwargs = {})
%bmm_7 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_57, %view_58), kwargs = {})
%_unsafe_view_7 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_7, [1, 12, 128, 64]), kwargs = {})
%permute_15 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_unsafe_view_7, [0, 2, 1, 3]), kwargs = {})
%clone_3 : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_15,), kwargs = {memory_format: torch.contiguous_format})
%view_59 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%clone_3, [1, 128, 768]), kwargs = {})
%view_60 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_59, [-1, 768]), kwargs = {})
%_param_constant42 : [#users=1] = get_attr[target=_param_constant42]
%_param_constant43 : [#users=1] = get_attr[target=_param_constant43]
%addmm_13 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant42, %view_60, %_param_constant43), kwargs = {})
%view_61 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_13, [1, 128, 768]), kwargs = {})
%add_13 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_61, %add_12), kwargs = {})
%_param_constant44 : [#users=1] = get_attr[target=_param_constant44]
%_param_constant45 : [#users=1] = get_attr[target=_param_constant45]
%native_layer_norm_7 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_13, [768], %_param_constant44, %_param_constant45, 1e-05), kwargs = {})
%getitem_21 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_7, 0), kwargs = {})
%getitem_22 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_7, 1), kwargs = {})
%getitem_23 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_7, 2), kwargs = {})
%view_62 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_21, [-1, 768]), kwargs = {})
%_param_constant46 : [#users=1] = get_attr[target=_param_constant46]
%_param_constant47 : [#users=1] = get_attr[target=_param_constant47]
%addmm_14 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant46, %view_62, %_param_constant47), kwargs = {})
%view_63 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_14, [1, 128, 3072]), kwargs = {})
%mul_12 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_63, 0.5), kwargs = {})
%pow_4 : [#users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%view_63, 3.0), kwargs = {})
%mul_13 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%pow_4, 0.044715), kwargs = {})
%add_14 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_63, %mul_13), kwargs = {})
%mul_14 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_14, 0.7978845608028654), kwargs = {})
%tanh_3 : [#users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_14,), kwargs = {})
%detach_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%tanh_3,), kwargs = {})
%add_15 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh_3, 1.0), kwargs = {})
%mul_15 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_12, %add_15), kwargs = {})
%view_64 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_15, [-1, 3072]), kwargs = {})
%_param_constant48 : [#users=1] = get_attr[target=_param_constant48]
%_param_constant49 : [#users=1] = get_attr[target=_param_constant49]
%addmm_15 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant48, %view_64, %_param_constant49), kwargs = {})
%view_65 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_15, [1, 128, 768]), kwargs = {})
%add_16 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_13, %view_65), kwargs = {})
%_param_constant50 : [#users=1] = get_attr[target=_param_constant50]
%_param_constant51 : [#users=1] = get_attr[target=_param_constant51]
%native_layer_norm_8 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_16, [768], %_param_constant50, %_param_constant51, 1e-05), kwargs = {})
%getitem_24 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_8, 0), kwargs = {})
%getitem_25 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_8, 1), kwargs = {})
%getitem_26 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_8, 2), kwargs = {})
%view_66 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_24, [-1, 768]), kwargs = {})
%_param_constant52 : [#users=1] = get_attr[target=_param_constant52]
%_param_constant53 : [#users=1] = get_attr[target=_param_constant53]
%addmm_16 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant52, %view_66, %_param_constant53), kwargs = {})
%view_67 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_16, [1, 128, 2304]), kwargs = {})
%as_strided_28 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_67, [1, 128, 768], [294912, 2304, 1], 0), kwargs = {})
%as_strided_29 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_67, [1, 128, 768], [294912, 2304, 1], 768), kwargs = {})
%as_strided_30 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_67, [1, 128, 768], [294912, 2304, 1], 1536), kwargs = {})
%view_68 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_28, [1, 128, 12, 64]), kwargs = {})
%permute_16 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_68, [0, 2, 1, 3]), kwargs = {})
%view_69 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_29, [1, 128, 12, 64]), kwargs = {})
%permute_17 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_69, [0, 2, 1, 3]), kwargs = {})
%view_70 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_30, [1, 128, 12, 64]), kwargs = {})
%permute_18 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_70, [0, 2, 1, 3]), kwargs = {})
%transpose_4 : [#users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%permute_17, -1, -2), kwargs = {})
%expand_16 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_16, [1, 12, 128, 64]), kwargs = {})
%view_71 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_16, [12, 128, 64]), kwargs = {})
%expand_17 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%transpose_4, [1, 12, 64, 128]), kwargs = {})
%view_72 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_17, [12, 64, 128]), kwargs = {})
%bmm_8 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_71, %view_72), kwargs = {})
%_unsafe_view_8 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_8, [1, 12, 128, 128]), kwargs = {})
%_tensor_constant12 : [#users=1] = get_attr[target=_tensor_constant12]
%lift_fresh_copy_8 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant12,), kwargs = {})
%div_8 : [#users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%_unsafe_view_8, %lift_fresh_copy_8), kwargs = {})
%_tensor_constant13 : [#users=1] = get_attr[target=_tensor_constant13]
%as_strided_31 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%_tensor_constant13, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_32 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_31, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_33 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_32, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_34 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_33, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%convert_element_type_4 : [#users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%as_strided_34, torch.bool), kwargs = {})
%_tensor_constant14 : [#users=1] = get_attr[target=_tensor_constant14]
%lift_fresh_copy_9 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant14,), kwargs = {})
%where_4 : [#users=2] = call_function[target=torch.ops.aten.where.self](args = (%convert_element_type_4, %div_8, %lift_fresh_copy_9), kwargs = {})
%amax_4 : [#users=1] = call_function[target=torch.ops.aten.amax.default](args = (%where_4, [-1], True), kwargs = {})
%sub_4 : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%where_4, %amax_4), kwargs = {})
%exp_4 : [#users=2] = call_function[target=torch.ops.aten.exp.default](args = (%sub_4,), kwargs = {})
%sum_5 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp_4, [-1], True), kwargs = {})
%div_9 : [#users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_4, %sum_5), kwargs = {})
%detach_8 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%div_9,), kwargs = {})
%expand_18 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%div_9, [1, 12, 128, 128]), kwargs = {})
%view_73 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_18, [12, 128, 128]), kwargs = {})
%expand_19 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_18, [1, 12, 128, 64]), kwargs = {})
%view_74 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_19, [12, 128, 64]), kwargs = {})
%bmm_9 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_73, %view_74), kwargs = {})
%_unsafe_view_9 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_9, [1, 12, 128, 64]), kwargs = {})
%permute_19 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_unsafe_view_9, [0, 2, 1, 3]), kwargs = {})
%clone_4 : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_19,), kwargs = {memory_format: torch.contiguous_format})
%view_75 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%clone_4, [1, 128, 768]), kwargs = {})
%view_76 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_75, [-1, 768]), kwargs = {})
%_param_constant54 : [#users=1] = get_attr[target=_param_constant54]
%_param_constant55 : [#users=1] = get_attr[target=_param_constant55]
%addmm_17 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant54, %view_76, %_param_constant55), kwargs = {})
%view_77 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_17, [1, 128, 768]), kwargs = {})
%add_17 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_77, %add_16), kwargs = {})
%_param_constant56 : [#users=1] = get_attr[target=_param_constant56]
%_param_constant57 : [#users=1] = get_attr[target=_param_constant57]
%native_layer_norm_9 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_17, [768], %_param_constant56, %_param_constant57, 1e-05), kwargs = {})
%getitem_27 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_9, 0), kwargs = {})
%getitem_28 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_9, 1), kwargs = {})
%getitem_29 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_9, 2), kwargs = {})
%view_78 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_27, [-1, 768]), kwargs = {})
%_param_constant58 : [#users=1] = get_attr[target=_param_constant58]
%_param_constant59 : [#users=1] = get_attr[target=_param_constant59]
%addmm_18 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant58, %view_78, %_param_constant59), kwargs = {})
%view_79 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_18, [1, 128, 3072]), kwargs = {})
%mul_16 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_79, 0.5), kwargs = {})
%pow_5 : [#users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%view_79, 3.0), kwargs = {})
%mul_17 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%pow_5, 0.044715), kwargs = {})
%add_18 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_79, %mul_17), kwargs = {})
%mul_18 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_18, 0.7978845608028654), kwargs = {})
%tanh_4 : [#users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_18,), kwargs = {})
%detach_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%tanh_4,), kwargs = {})
%add_19 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh_4, 1.0), kwargs = {})
%mul_19 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_16, %add_19), kwargs = {})
%view_80 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_19, [-1, 3072]), kwargs = {})
%_param_constant60 : [#users=1] = get_attr[target=_param_constant60]
%_param_constant61 : [#users=1] = get_attr[target=_param_constant61]
%addmm_19 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant60, %view_80, %_param_constant61), kwargs = {})
%view_81 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_19, [1, 128, 768]), kwargs = {})
%add_20 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_17, %view_81), kwargs = {})
%_param_constant62 : [#users=1] = get_attr[target=_param_constant62]
%_param_constant63 : [#users=1] = get_attr[target=_param_constant63]
%native_layer_norm_10 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_20, [768], %_param_constant62, %_param_constant63, 1e-05), kwargs = {})
%getitem_30 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_10, 0), kwargs = {})
%getitem_31 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_10, 1), kwargs = {})
%getitem_32 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_10, 2), kwargs = {})
%view_82 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_30, [-1, 768]), kwargs = {})
%_param_constant64 : [#users=1] = get_attr[target=_param_constant64]
%_param_constant65 : [#users=1] = get_attr[target=_param_constant65]
%addmm_20 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant64, %view_82, %_param_constant65), kwargs = {})
%view_83 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_20, [1, 128, 2304]), kwargs = {})
%as_strided_35 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_83, [1, 128, 768], [294912, 2304, 1], 0), kwargs = {})
%as_strided_36 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_83, [1, 128, 768], [294912, 2304, 1], 768), kwargs = {})
%as_strided_37 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%view_83, [1, 128, 768], [294912, 2304, 1], 1536), kwargs = {})
%view_84 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_35, [1, 128, 12, 64]), kwargs = {})
%permute_20 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_84, [0, 2, 1, 3]), kwargs = {})
%view_85 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_36, [1, 128, 12, 64]), kwargs = {})
%permute_21 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_85, [0, 2, 1, 3]), kwargs = {})
%view_86 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%as_strided_37, [1, 128, 12, 64]), kwargs = {})
%permute_22 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_86, [0, 2, 1, 3]), kwargs = {})
%transpose_5 : [#users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%permute_21, -1, -2), kwargs = {})
%expand_20 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_20, [1, 12, 128, 64]), kwargs = {})
%view_87 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_20, [12, 128, 64]), kwargs = {})
%expand_21 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%transpose_5, [1, 12, 64, 128]), kwargs = {})
%view_88 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_21, [12, 64, 128]), kwargs = {})
%bmm_10 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_87, %view_88), kwargs = {})
%_unsafe_view_10 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_10, [1, 12, 128, 128]), kwargs = {})
%_tensor_constant15 : [#users=1] = get_attr[target=_tensor_constant15]
%lift_fresh_copy_10 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant15,), kwargs = {})
%div_10 : [#users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%_unsafe_view_10, %lift_fresh_copy_10), kwargs = {})
%_tensor_constant16 : [#users=1] = get_attr[target=_tensor_constant16]
%as_strided_38 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%_tensor_constant16, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_39 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_38, [1, 1, 1024, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_40 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_39, [1, 1, 128, 1024], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%as_strided_41 : [#users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%as_strided_40, [1, 1, 128, 128], [1048576, 1048576, 1024, 1], 0), kwargs = {})
%convert_element_type_5 : [#users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%as_strided_41, torch.bool), kwargs = {})
%_tensor_constant17 : [#users=1] = get_attr[target=_tensor_constant17]
%lift_fresh_copy_11 : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant17,), kwargs = {})
%where_5 : [#users=2] = call_function[target=torch.ops.aten.where.self](args = (%convert_element_type_5, %div_10, %lift_fresh_copy_11), kwargs = {})
%amax_5 : [#users=1] = call_function[target=torch.ops.aten.amax.default](args = (%where_5, [-1], True), kwargs = {})
%sub_5 : [#users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%where_5, %amax_5), kwargs = {})
%exp_5 : [#users=2] = call_function[target=torch.ops.aten.exp.default](args = (%sub_5,), kwargs = {})
%sum_6 : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp_5, [-1], True), kwargs = {})
%div_11 : [#users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_5, %sum_6), kwargs = {})
%detach_10 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%div_11,), kwargs = {})
%expand_22 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%div_11, [1, 12, 128, 128]), kwargs = {})
%view_89 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_22, [12, 128, 128]), kwargs = {})
%expand_23 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_22, [1, 12, 128, 64]), kwargs = {})
%view_90 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_23, [12, 128, 64]), kwargs = {})
%bmm_11 : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_89, %view_90), kwargs = {})
%_unsafe_view_11 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%bmm_11, [1, 12, 128, 64]), kwargs = {})
%permute_23 : [#users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_unsafe_view_11, [0, 2, 1, 3]), kwargs = {})
%clone_5 : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_23,), kwargs = {memory_format: torch.contiguous_format})
%view_91 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%clone_5, [1, 128, 768]), kwargs = {})
%view_92 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_91, [-1, 768]), kwargs = {})
%_param_constant66 : [#users=1] = get_attr[target=_param_constant66]
%_param_constant67 : [#users=1] = get_attr[target=_param_constant67]
%addmm_21 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant66, %view_92, %_param_constant67), kwargs = {})
%view_93 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_21, [1, 128, 768]), kwargs = {})
%add_21 : [#users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_93, %add_20), kwargs = {})
%_param_constant68 : [#users=1] = get_attr[target=_param_constant68]
%_param_constant69 : [#users=1] = get_attr[target=_param_constant69]
%native_layer_norm_11 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_21, [768], %_param_constant68, %_param_constant69, 1e-05), kwargs = {})
%getitem_33 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_11, 0), kwargs = {})
%getitem_34 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_11, 1), kwargs = {})
%getitem_35 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_11, 2), kwargs = {})
%view_94 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_33, [-1, 768]), kwargs = {})
%_param_constant70 : [#users=1] = get_attr[target=_param_constant70]
%_param_constant71 : [#users=1] = get_attr[target=_param_constant71]
%addmm_22 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant70, %view_94, %_param_constant71), kwargs = {})
%view_95 : [#users=3] = call_function[target=torch.ops.aten.view.default](args = (%addmm_22, [1, 128, 3072]), kwargs = {})
%mul_20 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_95, 0.5), kwargs = {})
%pow_6 : [#users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%view_95, 3.0), kwargs = {})
%mul_21 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%pow_6, 0.044715), kwargs = {})
%add_22 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_95, %mul_21), kwargs = {})
%mul_22 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_22, 0.7978845608028654), kwargs = {})
%tanh_5 : [#users=2] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_22,), kwargs = {})
%detach_11 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%tanh_5,), kwargs = {})
%add_23 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh_5, 1.0), kwargs = {})
%mul_23 : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_20, %add_23), kwargs = {})
%view_96 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mul_23, [-1, 3072]), kwargs = {})
%_param_constant72 : [#users=1] = get_attr[target=_param_constant72]
%_param_constant73 : [#users=1] = get_attr[target=_param_constant73]
%addmm_23 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant72, %view_96, %_param_constant73), kwargs = {})
%view_97 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm_23, [1, 128, 768]), kwargs = {})
%add_24 : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_21, %view_97), kwargs = {})
%_param_constant74 : [#users=1] = get_attr[target=_param_constant74]
%_param_constant75 : [#users=1] = get_attr[target=_param_constant75]
%native_layer_norm_12 : [#users=3] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%add_24, [768], %_param_constant74, %_param_constant75, 1e-05), kwargs = {})
%getitem_36 : [#users=1] = call_function[target=operator.getitem](args = (%native_layer_norm_12, 0), kwargs = {})
%getitem_37 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_12, 1), kwargs = {})
%getitem_38 : [#users=0] = call_function[target=operator.getitem](args = (%native_layer_norm_12, 2), kwargs = {})
%view_98 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%getitem_36, [1, 128, 768]), kwargs = {})
%_param_constant76 : [#users=1] = get_attr[target=_param_constant76]
%t : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant76,), kwargs = {})
%view_99 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_98, [128, 768]), kwargs = {})
%mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view_99, %t), kwargs = {})
%_unsafe_view_12 : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%mm, [1, 128, 2]), kwargs = {})
%arange_1 : [#users=1] = call_function[target=torch.ops.aten.arange.default](args = (1,), kwargs = {device: cpu, pin_memory: False})
%select : [#users=1] = call_function[target=torch.ops.aten.select.int](args = (%_unsafe_view_12, 1, -1), kwargs = {})
%index : [#users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%select, [%arange_1]), kwargs = {})
return index
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment