Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created August 30, 2019 14:52
Show Gist options
  • Save NMZivkovic/597ac27670510cffccb112eefdd4c83c to your computer and use it in GitHub Desktop.
Save NMZivkovic/597ac27670510cffccb112eefdd4c83c to your computer and use it in GitHub Desktop.
loss_objective_function = SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def padded_loss_function(real, prediction):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss = loss_objective_function(real, prediction)
mask = tf.cast(mask, dtype=loss.dtype)
loss *= mask
return tf.reduce_mean(loss)
training_loss = Mean(name='training_loss')
training_accuracy = SparseCategoricalAccuracy(name='training_accuracy')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment