Created
March 27, 2021 03:02
-
-
Save ArtVanderlay/f08fc36cf79e53fbe03804697735c8f4 to your computer and use it in GitHub Desktop.
checkpoint averaging script for huggingface models in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Slightly tweaked version of https://github.com/pytorch/fairseq/blob/master/scripts/average_checkpoints.py | |
# To work with Huggingface wav2vec2 | |
import argparse | |
import collections | |
import os | |
import re | |
import torch | |
from fairseq.file_io import PathManager | |
def average_checkpoints(inputs): | |
"""Loads checkpoints from inputs and returns a model with averaged weights. | |
Args: | |
inputs: An iterable of string paths of checkpoints to load from. | |
Returns: | |
A dict of string keys mapping to various values. The 'model' key | |
from the returned dict should correspond to an OrderedDict mapping | |
string parameter names to torch Tensors. | |
""" | |
params_dict = collections.OrderedDict() | |
params_keys = None | |
new_state = None | |
num_models = len(inputs) | |
for fpath in inputs: | |
print(f'filepath: {fpath}') | |
with PathManager.open(fpath, "rb") as f: | |
state = torch.load( | |
f, | |
map_location="cuda" | |
) | |
# Copies over the settings from the first checkpoint | |
if new_state is None: | |
new_state = state | |
model_params = state | |
model_params_keys = list(model_params.keys()) | |
if params_keys is None: | |
params_keys = model_params_keys | |
elif params_keys != model_params_keys: | |
raise KeyError( | |
"For checkpoint {}, expected list of params: {}, " | |
"but found: {}".format(f, params_keys, model_params_keys) | |
) | |
for k in params_keys: | |
p = model_params[k] | |
if isinstance(p, torch.HalfTensor): | |
p = p.float() | |
if k not in params_dict: | |
params_dict[k] = p.clone() | |
# NOTE: clone() is needed in case of p is a shared parameter | |
else: | |
params_dict[k] += p | |
averaged_params = collections.OrderedDict() | |
for k, v in params_dict.items(): | |
averaged_params[k] = v | |
if averaged_params[k].is_floating_point(): | |
averaged_params[k].div_(num_models) | |
else: | |
averaged_params[k] //= num_models | |
new_state = averaged_params | |
return new_state | |
def last_n_checkpoints(paths, n, update_based, upper_bound=None): | |
assert len(paths) == 1 | |
path = paths[0] | |
if update_based: | |
pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") | |
else: | |
pt_regexp = re.compile(r"checkpoint(\d+)\.pt") | |
files = PathManager.ls(path) | |
entries = [] | |
for f in files: | |
m = pt_regexp.fullmatch(f) | |
if m is not None: | |
sort_key = int(m.group(1)) | |
if upper_bound is None or sort_key <= upper_bound: | |
entries.append((sort_key, m.group(0))) | |
if len(entries) < n: | |
raise Exception( | |
"Found {} checkpoint files but need at least {}", len(entries), n | |
) | |
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Tool to average the params of input checkpoints to " | |
"produce a new checkpoint", | |
) | |
# fmt: off | |
parser.add_argument('--inputs', required=True, nargs='+', | |
help='Input checkpoint file paths.') | |
parser.add_argument('--output', required=True, metavar='FILE', | |
help='Write the new checkpoint containing the averaged weights to this path.') | |
num_group = parser.add_mutually_exclusive_group() | |
num_group.add_argument('--num-epoch-checkpoints', type=int, | |
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' | |
'and average last this many of them.') | |
num_group.add_argument('--num-update-checkpoints', type=int, | |
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' | |
'and average last this many of them.') | |
parser.add_argument('--checkpoint-upper-bound', type=int, | |
help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' | |
'when using --num-update-checkpoints, this will set an upper bound on which update to use' | |
'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.' | |
'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500' | |
) | |
# fmt: on | |
args = parser.parse_args() | |
print(args) | |
num = None | |
is_update_based = False | |
if args.num_update_checkpoints is not None: | |
num = args.num_update_checkpoints | |
is_update_based = True | |
elif args.num_epoch_checkpoints is not None: | |
num = args.num_epoch_checkpoints | |
assert args.checkpoint_upper_bound is None or ( | |
args.num_epoch_checkpoints is not None | |
or args.num_update_checkpoints is not None | |
), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" | |
assert ( | |
args.num_epoch_checkpoints is None or args.num_update_checkpoints is None | |
), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" | |
if num is not None: | |
args.inputs = last_n_checkpoints( | |
args.inputs, | |
num, | |
is_update_based, | |
upper_bound=args.checkpoint_upper_bound, | |
) | |
print("averaging checkpoints: ", args.inputs) | |
new_state = average_checkpoints(args.inputs) | |
with PathManager.open(args.output, "wb") as f: | |
torch.save(new_state, f) | |
print("Finished writing averaged checkpoint to {}".format(args.output)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment