Skip to content

Instantly share code, notes, and snippets.

@himanshurawlani
Last active March 21, 2019 20:54
Show Gist options
  • Save himanshurawlani/8f60b3ffa31de5a29fbeba05221e9b7f to your computer and use it in GitHub Desktop.
Save himanshurawlani/8f60b3ffa31de5a29fbeba05221e9b7f to your computer and use it in GitHub Desktop.
Training a keras model using fit() method
def train_model(model):
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Creating Keras callbacks
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
'training_checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5', period=5)
os.makedirs('training_checkpoints/', exist_ok=True)
early_stopping_checkpoint = keras.callbacks.EarlyStopping(patience=5)
history = model.fit(train.repeat(),
epochs=epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation.repeat(),
validation_steps=validation_steps,
callbacks=[tensorboard_callback,
model_checkpoint_callback,
early_stopping_checkpoint])
return history
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment