Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Created May 21, 2018 18:42
Show Gist options
  • Save mrdrozdov/2d291d39c7cc7ea6a60bb670143d852f to your computer and use it in GitHub Desktop.
Save mrdrozdov/2d291d39c7cc7ea6a60bb670143d852f to your computer and use it in GitHub Desktop.
test_checkpoints.py
import unittest
import tempfile
import os
import torch
def models_eq(a, b):
"""
Not the most elegant method. Compares two models' parameters.
"""
sa = a.state_dict()
sb = b.state_dict()
sak = set(sa.keys())
sbk = set(sb.keys())
eq = len(sak) == len(sbk)
if not eq:
return eq
eq = len(set.intersection(sak, sbk)) == len(sak)
if not eq:
return eq
for k in sak:
eq = sa[k].shape == sb[k].shape
if not eq:
return eq
eq = sa[k].eq(sb[k]).all()
if not eq:
return eq
return eq
class TestCheckpoints(unittest.TestCase):
def test_save_load(self):
shape = (3, 3)
expected = torch.nn.Linear(*shape)
actual = torch.nn.Linear(*shape)
_, path = tempfile.mkstemp()
try:
assert models_eq(expected, actual) == False, "Models are randomly initialized and should not be equal."
torch.save(expected.state_dict(), path)
state_dict = torch.load(path)
actual.load_state_dict(state_dict)
assert models_eq(expected, actual) == True, "Models should be equal after loading checkpoint."
except AssertionError as e:
raise e
except:
raise RuntimeError('Test failed!')
finally:
os.remove(path)
if __name__ == '__main__':
unittest.main()
@mrdrozdov
Copy link
Author

python3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment