Before training:
last = model.state_dict()Inside training loop, after computing loss:
if torch.isnan(loss).sum().item():
model.load_state_dict(last)
else:
last = model.state_dict()Before training:
last = model.state_dict()Inside training loop, after computing loss:
if torch.isnan(loss).sum().item():
model.load_state_dict(last)
else:
last = model.state_dict()