Skip to content

Instantly share code, notes, and snippets.

@AFAgarap
Last active September 29, 2019 05:19
Show Gist options
  • Save AFAgarap/7f2221003a3b25f72ed648c87f75c048 to your computer and use it in GitHub Desktop.
Save AFAgarap/7f2221003a3b25f72ed648c87f75c048 to your computer and use it in GitHub Desktop.
Custom training loop for a TensorFlow 2.0 Subclassing API model. Link to blog: https://towardsdatascience.com/how-can-i-trust-you-fb433a06256c?source=friends_link&sk=0af208dc53be2a326d2407577184686b
epochs = 60
def loss(actual, predicted):
crossentropy_loss = tf.losses.categorical_crossentropy(actual, predicted)
average_loss = tf.reduce_mean(crossentropy_loss)
return average_loss
def train(train_dataset, validation_dataset, epochs, learning_rate=1e-1, momentum=9e-1, decay=1e-6):
optimizer = tf.optimizers.SGD(learning_rate=learning_rate=,
momentum=momentum,
decay=decay)
for epoch in range(epochs):
epoch_loss = []
train_accuracy = []
validation_accuracy = []
for train_batch, validation_batch in zip(train_dataset, validation_dataset):
train_batch_features, train_batch_labels = train_batch
validation_batch_features, validation_batch_labels = validation_batch
with tf.GradientTape() as tape:
predictions = model(train_batch_features)
train_loss = loss(train_batch_labels, predictions)
gradients = tape.gradient(train_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
epoch_loss.append(train_loss)
accuracy = tf.metrics.Accuracy()
accuracy(tf.argmax(train_batch_labels, 1),
tf.argmax(predictions, 1))
train_accuracy.append(accuracy.result())
validation_predictions = model(validation_batch_features)
accuracy = tf.metrics.Accuracy()
accuracy(tf.argmax(validation_batch_labels, 1),
tf.argmax(validation_predictions, 1))
validation_accuracy.append(accuracy.result())
epoch_loss = tf.reduce_mean(epoch_loss)
train_accuracy = tf.reduce_mean(train_accuracy)
validation_accuracy = tf.reduce_mean(validation_accuracy)
print('Epoch {} / {} : train loss = {}, train accuracy = {}, validation accuracy = {}'.format(epoch_loss,
train_accuracy,
validation_accuracy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment