Skip to content

Instantly share code, notes, and snippets.

@ivanpanshin
Created May 11, 2021 10:33
Show Gist options
  • Save ivanpanshin/3822ba024795636fd0f48a04f503fb79 to your computer and use it in GitHub Desktop.
Save ivanpanshin/3822ba024795636fd0f48a04f503fb79 to your computer and use it in GitHub Desktop.
swa for pytorch
def swa(paths):
state_dicts = []
for path in paths:
state_dicts.append(torch.load(path)["model_state_dict"])
average_dict = OrderedDict()
for k in state_dicts[0].keys():
average_dict[k] = sum([state_dict[k] for state_dict in state_dicts]) / len(state_dicts)
return average_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment