Skip to content

Instantly share code, notes, and snippets.

@yearofthewhopper
Last active February 22, 2020 18:06
Show Gist options
  • Save yearofthewhopper/cf787f32b784d0fdf5a97dffb1207166 to your computer and use it in GitHub Desktop.
Save yearofthewhopper/cf787f32b784d0fdf5a97dffb1207166 to your computer and use it in GitHub Desktop.
# save:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
# load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# or
model.train()
# save entire model
# save
torch.save(model, PATH)
# load
# model class must be defined somewhere
model = torch.load(PATH)
model.eval()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment