Skip to content

Instantly share code, notes, and snippets.

@shashankprasanna
Created April 24, 2020 22:00
Show Gist options
  • Save shashankprasanna/9e962d257d55f960ad49cdd9329c0290 to your computer and use it in GitHub Desktop.
Save shashankprasanna/9e962d257d55f960ad49cdd9329c0290 to your computer and use it in GitHub Desktop.
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint_path = "/opt/ml/checkpoints"
checkpoint_names = 'cifar10-'+model_type+'.{epoch:03d}.h5'
checkpoint_callback = ModelCheckpoint(filepath=f'{checkpoint_path}/{checkpoint_names}',
save_weights_only=False,
monitor='val_loss')
model.fit(train_dataset, ...
epochs=epochs,
initial_epoch=epoch_number,
callbacks=[checkpoint_callback])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment