Created
September 26, 2019 17:41
-
-
Save thomasbrandon/6279cc5b4dc47b5fc4b9f37dbdbf9c50 to your computer and use it in GitHub Desktop.
FastAI callback to find non-finite gradients and losses
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.basics import * | |
class ErrorCallback(LearnerCallback): | |
def __init__(self, lrn:Learner): | |
super().__init__(lrn) | |
self.err_loss,self.err_input,self.err_output = None,None,None | |
def on_train_begin(self, **kwargs): | |
def hook(mod, inps, outs): | |
nfs = [] | |
for inp in inps: | |
if inp is None: continue | |
inp = inp.detach() | |
nfs.append(( | |
(inp == inp.new_full((1,), np.inf)).sum().cpu(), # Count non-finites | |
(inp == inp.new_full((1,), np.nan)).sum().cpu() # On GPU so don't check yet | |
)) | |
return (mod, nfs) | |
self.module_names = {m: n for n,m in iter_children(mdl_mish)} | |
self.hooks = callbacks.Hooks([m for m in self.module_names.keys() if hasattr(m, 'weight')], | |
hook, is_forward=False, detach=False) | |
def on_batch_end(self, num_batch, last_loss, last_input, last_output, pbar, **kwargs): | |
if not np.isfinite(last_loss) and self.err_loss is None: | |
self.err_loss,self.err_input,self.err_output = last_loss,last_input,last_output | |
pbar.write(f"Non-finite loss on batch {num_batch}") | |
return {'stop_epoch': True, 'stop_training': True} | |
def on_backward_end(self, num_batch, last_loss, last_input, last_output, pbar, **kwargs): | |
for mod,nfs in self.hooks.stored: | |
infs,nans = 0,0 | |
for inf,nan in nfs: | |
infs += inf | |
nans += nan | |
if infs or nans: | |
name = self.module_names[mod] | |
pbar.write(f"Non-finite gradients on batch {num_batch} from child {name}, {infs} inf, {nans} nan. Aborting.") | |
self.err_loss,self.err_input,self.err_output = last_loss,last_input,last_output | |
return {'stop_epoch': True, 'stop_training': True} | |
def on_train_end(self, **kwargs): self.hooks.remove() | |
def on_epoch_end(self, **kwargs): | |
if self.err_loss is not None: return {'stop_training': True} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment