Last active
February 5, 2021 15:58
-
-
Save ntakouris/d21b7f6d027359ca332f158dfa77d88d to your computer and use it in GitHub Desktop.
Multi Hot Binary Accuracy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultiHotAccuracy(keras.metrics.Metric): | |
def __init__(self, name='multihot_accuracy', absolute=False, **kwargs): | |
name = f'multihot{"_absolute" if absolute else ""}_accuracy' | |
super().__init__(name=name, **kwargs) | |
self.absolute = absolute | |
self.binary_accuracy = self.add_weight(name='binary_accuracy', shape=(1,), initializer='zeros', dtype=tf.float32) | |
self.num_samples = self.add_weight(name='num_samples', shape=(1,), initializer='zeros', dtype=tf.int32) | |
def update_state(self, y_true, y_pred, **kwargs): | |
bs = tf.shape(y_true)[0] | |
pred_tensor_arr = tf.TensorArray(tf.float32, tf.shape(y_true)[1], dynamic_size=True, infer_shape=False) | |
y_pred_arr = pred_tensor_arr.scatter(indices=tf.range(bs), value=y_pred) | |
true_tensor_arr = tf.TensorArray(tf.int32, tf.shape(y_true)[1], dynamic_size=True, infer_shape=False) | |
num_positive_labels = true_tensor_arr.scatter(indices=tf.range(bs), value=tf.reduce_sum(tf.cast(y_true, tf.int32), axis=-1)) | |
min_topk_arr = tf.TensorArray(tf.float32, tf.shape(y_true)[1], dynamic_size=True, infer_shape=False) | |
_y_pred = y_pred_arr.gather(tf.range(bs)) | |
_num_positive_labels = num_positive_labels.gather(tf.range(bs)) | |
for i in tf.range(bs): | |
y_pred_i = _y_pred[i] | |
num_positive_labels_i = _num_positive_labels[i] | |
res = tf.math.top_k(y_pred_i, k=num_positive_labels_i)[0][-1] | |
min_topk_arr = min_topk_arr.write(i, res) | |
min_topk = min_topk_arr.gather(tf.range(bs)) | |
min_topk = tf.expand_dims(min_topk, axis=-1) | |
topk_idx_pred = y_pred >= min_topk | |
topk_idx_pred = tf.cast(topk_idx_pred, dtype=tf.int32) | |
topk_idx_true = y_true == 1 | |
topk_idx_true = tf.cast(topk_idx_true, dtype=tf.int32) | |
b_acc = tf.keras.metrics.binary_accuracy(topk_idx_true, topk_idx_pred) | |
if self.absolute: | |
b_acc = tf.cast(b_acc == 1.0, tf.int32) | |
res = tf.reduce_sum(b_acc) | |
res = tf.expand_dims(res, -1) | |
self.binary_accuracy.assign_add(tf.cast(res, dtype=tf.float32)) | |
self.num_samples.assign_add(self.num_samples + bs) | |
def result(self): | |
return self.binary_accuracy / tf.cast(self.num_samples, dtype=tf.float32) | |
def reset_states(self): | |
self.binary_accuracy.assign([0.]) | |
self.num_samples.assign([0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment