Skip to content

Instantly share code, notes, and snippets.

@mjpost
Created May 15, 2020 14:37
Show Gist options
  • Save mjpost/b03346afbaf8cff367f82e8dceb48fc7 to your computer and use it in GitHub Desktop.
Save mjpost/b03346afbaf8cff367f82e8dceb48fc7 to your computer and use it in GitHub Desktop.
Removes ADAM optimizer state from fairseq models, greatly reducing their size
#!/usr/bin/env python3
"""
This is code to take a trained Fairseq model and discard the ADAM optimizer state,
which is not needed at test time. It can reduce a model size by ~70%.
Original author: Brian Thompson
"""
from fairseq import checkpoint_utils
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Strip ADAM optimizer states out of a fairseq checkpoint to make it smaller for release.')
parser.add_argument('fin', type=str, help='input checkpoint file')
parser.add_argument('fout', type=str, help='output checkpoint file')
args = parser.parse_args()
assert args.fin != args.fout # do not allow overwrite input
model = checkpoint_utils.load_checkpoint_to_cpu(args.fin)
for key in model['last_optimizer_state']['state']:
del model['last_optimizer_state']['state'][key]['exp_avg_sq']
del model['last_optimizer_state']['state'][key]['exp_avg']
torch.save(model, f=open(args.fout, 'wb'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment