Skip to content

Instantly share code, notes, and snippets.

@dpoulopoulos
Created July 9, 2021 10:11
Show Gist options
  • Save dpoulopoulos/42265be408aa18fafc7cb0a1fbada4f1 to your computer and use it in GitHub Desktop.
Save dpoulopoulos/42265be408aa18fafc7cb0a1fbada4f1 to your computer and use it in GitHub Desktop.
loss_fn = keras.losses.CategoricalCrossentropy()
optimizer = keras.optimizers.Adam()
train_loss = keras.metrics.Mean(name='train_loss')
train_accuracy = keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = keras.metrics.Mean(name='test_loss')
test_accuracy = keras.metrics.CategoricalAccuracy(name='test_accuracy')
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
def test_step(images, labels):
predictions = model(images, training=False)
t_loss = loss_fn(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment