Skip to content

Instantly share code, notes, and snippets.

@maciejskorski
Created April 19, 2020 13:38
Show Gist options
  • Save maciejskorski/e0274d78139e7f874001f3dfbb279e5f to your computer and use it in GitHub Desktop.
Save maciejskorski/e0274d78139e7f874001f3dfbb279e5f to your computer and use it in GitHub Desktop.
def SparseCategoricalCrossentropy(labels,logits):
''' labels: shape [n_batch] contains true classes as numbers from 0 to n_classes-1
logits: shape [n_batch,n_classes], predicted log probabilities '''
Z = tf.reduce_logsumexp(logits,axis=-1)
lookup_labels = tf.stack([tf.range(tf.shape(labels)[0]),tf.cast(labels,tf.int32)],1)
true_logits = tf.gather_nd(logits,lookup_labels,batch_dims=0)
return -true_logits + Z
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment