Skip to content

Instantly share code, notes, and snippets.

@Hanrui-Wang
Created July 17, 2019 20:59
Show Gist options
  • Select an option

  • Save Hanrui-Wang/723829ac44a85ca105891760f3b2e39e to your computer and use it in GitHub Desktop.

Select an option

Save Hanrui-Wang/723829ac44a85ca105891760f3b2e39e to your computer and use it in GitHub Desktop.
who to save and load model in pytorch
# Save/Load state_dict
torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
# Save/Load Entire Model
torch.save(model, PATH)
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
# Saving & Loading a General Checkpoint for Inference and/or Resuming Training
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
# Saving Multiple Models in One File
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
# Warmstarting Model Using Parameters from a Different Model
torch.save(modelA.state_dict(), PATH)
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
# If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
# Save on GPU, Load on CPU
torch.save(model.state_dict(), PATH)
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
# Save on GPU, Load on GPU
torch.save(model.state_dict(), PATH)
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
# Save on CPU, Load on GPU
torch.save(model.state_dict(), PATH)
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
# Saving torch.nn.DataParallel Models
torch.save(model.module.state_dict(), PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment