Created
December 26, 2018 02:27
-
-
Save tamuhey/3b490998a4ea46e341870cd78f60d1fa to your computer and use it in GitHub Desktop.
Pytorch Early Stop Class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as to | |
import torch.nn as nn | |
class EarlyStop: | |
"""Check early stop, and save best params | |
Examples: | |
>>> e=EarlyStop(10) | |
>>> model=nn.Linear(3,5) | |
>>> x=to.rand(3,3) | |
>>> output=model(x).sum() | |
>>> e(output, model) | |
False | |
>>> e(output, model, another_attribute=10) | |
False | |
>>> e.another_attribute # save the value when model has best params | |
10 | |
""" | |
def __init__(self, num_patience: int = 50): | |
self.num_patience = num_patience | |
self.count_early_stop = 0 | |
self.best_value = float("inf") | |
self.best_model_state_dict = None | |
self.cpu = to.device("cpu") | |
def __call__(self, value: float, model: nn.Module, **kwargs) -> bool: | |
if self.best_value < value: | |
if self.count_early_stop > self.num_patience: | |
return True | |
self.count_early_stop += 1 | |
else: | |
self.best_value = value | |
self.best_model_state_dict = self._state_dict_to_cpu(model.state_dict()) | |
self.count_early_stop = 0 | |
for k, v in kwargs.items(): | |
self.__setattr__(k, v) | |
return False | |
def state_dict(self): | |
return self.best_model_state_dict | |
def _state_dict_to_cpu(self, state_dict: dict): | |
for k, v in state_dict.items(): | |
state_dict[k] = v.to(self.cpu) | |
return state_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment