Created
May 11, 2021 10:33
-
-
Save ivanpanshin/3822ba024795636fd0f48a04f503fb79 to your computer and use it in GitHub Desktop.
swa for pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def swa(paths): | |
state_dicts = [] | |
for path in paths: | |
state_dicts.append(torch.load(path)["model_state_dict"]) | |
average_dict = OrderedDict() | |
for k in state_dicts[0].keys(): | |
average_dict[k] = sum([state_dict[k] for state_dict in state_dicts]) / len(state_dicts) | |
return average_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment