import torch class Good(torch.nn.Module): def forward(self, x): return x + torch.ones_like(x) class Bad(torch.nn.Module): def forward(self, x): return x + torch.full_like(x, float('nan')) def isfinite_module_forward_hook(module, input, output): x = output if torch.isfinite(x).logical_not().any(): # breakpoint() # for interactive debugging print(getattr(module, 'name', module), 'isnan:', int(torch.isnan(x).long().sum()), 'isinf:', int(torch.isinf(x).long().sum())) def assign_module_names(module, prefix = ''): for name, submodule in module.named_modules(prefix = prefix): submodule.name = name if __name__ == '__main__': model = torch.nn.Sequential(Good(), torch.nn.Sequential(Good(), Bad()), Bad()) assign_module_names(model, prefix = 'model') model.apply(lambda module: module.register_forward_hook(isfinite_module_forward_hook)) #torch.nn.modules.module.register_module_forward_hook(isfinite_module_forward_hook) model(torch.ones(4))