Created
May 15, 2020 14:37
-
-
Save mjpost/b03346afbaf8cff367f82e8dceb48fc7 to your computer and use it in GitHub Desktop.
Removes ADAM optimizer state from fairseq models, greatly reducing their size
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
This is code to take a trained Fairseq model and discard the ADAM optimizer state, | |
which is not needed at test time. It can reduce a model size by ~70%. | |
Original author: Brian Thompson | |
""" | |
from fairseq import checkpoint_utils | |
import torch | |
import argparse | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Strip ADAM optimizer states out of a fairseq checkpoint to make it smaller for release.') | |
parser.add_argument('fin', type=str, help='input checkpoint file') | |
parser.add_argument('fout', type=str, help='output checkpoint file') | |
args = parser.parse_args() | |
assert args.fin != args.fout # do not allow overwrite input | |
model = checkpoint_utils.load_checkpoint_to_cpu(args.fin) | |
for key in model['last_optimizer_state']['state']: | |
del model['last_optimizer_state']['state'][key]['exp_avg_sq'] | |
del model['last_optimizer_state']['state'][key]['exp_avg'] | |
torch.save(model, f=open(args.fout, 'wb')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment