Skip to content

Instantly share code, notes, and snippets.

@seahrh
Last active October 8, 2020 18:37
Show Gist options
  • Save seahrh/19c8779e159da35bcdc696245a2b24f6 to your computer and use it in GitHub Desktop.
Save seahrh/19c8779e159da35bcdc696245a2b24f6 to your computer and use it in GitHub Desktop.
Extends the `keras.callbacks.ModelCheckpoint` callback to save checkpoints in Google Cloud Storage (GCS). Based on Tensorflow 2.3.
from tensorflow import keras
from tensorflow.python.lib.io import file_io
class ModelCheckpointInGcs(keras.callbacks.ModelCheckpoint):
def __init__(
self,
filepath,
gcs_dir: str,
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
save_freq="epoch",
options=None,
**kwargs,
):
super().__init__(
filepath,
monitor=monitor,
verbose=verbose,
save_best_only=save_best_only,
save_weights_only=save_weights_only,
mode=mode,
save_freq=save_freq,
options=options,
**kwargs,
)
self._gcs_dir = gcs_dir
def _save_model(self, epoch, logs):
super()._save_model(epoch, logs)
filepath = self._get_file_path(epoch, logs)
if os.path.isfile(filepath):
with file_io.FileIO(filepath, mode="rb") as inp:
with file_io.FileIO(
os.path.join(self._gcs_dir, filepath), mode="wb+"
) as out:
out.write(inp.read())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment