Skip to content

Instantly share code, notes, and snippets.

@shravankumar147
Created April 6, 2023 10:10
Show Gist options
  • Save shravankumar147/036198ab1fbb7be39f334913ce9a5e0f to your computer and use it in GitHub Desktop.
Save shravankumar147/036198ab1fbb7be39f334913ce9a5e0f to your computer and use it in GitHub Desktop.
Implementation of early stopping and checkpointing in PyTorch
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from tqdm import tqdm
from collections import defaultdict
class EarlyStoppingCheckpoint:
def __init__(self, model, save_path, metric_name, mode='max', patience=10):
self.model = model
self.save_path = save_path
self.metric_name = metric_name
self.mode = mode
self.patience = patience
self.best_metric = None
self.best_epoch = None
self.epochs_since_improvement = 0
if self.mode == 'max':
self.best_metric = float('-inf')
elif self.mode == 'min':
self.best_metric = float('inf')
else:
raise ValueError(f"Invalid mode: {self.mode}. Must be 'max' or 'min'.")
def __call__(self, metric, epoch):
if self.mode == 'max' and metric > self.best_metric:
self.best_metric = metric
self.best_epoch = epoch
self.epochs_since_improvement = 0
self.save_checkpoint()
elif self.mode == 'min' and metric < self.best_metric:
self.best_metric = metric
self.best_epoch = epoch
self.epochs_since_improvement = 0
self.save_checkpoint()
else:
self.epochs_since_improvement += 1
if self.epochs_since_improvement >= self.patience:
print(f"Stopping early. No improvement in {self.patience} epochs.")
return True # Early stopping
return False # Continue training
def save_checkpoint(self):
checkpoint = {'epoch': self.best_epoch, 'model': self.model.state_dict(), 'metric': self.best_metric}
torch.save(checkpoint, self.save_path)
print(f"Saved checkpoint at epoch {self.best_epoch}, with {self.metric_name} of {self.best_metric:.4f}.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment