import torch

class Good(torch.nn.Module):
    def forward(self, x):
        return x + torch.ones_like(x)

class Bad(torch.nn.Module):
    def forward(self, x):
        return x + torch.full_like(x, float('nan'))


def isfinite_module_forward_hook(module, input, output):
    x = output
    if torch.isfinite(x).logical_not().any():
        # breakpoint() # for interactive debugging
        print(getattr(module, 'name', module), 'isnan:', int(torch.isnan(x).long().sum()), 'isinf:', int(torch.isinf(x).long().sum()))

def assign_module_names(module, prefix = ''):
    for name, submodule in module.named_modules(prefix = prefix):
        submodule.name = name

if __name__ == '__main__':
    model = torch.nn.Sequential(Good(), torch.nn.Sequential(Good(), Bad()), Bad())
    assign_module_names(model, prefix = 'model')
    model.apply(lambda module: module.register_forward_hook(isfinite_module_forward_hook))
    #torch.nn.modules.module.register_module_forward_hook(isfinite_module_forward_hook)
    model(torch.ones(4))