Created
July 17, 2019 20:59
-
-
Save Hanrui-Wang/723829ac44a85ca105891760f3b2e39e to your computer and use it in GitHub Desktop.
who to save and load model in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Save/Load state_dict | |
| torch.save(model.state_dict(), PATH) | |
| model = TheModelClass(*args, **kwargs) | |
| model.load_state_dict(torch.load(PATH)) | |
| model.eval() | |
| # Save/Load Entire Model | |
| torch.save(model, PATH) | |
| # Model class must be defined somewhere | |
| model = torch.load(PATH) | |
| model.eval() | |
| # Saving & Loading a General Checkpoint for Inference and/or Resuming Training | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': loss, | |
| ... | |
| }, PATH) | |
| model = TheModelClass(*args, **kwargs) | |
| optimizer = TheOptimizerClass(*args, **kwargs) | |
| checkpoint = torch.load(PATH) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| epoch = checkpoint['epoch'] | |
| loss = checkpoint['loss'] | |
| model.eval() | |
| # - or - | |
| model.train() | |
| # Saving Multiple Models in One File | |
| torch.save({ | |
| 'modelA_state_dict': modelA.state_dict(), | |
| 'modelB_state_dict': modelB.state_dict(), | |
| 'optimizerA_state_dict': optimizerA.state_dict(), | |
| 'optimizerB_state_dict': optimizerB.state_dict(), | |
| ... | |
| }, PATH) | |
| modelA = TheModelAClass(*args, **kwargs) | |
| modelB = TheModelBClass(*args, **kwargs) | |
| optimizerA = TheOptimizerAClass(*args, **kwargs) | |
| optimizerB = TheOptimizerBClass(*args, **kwargs) | |
| checkpoint = torch.load(PATH) | |
| modelA.load_state_dict(checkpoint['modelA_state_dict']) | |
| modelB.load_state_dict(checkpoint['modelB_state_dict']) | |
| optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) | |
| optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) | |
| modelA.eval() | |
| modelB.eval() | |
| # - or - | |
| modelA.train() | |
| modelB.train() | |
| # Warmstarting Model Using Parameters from a Different Model | |
| torch.save(modelA.state_dict(), PATH) | |
| modelB = TheModelBClass(*args, **kwargs) | |
| modelB.load_state_dict(torch.load(PATH), strict=False) | |
| # If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into. | |
| # Save on GPU, Load on CPU | |
| torch.save(model.state_dict(), PATH) | |
| device = torch.device('cpu') | |
| model = TheModelClass(*args, **kwargs) | |
| model.load_state_dict(torch.load(PATH, map_location=device)) | |
| # Save on GPU, Load on GPU | |
| torch.save(model.state_dict(), PATH) | |
| device = torch.device("cuda") | |
| model = TheModelClass(*args, **kwargs) | |
| model.load_state_dict(torch.load(PATH)) | |
| model.to(device) | |
| # Make sure to call input = input.to(device) on any input tensors that you feed to the model | |
| # Save on CPU, Load on GPU | |
| torch.save(model.state_dict(), PATH) | |
| device = torch.device("cuda") | |
| model = TheModelClass(*args, **kwargs) | |
| model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want | |
| model.to(device) | |
| # Make sure to call input = input.to(device) on any input tensors that you feed to the model | |
| # Saving torch.nn.DataParallel Models | |
| torch.save(model.module.state_dict(), PATH) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment