Created
June 2, 2021 23:39
-
-
Save rahulvigneswaran/f17f8785ab6308ba7181b2a1470b5c17 to your computer and use it in GitHub Desktop.
A boilerplate code to save models and resume from it incase the run crashes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_latest(epoch, model_dir, model, optimizer, scheduler=None, wandb_id=None): | |
"""Saves latest epoch's weights and other necessary things to resume | |
""" | |
model_states = { | |
"epoch": epoch, | |
"state_dict": model.state_dict(), | |
"opt_state_dict": optimizer.state_dict(), | |
"sch_state_dict": scheduler.state_dict() if scheduler != None else None, | |
"wandb_id_save": wandb_id, #----> Remove this if you don't use wandb for logging | |
} | |
torch.save(model_states, model_dir) | |
def resume(saved_model_states_dict, model, optimizer, scheduler=None, wandb_id=None): | |
"""Saves latest epoch's weights and other necessary things to resume | |
""" | |
loaded_dict = torch.load(saved_model_states_dict) | |
epoch = loaded_dict["epoch"] | |
model.load_state_dict(loaded_dict["state_dict"]) | |
optimizer.load_state_dict(loaded_dict["opt_state_dict"]) | |
if scheduler != None: | |
scheduler.load_state_dict(loaded_dict["sch_state_dict"]) | |
else: | |
scheduler = None | |
wandb_id = loaded_dict["wandb_id_save"] #----> Remove this if you don't use wandb for logging | |
return epoch, model, optimizer, scheduler, wandb_id | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment