Last active
March 22, 2019 06:22
-
-
Save priancho/63a2ab3072862247ace30e179669ff79 to your computer and use it in GitHub Desktop.
Save the model at the beginning/end of training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomModelCheckpoint(ModelCheckpoint): | |
"""Custom ModelCheckpoint. | |
Save the model at the beginning (a random init model) and at the | |
end of the training. | |
""" | |
def __init__(self, *args, **kwargs): | |
super(CustomModelCheckpoint, self).__init__(*args, **kwargs) | |
self.current_epoch = 0 | |
def on_epoch_begin(self, epoch, logs=None): | |
self.current_epoch = epoch | |
def on_train_begin(self, logs=None): | |
"""Save the initialized model.""" | |
logs = logs or {} | |
epoch = self.current_epoch | |
filepath = self.filepath.format(epoch=epoch, **logs) | |
if self.verbose > 0: | |
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) | |
if self.save_weights_only: | |
self.model.save_weights(filepath, overwrite=True) | |
else: | |
self.model.save(filepath, overwrite=True) | |
def on_train_end(self, logs=None): | |
"""Save the last model.""" | |
logs = logs or {} | |
epoch = self.current_epoch # maximum epochs | |
filepath = self.filepath.format(epoch=epoch+1, **logs) | |
if self.verbose > 0: | |
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) | |
if self.save_weights_only: | |
self.model.save_weights(filepath, overwrite=True) | |
else: | |
self.model.save(filepath, overwrite=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Save the model at the beginning and the END of training.
No need to save the model manually after the training :-)