Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Last active February 5, 2021 15:58
Show Gist options
  • Save ntakouris/d21b7f6d027359ca332f158dfa77d88d to your computer and use it in GitHub Desktop.
Save ntakouris/d21b7f6d027359ca332f158dfa77d88d to your computer and use it in GitHub Desktop.
Multi Hot Binary Accuracy
class MultiHotAccuracy(keras.metrics.Metric):
def __init__(self, name='multihot_accuracy', absolute=False, **kwargs):
name = f'multihot{"_absolute" if absolute else ""}_accuracy'
super().__init__(name=name, **kwargs)
self.absolute = absolute
self.binary_accuracy = self.add_weight(name='binary_accuracy', shape=(1,), initializer='zeros', dtype=tf.float32)
self.num_samples = self.add_weight(name='num_samples', shape=(1,), initializer='zeros', dtype=tf.int32)
def update_state(self, y_true, y_pred, **kwargs):
bs = tf.shape(y_true)[0]
pred_tensor_arr = tf.TensorArray(tf.float32, tf.shape(y_true)[1], dynamic_size=True, infer_shape=False)
y_pred_arr = pred_tensor_arr.scatter(indices=tf.range(bs), value=y_pred)
true_tensor_arr = tf.TensorArray(tf.int32, tf.shape(y_true)[1], dynamic_size=True, infer_shape=False)
num_positive_labels = true_tensor_arr.scatter(indices=tf.range(bs), value=tf.reduce_sum(tf.cast(y_true, tf.int32), axis=-1))
min_topk_arr = tf.TensorArray(tf.float32, tf.shape(y_true)[1], dynamic_size=True, infer_shape=False)
_y_pred = y_pred_arr.gather(tf.range(bs))
_num_positive_labels = num_positive_labels.gather(tf.range(bs))
for i in tf.range(bs):
y_pred_i = _y_pred[i]
num_positive_labels_i = _num_positive_labels[i]
res = tf.math.top_k(y_pred_i, k=num_positive_labels_i)[0][-1]
min_topk_arr = min_topk_arr.write(i, res)
min_topk = min_topk_arr.gather(tf.range(bs))
min_topk = tf.expand_dims(min_topk, axis=-1)
topk_idx_pred = y_pred >= min_topk
topk_idx_pred = tf.cast(topk_idx_pred, dtype=tf.int32)
topk_idx_true = y_true == 1
topk_idx_true = tf.cast(topk_idx_true, dtype=tf.int32)
b_acc = tf.keras.metrics.binary_accuracy(topk_idx_true, topk_idx_pred)
if self.absolute:
b_acc = tf.cast(b_acc == 1.0, tf.int32)
res = tf.reduce_sum(b_acc)
res = tf.expand_dims(res, -1)
self.binary_accuracy.assign_add(tf.cast(res, dtype=tf.float32))
self.num_samples.assign_add(self.num_samples + bs)
def result(self):
return self.binary_accuracy / tf.cast(self.num_samples, dtype=tf.float32)
def reset_states(self):
self.binary_accuracy.assign([0.])
self.num_samples.assign([0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment