Created
April 20, 2023 22:47
-
-
Save wanchaol/93056bca54b4ef46c8a59ed8821a4652 to your computer and use it in GitHub Desktop.
tracing val issue for fused_adam
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch._meta_registrations import register_meta | |
aten = torch.ops.aten | |
@register_meta([aten._fused_adam.default]) | |
def meta__fused_adam( | |
params, | |
grads, | |
exp_avgs, | |
exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
*, | |
lr, | |
beta1, | |
beta2, | |
weight_decay, | |
eps, | |
amsgrad, | |
maximize, | |
grad_scale=None, | |
found_inf=None, | |
): | |
def empty_like_list(tensor_list): | |
return [torch.empty_like(t) for t in tensor_list] | |
return ( | |
empty_like_list(params), | |
empty_like_list(grads), | |
empty_like_list(exp_avgs), | |
empty_like_list(exp_avg_sqs), | |
empty_like_list(state_steps), | |
) | |
params = [torch.randn(10, 10, requires_grad=True, device="cuda") for _ in range(10)] | |
grads = [torch.randn(10, 10, device="cuda") for _ in range(10)] | |
exp_avgs = [torch.randn(10, 10, device="cuda") for _ in range(10)] | |
exp_avg_sqs = [torch.randn(10, 10, device="cuda") for _ in range(10)] | |
max_exp_avg_sqs = [torch.randn(10, 10, device="cuda") for _ in range(10)] | |
state_steps = [torch.tensor(0, device="cuda") for _ in range(10)] | |
def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps): | |
return aten._fused_adam.default( | |
params, | |
grads, | |
exp_avgs, | |
exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
lr=0.1, | |
beta1=0.9, | |
beta2=0.999, | |
weight_decay=0.01, | |
eps=1e-8, | |
amsgrad=False, | |
maximize=False, | |
) | |
# trace | |
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule | |
gm = get_isolated_graphmodule( | |
fused_adam, | |
args=( | |
params, | |
grads, | |
exp_avgs, | |
exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
), | |
kwargs={}, | |
) | |
gm.print_readable() | |
for n in gm.graph.nodes: | |
if n.op == "call_function" and n.target == aten._fused_adam.default: | |
print(f"node: {n}, meta: {n.meta}, val in meta: {'val' in n.meta}") | |
======== | |
output: | |
class wrapped(torch.nn.Module): | |
def forward(self, flat_args): | |
flat_args_1: f32[10, 10], flat_args_2: f32[10, 10], flat_args_3: f32[10, 10], flat_args_4: f32[10, 10], flat_args_5: f32[10, 10], flat_args_6: f32[10, 10], flat_args_7: f32[10, 10], flat_args_8: f32[10, 10], flat_args_9: f32[10, 10], flat_args_10: f32[10, 10], flat_args_11: f32[10, 10], flat_args_12: f32[10, 10], flat_args_13: f32[10, 10], flat_args_14: f32[10, 10], flat_args_15: f32[10, 10], flat_args_16: f32[10, 10], flat_args_17: f32[10, 10], flat_args_18: f32[10, 10], flat_args_19: f32[10, 10], flat_args_20: f32[10, 10], flat_args_21: f32[10, 10], flat_args_22: f32[10, 10], flat_args_23: f32[10, 10], flat_args_24: f32[10, 10], flat_args_25: f32[10, 10], flat_args_26: f32[10, 10], flat_args_27: f32[10, 10], flat_args_28: f32[10, 10], flat_args_29: f32[10, 10], flat_args_30: f32[10, 10], flat_args_31: f32[10, 10], flat_args_32: f32[10, 10], flat_args_33: f32[10, 10], flat_args_34: f32[10, 10], flat_args_35: f32[10, 10], flat_args_36: f32[10, 10], flat_args_37: f32[10, 10], flat_args_38: f32[10, 10], flat_args_39: f32[10, 10], flat_args_40: f32[10, 10], flat_args_41: f32[10, 10], flat_args_42: f32[10, 10], flat_args_43: f32[10, 10], flat_args_44: f32[10, 10], flat_args_45: f32[10, 10], flat_args_46: f32[10, 10], flat_args_47: f32[10, 10], flat_args_48: f32[10, 10], flat_args_49: f32[10, 10], flat_args_50: f32[10, 10], flat_args_51: i64[], flat_args_52: i64[], flat_args_53: i64[], flat_args_54: i64[], flat_args_55: i64[], flat_args_56: i64[], flat_args_57: i64[], flat_args_58: i64[], flat_args_59: i64[], flat_args_60: i64[], = fx_pytree.tree_flatten_spec([flat_args], self._in_spec) | |
# No stacktrace found for following nodes | |
_fused_adam = torch.ops.aten._fused_adam.default([flat_args_1, flat_args_2, flat_args_3, flat_args_4, flat_args_5, flat_args_6, flat_args_7, flat_args_8, flat_args_9, flat_args_10], [flat_args_11, flat_args_12, flat_args_13, flat_args_14, flat_args_15, flat_args_16, flat_args_17, flat_args_18, flat_args_19, flat_args_20], [flat_args_21, flat_args_22, flat_args_23, flat_args_24, flat_args_25, flat_args_26, flat_args_27, flat_args_28, flat_args_29, flat_args_30], [flat_args_31, flat_args_32, flat_args_33, flat_args_34, flat_args_35, flat_args_36, flat_args_37, flat_args_38, flat_args_39, flat_args_40], [flat_args_41, flat_args_42, flat_args_43, flat_args_44, flat_args_45, flat_args_46, flat_args_47, flat_args_48, flat_args_49, flat_args_50], [flat_args_51, flat_args_52, flat_args_53, flat_args_54, flat_args_55, flat_args_56, flat_args_57, flat_args_58, flat_args_59, flat_args_60], lr = 0.1, beta1 = 0.9, beta2 = 0.999, weight_decay = 0.01, eps = 1e-08, amsgrad = False, maximize = False); flat_args_1 = flat_args_2 = flat_args_3 = flat_args_4 = flat_args_5 = flat_args_6 = flat_args_7 = flat_args_8 = flat_args_9 = flat_args_10 = flat_args_11 = flat_args_12 = flat_args_13 = flat_args_14 = flat_args_15 = flat_args_16 = flat_args_17 = flat_args_18 = flat_args_19 = flat_args_20 = flat_args_21 = flat_args_22 = flat_args_23 = flat_args_24 = flat_args_25 = flat_args_26 = flat_args_27 = flat_args_28 = flat_args_29 = flat_args_30 = flat_args_31 = flat_args_32 = flat_args_33 = flat_args_34 = flat_args_35 = flat_args_36 = flat_args_37 = flat_args_38 = flat_args_39 = flat_args_40 = flat_args_41 = flat_args_42 = flat_args_43 = flat_args_44 = flat_args_45 = flat_args_46 = flat_args_47 = flat_args_48 = flat_args_49 = flat_args_50 = flat_args_51 = flat_args_52 = flat_args_53 = flat_args_54 = flat_args_55 = flat_args_56 = flat_args_57 = flat_args_58 = flat_args_59 = flat_args_60 = None | |
getitem = _fused_adam[0] | |
getitem_1: f32[10, 10] = getitem[0] | |
getitem_2: f32[10, 10] = getitem[1] | |
getitem_3: f32[10, 10] = getitem[2] | |
getitem_4: f32[10, 10] = getitem[3] | |
getitem_5: f32[10, 10] = getitem[4] | |
getitem_6: f32[10, 10] = getitem[5] | |
getitem_7: f32[10, 10] = getitem[6] | |
getitem_8: f32[10, 10] = getitem[7] | |
getitem_9: f32[10, 10] = getitem[8] | |
getitem_10: f32[10, 10] = getitem[9]; getitem = None | |
getitem_11 = _fused_adam[1] | |
getitem_12: f32[10, 10] = getitem_11[0] | |
getitem_13: f32[10, 10] = getitem_11[1] | |
getitem_14: f32[10, 10] = getitem_11[2] | |
getitem_15: f32[10, 10] = getitem_11[3] | |
getitem_16: f32[10, 10] = getitem_11[4] | |
getitem_17: f32[10, 10] = getitem_11[5] | |
getitem_18: f32[10, 10] = getitem_11[6] | |
getitem_19: f32[10, 10] = getitem_11[7] | |
getitem_20: f32[10, 10] = getitem_11[8] | |
getitem_21: f32[10, 10] = getitem_11[9]; getitem_11 = None | |
getitem_22 = _fused_adam[2] | |
getitem_23: f32[10, 10] = getitem_22[0] | |
getitem_24: f32[10, 10] = getitem_22[1] | |
getitem_25: f32[10, 10] = getitem_22[2] | |
getitem_26: f32[10, 10] = getitem_22[3] | |
getitem_27: f32[10, 10] = getitem_22[4] | |
getitem_28: f32[10, 10] = getitem_22[5] | |
getitem_29: f32[10, 10] = getitem_22[6] | |
getitem_30: f32[10, 10] = getitem_22[7] | |
getitem_31: f32[10, 10] = getitem_22[8] | |
getitem_32: f32[10, 10] = getitem_22[9]; getitem_22 = None | |
getitem_33 = _fused_adam[3] | |
getitem_34: f32[10, 10] = getitem_33[0] | |
getitem_35: f32[10, 10] = getitem_33[1] | |
getitem_36: f32[10, 10] = getitem_33[2] | |
getitem_37: f32[10, 10] = getitem_33[3] | |
getitem_38: f32[10, 10] = getitem_33[4] | |
getitem_39: f32[10, 10] = getitem_33[5] | |
getitem_40: f32[10, 10] = getitem_33[6] | |
getitem_41: f32[10, 10] = getitem_33[7] | |
getitem_42: f32[10, 10] = getitem_33[8] | |
getitem_43: f32[10, 10] = getitem_33[9]; getitem_33 = None | |
getitem_44 = _fused_adam[4]; _fused_adam = None | |
getitem_45: f32[10, 10] = getitem_44[0] | |
getitem_46: f32[10, 10] = getitem_44[1] | |
getitem_47: f32[10, 10] = getitem_44[2] | |
getitem_48: f32[10, 10] = getitem_44[3] | |
getitem_49: f32[10, 10] = getitem_44[4] | |
getitem_50: f32[10, 10] = getitem_44[5] | |
getitem_51: f32[10, 10] = getitem_44[6] | |
getitem_52: f32[10, 10] = getitem_44[7] | |
getitem_53: f32[10, 10] = getitem_44[8] | |
getitem_54: f32[10, 10] = getitem_44[9]; getitem_44 = None | |
return pytree.tree_unflatten([getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7, getitem_8, getitem_9, getitem_10, getitem_12, getitem_13, getitem_14, getitem_15, getitem_16, getitem_17, getitem_18, getitem_19, getitem_20, getitem_21, getitem_23, getitem_24, getitem_25, getitem_26, getitem_27, getitem_28, getitem_29, getitem_30, getitem_31, getitem_32, getitem_34, getitem_35, getitem_36, getitem_37, getitem_38, getitem_39, getitem_40, getitem_41, getitem_42, getitem_43, getitem_45, getitem_46, getitem_47, getitem_48, getitem_49, getitem_50, getitem_51, getitem_52, getitem_53, getitem_54], self._out_spec) | |
node: _fused_adam, meta: {}, val in meta: False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment