Last active
          January 22, 2020 23:06 
        
      - 
      
- 
        Save vsay01/7cbbd5c4abd265f9ac4e023ffb0f0e7e to your computer and use it in GitHub Desktop. 
    Load model checkpoint in PyTorch
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def load_ckp(checkpoint_fpath, model, optimizer): | |
| """ | |
| checkpoint_path: path to save checkpoint | |
| model: model that we want to load checkpoint parameters into | |
| optimizer: optimizer we defined in previous training | |
| """ | |
| # load check point | |
| checkpoint = torch.load(checkpoint_fpath) | |
| # initialize state_dict from checkpoint to model | |
| model.load_state_dict(checkpoint['state_dict']) | |
| # initialize optimizer from checkpoint to optimizer | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| # initialize valid_loss_min from checkpoint to valid_loss_min | |
| valid_loss_min = checkpoint['valid_loss_min'] | |
| # return model, optimizer, epoch value, min validation loss | |
| return model, optimizer, checkpoint['epoch'], valid_loss_min.item() | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment