Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active May 10, 2022 17:03
Show Gist options
  • Save AlessandroMondin/d7925e8aac35473e9e6ae10f84ed20a9 to your computer and use it in GitHub Desktop.
Save AlessandroMondin/d7925e8aac35473e9e6ae10f84ed20a9 to your computer and use it in GitHub Desktop.
def cross_entropy(self, scaled_logits, one_hot):
if self.library == "tf":
masked_logits = tf.boolean_mask(scaled_logits, one_hot)
ce = -tf.math.log(masked_logits)
else:
masked_logits = torch.masked_select(scaled_logits, one_hot)
ce = -torch.log(masked_logits)
return ce
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment