Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rahulvigneswaran/f17f8785ab6308ba7181b2a1470b5c17 to your computer and use it in GitHub Desktop.
Save rahulvigneswaran/f17f8785ab6308ba7181b2a1470b5c17 to your computer and use it in GitHub Desktop.
A boilerplate code to save models and resume from it incase the run crashes.
def save_latest(epoch, model_dir, model, optimizer, scheduler=None, wandb_id=None):
"""Saves latest epoch's weights and other necessary things to resume
"""
model_states = {
"epoch": epoch,
"state_dict": model.state_dict(),
"opt_state_dict": optimizer.state_dict(),
"sch_state_dict": scheduler.state_dict() if scheduler != None else None,
"wandb_id_save": wandb_id, #----> Remove this if you don't use wandb for logging
}
torch.save(model_states, model_dir)
def resume(saved_model_states_dict, model, optimizer, scheduler=None, wandb_id=None):
"""Saves latest epoch's weights and other necessary things to resume
"""
loaded_dict = torch.load(saved_model_states_dict)
epoch = loaded_dict["epoch"]
model.load_state_dict(loaded_dict["state_dict"])
optimizer.load_state_dict(loaded_dict["opt_state_dict"])
if scheduler != None:
scheduler.load_state_dict(loaded_dict["sch_state_dict"])
else:
scheduler = None
wandb_id = loaded_dict["wandb_id_save"] #----> Remove this if you don't use wandb for logging
return epoch, model, optimizer, scheduler, wandb_id
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment