Last active
September 19, 2022 22:47
-
-
Save YashashGaurav/4db384fe3c255d6d52ec6a3c0b88e6fe to your computer and use it in GitHub Desktop.
best version of model logging system that I have built for GDrive - Kinda depends on wandb for naming - but easily customizable to any to other service provider.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Model Logging setup | |
import os | |
from os.path import isfile, join | |
def log_checkpoint( | |
epoch: int, | |
model: torch.nn.Module, | |
optimizer: torch.optim.Optimizer, | |
lr_scheduler: object, # Cuz the base class seemed like a protected instance | |
metric: float = None, | |
track_greater_metric: bool = False | |
): | |
"""Logs a checkpoint (a torch model) only if the metric passed is one of | |
top 3 (to save space of course) metrics in the directory that | |
the models are being saved in. | |
Format: if wandb instance exists the checkpoints are stored at: | |
checkpoint directory -> | |
{args["checkpoint_path"]}/trainings/{wandb.run.name}/ | |
else, | |
checkpoint directory -> {args["checkpoint_path"]}/trainings/temp/ | |
where the file is named: | |
{metric}_{project_name}_checkpoint.h5 | |
if metric is not provided, we overwrite: | |
{project_name}_checkpoint.h5 given above checkpoint directory | |
Beyond the function params, the function also expects an 'args' dictionary | |
that has: | |
- args['checkpoint_epoch_step'] = number of epochs after which we want to | |
try to log a checkpoint. This can help save time if you want to skip a few | |
epochs and then log your checkpoint. | |
- args['project_name'] = a project name for the experiments that we are | |
running. We add this detail to the checkpoint file saved. | |
and 'hyper_params' dictionary where: | |
hyper_params['epochs'] = total number of epochs so that we definitely | |
log the last epoch's checkpoint. | |
:param epoch: current epoch index | |
:type epoch: int | |
:param model: model that you are trying to log using torch.save() | |
:type model: torch.nn.Module | |
:param optimizer: optimizer to be stored with its state | |
:type optimizer: torch.optim.Optimizer | |
:param lr_scheduler: LR Scheduler used for the experiment. | |
:type lr_scheduler: _LRScheduler - github/torch/optim/lr_scheduler.py#L25 | |
:param metric: value that you want to log the model based on (like Val acc), | |
if not provided we save the model | |
checkpoints by, defaults to None | |
:type metric: float, optional | |
:param track_greater_metric: To be set to true if higher metric passed | |
means that the model is better | |
:type track_greater_metric: bool | |
""" | |
if ( | |
epoch % args["checkpoint_epoch_step"] == 0 | |
or epoch == hyper_params["epochs"] | |
): | |
state = { | |
"epoch": epoch + 1, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"lr_scheduler_state_dict": lr_scheduler.state_dict(), | |
} | |
# we add a project name tag to the checkpoints saved. | |
project_name = args["project_name"] | |
# create path | |
if wandb: | |
check_point_dir = ( | |
args["checkpoint_path"] + "trainings/" + wandb.run.name | |
) | |
else: | |
check_point_dir = args["checkpoint_path"] + "trainings/temp/" | |
if not os.path.exists(check_point_dir): | |
os.makedirs(check_point_dir) | |
onlyfiles = [ | |
float(f.split(f"_{project_name}_checkpoint.h5")[0]) | |
for f in os.listdir(check_point_dir) | |
if isfile(join(check_point_dir, f)) | |
and f"_{project_name}_checkpoint.h5" in f | |
] | |
if metric != None: | |
checkpoint_file_path = ( | |
check_point_dir + f"/{metric}_{project_name}_checkpoint.h5" | |
) | |
if len(onlyfiles) >= 3: | |
if track_greater_metric and metric > sorted(onlyfiles, reverse=True)[2]: | |
torch.save(state, checkpoint_file_path) | |
os.remove( | |
check_point_dir | |
+ f"/{sorted(onlyfiles, reverse=True)[3]}_{project_name}_checkpoint.h5" | |
) | |
elif (not track_greater_metric) and metric < sorted(onlyfiles)[2]: | |
torch.save(state, checkpoint_file_path) | |
os.remove( | |
check_point_dir | |
+ f"/{sorted(onlyfiles)[3]}_{project_name}_checkpoint.h5" | |
) | |
elif len(onlyfiles) < 3: | |
torch.save(state, checkpoint_file_path) | |
else: | |
checkpoint_file_path = ( | |
check_point_dir + f"/{project_name}_checkpoint.h5" | |
) | |
torch.save(state, checkpoint_file_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment