Skip to content

Instantly share code, notes, and snippets.

@vsay01
Last active January 22, 2020 23:06
Show Gist options
  • Save vsay01/7cbbd5c4abd265f9ac4e023ffb0f0e7e to your computer and use it in GitHub Desktop.
Save vsay01/7cbbd5c4abd265f9ac4e023ffb0f0e7e to your computer and use it in GitHub Desktop.
Load model checkpoint in PyTorch
def load_ckp(checkpoint_fpath, model, optimizer):
"""
checkpoint_path: path to save checkpoint
model: model that we want to load checkpoint parameters into
optimizer: optimizer we defined in previous training
"""
# load check point
checkpoint = torch.load(checkpoint_fpath)
# initialize state_dict from checkpoint to model
model.load_state_dict(checkpoint['state_dict'])
# initialize optimizer from checkpoint to optimizer
optimizer.load_state_dict(checkpoint['optimizer'])
# initialize valid_loss_min from checkpoint to valid_loss_min
valid_loss_min = checkpoint['valid_loss_min']
# return model, optimizer, epoch value, min validation loss
return model, optimizer, checkpoint['epoch'], valid_loss_min.item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment