Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active November 16, 2023 13:35
Show Gist options
  • Save vadimkantorov/5e4888d51249c5637783744b953b9a09 to your computer and use it in GitHub Desktop.
Save vadimkantorov/5e4888d51249c5637783744b953b9a09 to your computer and use it in GitHub Desktop.
An example of detecting NaNs / infs in module outputs for some basic debugging
import torch
class Good(torch.nn.Module):
def forward(self, x):
return x + torch.ones_like(x)
class Bad(torch.nn.Module):
def forward(self, x):
return x + torch.full_like(x, float('nan'))
def isfinite_module_forward_hook(module, input, output):
x = output
if torch.isfinite(x).logical_not().any():
# breakpoint() # for interactive debugging
print(getattr(module, 'name', module), 'isnan:', int(torch.isnan(x).long().sum()), 'isinf:', int(torch.isinf(x).long().sum()))
def assign_module_names(module, prefix = ''):
for name, submodule in module.named_modules(prefix = prefix):
submodule.name = name
if __name__ == '__main__':
model = torch.nn.Sequential(Good(), torch.nn.Sequential(Good(), Bad()), Bad())
assign_module_names(model, prefix = 'model')
model.apply(lambda module: module.register_forward_hook(isfinite_module_forward_hook))
#torch.nn.modules.module.register_module_forward_hook(isfinite_module_forward_hook)
model(torch.ones(4))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment