Skip to content

Instantly share code, notes, and snippets.

@realyanyang
Created November 4, 2020 14:26
Show Gist options
  • Save realyanyang/9d88e470c784d22b920e1e62038655f7 to your computer and use it in GitHub Desktop.
Save realyanyang/9d88e470c784d22b920e1e62038655f7 to your computer and use it in GitHub Desktop.
class EarlyStop:
"""
Early stops the training if validation loss doesn't improve after a given patience.
"""
def __init__(self, patience=10, verbose=False, delta=0, mode='min'):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 10
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
mode (str): The mode of training task, `min` means the less the better, `max` means
the larger the better.
Default: min
"""
if mode not in ['min', 'max']:
raise ValueError('check mode')
self.mode = mode
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.delta = delta
self.best_epoch = None
if mode == 'min':
self.init_score = float('inf')
else:
self.init_score = float('-inf')
def __call__(self, score, model, save_path, epoch):
"""
record loss or metric.
"""
if self.mode == 'min':
score = -score
if self.best_score is None:
self.best_score = score
self.save_checkpoint(score, model, save_path)
self.best_epoch = epoch
elif score < self.best_score - self.delta:
self.counter += 1
if self.verbose:
print('EarlyStop counter: %d out of %d' % (self.counter, self.patience))
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(score, model, save_path)
self.counter = 0
self.best_epoch = epoch
def save_checkpoint(self, score, model, save_path):
"""
Saves model when validation loss decrease.
"""
if self.mode == 'min':
if self.verbose:
print('Validation score decreased (%.4f --> %.4f). Saving model ...' % (self.init_score, -score))
self.init_score = -score
else:
if self.verbose:
print('Validation score increased (%.4f --> %.4f). Saving model ...' % (self.init_score, score))
self.init_score = score
torch.save(model.state_dict(), save_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment