Last active
August 24, 2018 17:20
-
-
Save nuric/0c1fde80f0d1d4e703485a48f9c375e6 to your computer and use it in GitHub Desktop.
Stateful Checkpoint for Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import socket | |
from keras.callbacks import ModelCheckpoint | |
class StatefulCheckpoint(ModelCheckpoint): | |
"""Save extra checkpoint data to resume training.""" | |
def __init__(self, weight_file, state_file=None, **kwargs): | |
"""Save the state (epoch etc.) along side weights.""" | |
super().__init__(weight_file, **kwargs) | |
self.state_f = state_file | |
self.hostname = socket.gethostname() | |
self.state = dict() | |
if self.state_f: | |
# Load the last state if any | |
try: | |
with open(self.state_f, 'r') as f: | |
self.state = json.load(f) | |
self.best = self.state['best'] | |
except Exception as e: # pylint: disable=broad-except | |
print("Skipping last state:", e) | |
def on_train_begin(self, logs=None): | |
prefix = "Resuming" if self.state else "Starting" | |
print("{} training on {}".format(prefix, self.hostname)) | |
def on_epoch_end(self, epoch, logs=None): | |
"""Saves training state as well as weights.""" | |
super().on_epoch_end(epoch, logs) | |
if self.state_f: | |
state = {'epoch': epoch+1, 'best': self.best, | |
'hostname': self.hostname} | |
state.update(logs) | |
state.update(self.params) | |
with open(self.state_f, 'w') as f: | |
json.dump(state, f) | |
def get_last_epoch(self, initial_epoch=0): | |
"""Return last saved epoch if any, or return default argument.""" | |
return self.state.get('epoch', initial_epoch) | |
def on_train_end(self, logs=None): | |
print("Training ending on {}".format(self.hostname)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment