Skip to content

Instantly share code, notes, and snippets.

@himanshurawlani
Last active March 17, 2019 16:57
Show Gist options
  • Save himanshurawlani/3846e581728beae77581d0a7d1900be8 to your computer and use it in GitHub Desktop.
Save himanshurawlani/3846e581728beae77581d0a7d1900be8 to your computer and use it in GitHub Desktop.
Using train_on_batch() method to customize training loop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
# Performing model training for the given number of epochs
for e in range(epochs):
print('Epoch', e)
batches = 0
loss = 0
accuracy = 0
for example in tfds.as_numpy(train):
x_train, y_train = example[0], example[1]
for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=BATCH_SIZE):
curr_loss, curr_accuracy = model.train_on_batch(x_batch, y_batch)
loss += curr_loss
accuracy += curr_accuracy
batches += 1
if batches >= len(x_train) // BATCH_SIZE:
# we need to break the loop by hand because
# the generator loops indefinitely
break
train_losses.append(loss/batches)
train_accuracies.append(accuracy/batches)
print('Train Loss:', loss/batches, 'Train Accuracy:', accuracy/batches)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment