Last active
September 2, 2022 10:04
-
-
Save Pangoraw/847c256203bf0736843af73164e8ff67 to your computer and use it in GitHub Desktop.
A helper class similar to `torch.autograd.detect_anomaly()` but for the forward pass!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
from torch import Tensor, nn | |
def get_hook(mod_name: str, layer_name: str, check_inf: bool = True): | |
def hook(_module: nn.Module, input: Tuple, output: Tuple) -> None: | |
for i, t in enumerate(input): | |
if torch.any(t.isnan()): | |
raise Exception( | |
f"input #{i} to layer {layer_name} of module {mod_name} has NaNs ({t.isnan().sum()} NaNs / {t.numel()})" | |
) | |
if check_inf and torch.any(t.isinf()): | |
raise Exception( | |
f"input #{i} to layer {layer_name} of module {mod_name} has Infs ({t.isinf().sum()} Infs / {t.numel()})" | |
) | |
for i, t in enumerate(output): | |
if torch.any(t.isnan()): | |
raise Exception( | |
f"output #{i} to layer {layer_name} of module {mod_name} has NaNs ({t.isnan().sum()} NaNs / {t.numel()})" | |
) | |
if check_inf and torch.any(t.isinf()): | |
raise Exception( | |
f"output #{i} to layer {layer_name} of module {mod_name} has Infs ({t.isinf().sum()} Infs / {t.numel()})" | |
) | |
return hook | |
class forward_detect_anomaly: | |
def __init__(self, module: nn.Module, check_inf: bool = True) -> None: | |
self.handles = [] | |
self.module = module | |
self.check_inf = check_inf | |
def __enter__(self): | |
mod_name = self.module.__class__.__name__ | |
for name, submod in self.module.named_modules(): | |
self.handles.append(submod.register_forward_hook(get_hook(mod_name, name, check_inf=self.check_inf))) | |
def __exit__(self, *_): | |
for handle in self.handles: | |
handle.remove() | |
if __name__ == "__main__": | |
class ToNaN(nn.Module): | |
def forward(self, x): | |
return x + torch.nan | |
class MyNetwork(nn.Module): | |
def __init__(self): | |
super(MyNetwork, self).__init__() | |
self.layer1 = nn.Sequential( | |
ToNaN(), | |
nn.Linear(10, 2), | |
) | |
def forward(self, x): | |
return self.layer1(x) | |
model = MyNetwork() | |
with forward_detect_anomaly(model): | |
x = torch.randn(2, 10) | |
y = model(x) | |
print(y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment