Skip to content

Instantly share code, notes, and snippets.

@pkubik
Created November 2, 2020 21:48
Show Gist options
  • Save pkubik/a28f97553932380ff245d6dd600b37f5 to your computer and use it in GitHub Desktop.
Save pkubik/a28f97553932380ff245d6dd600b37f5 to your computer and use it in GitHub Desktop.
Surgically rename a variable within pytorch model.
import torch
PATH = 'model.pth'
ckpt = torch.load(PATH)
# `ckpt` will likely be a nested dictionary with nested `state_dict`
ckpt['state_dict'] = {
key.replace('old_name', 'new_name'): value
for key, value in ckpt['state_dict'].items()
}
# You may need to be smarter if the names are less distinctive, e.g. include some prefix
torch.save(ckpt, PATH)
# You'd rather like to use a different name in case something went wrong
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment