Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Created May 11, 2022 10:39
Show Gist options
  • Save AlessandroMondin/517906d5a420350fa2d93b8b5ae4593e to your computer and use it in GitHub Desktop.
Save AlessandroMondin/517906d5a420350fa2d93b8b5ae4593e to your computer and use it in GitHub Desktop.
def accuracy(self, y_hat, Y):
if self.library == "tf":
# calculate argmax
argmax = tf.cast(tf.argmax(y_hat, axis=1), Y.dtype)
acc = tf.math.reduce_sum(tf.cast(argmax == Y, tf.int32)) / Y.shape[0]
else:
argmax = torch.argmax(y_hat, dim=1)
acc = torch.sum(torch.eq(argmax, Y)) / Y.shape[0]
return acc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment