Created
August 5, 2022 13:11
-
-
Save Rocketknight1/16f43aa5df728f72e7dbe6b2aca65571 to your computer and use it in GitHub Desktop.
keras_metrics.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class MaskedAccuracy(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, clm=False, label_to_ignore=-100, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.label_to_ignore = label_to_ignore | |
self.correct_predictions = self.add_weight(name='correct_predictions', initializer='zeros', dtype=tf.int64) | |
self.all_predictions = self.add_weight(name='all_predictions', initializer='zeros', dtype=tf.int64) | |
self.clm = clm | |
def update_state(self, y_true, y_pred, sample_weight=None): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
if self.clm: | |
class_predictions = class_predictions[:, :-1] | |
y_true = y_true[:, 1:] | |
self.correct_predictions.assign_add(tf.math.count_nonzero((class_predictions == y_true) & (y_true != self.label_to_ignore))) | |
self.all_predictions.assign_add(tf.math.count_nonzero(y_true != self.label_to_ignore)) | |
def result(self): | |
return tf.cast(self.correct_predictions, tf.float32) / tf.cast(self.all_predictions, tf.float32) | |
class MultiClassPrecision(tf.keras.metrics.Metric): | |
def __init__(self, label_id_to_name, label_to_ignore=-100, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
if sorted(label_id_to_name.keys()) != list(range(max(label_id_to_name.keys()))): | |
raise ValueError("label_id_to_name should be a dict whose keys are sequential integers from 0!") | |
self.num_classes = len(label_id_to_name) | |
self.label_id_to_name = label_id_to_name | |
self.label_to_ignore = label_to_ignore | |
self.per_class_statistics = False | |
self.tp = self.add_weight(shape=(self.num_classes,), name="tp", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(shape=(self.num_classes,), name="tp_plus_fp", | |
initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
true_positive_weights_mask = tf.cast(class_predictions == y_true, dtype=tf.int64) | |
self.tp_plus_fp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes)) | |
self.tp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes, | |
weights=true_positive_weights_mask)) | |
def result(self): | |
precisions = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
return {f"{class_name}_precision": precisions[i] for i, class_name in self.label_id_to_name.items()} | |
class MultiClassRecall(tf.keras.metrics.Metric): | |
def __init__(self, label_id_to_name, label_to_ignore=-100, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
if sorted(label_id_to_name.keys()) != list(range(max(label_id_to_name.keys()))): | |
raise ValueError("label_id_to_name should be a dict whose keys are sequential integers from 0!") | |
self.num_classes = len(label_id_to_name) | |
self.label_id_to_name = label_id_to_name | |
self.label_to_ignore = label_to_ignore | |
self.tp = self.add_weight(shape=(self.num_classes,), name="tp", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(shape=(self.num_classes,), name="tp_plus_fn", | |
initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
# Mask with a too-high value that will be ignored by bincount | |
y_true = tf.where(y_true != self.label_to_ignore, y_true, self.num_classes) | |
true_positive_weights_mask = tf.cast(tf.math.argmax(y_pred, axis=-1) == y_true, dtype=tf.int64) | |
self.tp_plus_fn.assign_add(tf.math.bincount(y_true, minlength=self.num_classes, maxlength=self.num_classes)) | |
self.tp.assign_add(tf.math.bincount(y_true, minlength=self.num_classes, maxlength=self.num_classes, | |
weights=true_positive_weights_mask)) | |
def result(self): | |
recalls = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
return {f"{class_name}_recall": recalls[i] for i, class_name in self.label_id_to_name.items()} | |
class MultiClassF1(tf.keras.metrics.Metric): | |
def __init__(self, label_id_to_name, label_to_ignore=-100, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
if sorted(label_id_to_name.keys()) != list(range(max(label_id_to_name.keys()))): | |
raise ValueError("label_id_to_name should be a dict whose keys are sequential integers from 0!") | |
self.num_classes = len(label_id_to_name) | |
self.label_id_to_name = label_id_to_name | |
self.label_to_ignore = label_to_ignore | |
self.tp = self.add_weight(shape=(self.num_classes,), name="tp", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(shape=(self.num_classes,), name="tp_plus_fn", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(shape=(self.num_classes,), name="tp_plus_fp", | |
initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
y_true = tf.where(y_true != self.label_to_ignore, y_true, self.num_classes) | |
true_positive_weights_mask = tf.cast(class_predictions == y_true, dtype=tf.int64) | |
self.tp_plus_fp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes)) | |
self.tp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes, | |
weights=true_positive_weights_mask)) | |
self.tp_plus_fn.assign_add(tf.math.bincount(y_true, minlength=self.num_classes, maxlength=self.num_classes)) | |
def result(self): | |
precisions = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
recalls = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
f1_scores = (2 * precisions * recalls) / tf.clip_by_value(precisions + recalls, 1., float(1e12)) | |
return {f"{class_name}_f1": f1_scores[i] for i, class_name in self.label_id_to_name.items()} | |
class Precision(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.tp = self.add_weight(name="tp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(name="tp_plus_fp", initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
self.tp_plus_fp.assign_add(tf.math.count_nonzero(class_predictions)) | |
self.tp.assign_add(tf.math.count_nonzero(class_predictions == 1 & y_true == 1)) | |
def result(self): | |
precision = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
return {"Precision": precision} | |
class BinaryRecall(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.tp = self.add_weight(name="tp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(name="tp_plus_fn", initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
self.tp_plus_fn.assign_add(tf.math.count_nonzero(y_true)) | |
self.tp.assign_add(tf.math.count_nonzero(class_predictions == 1 & y_true == 1)) | |
def result(self): | |
recall = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
return {"Recall": recall} | |
class BinaryF1(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.tp = self.add_weight(name="tp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(name="tp_plus_fp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(name="tp_plus_fn", initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
self.tp_plus_fp.assign_add(tf.math.count_nonzero(class_predictions)) | |
self.tp_plus_fn.assign_add(tf.math.count_nonzero(y_true)) | |
self.tp.assign_add(tf.math.count_nonzero(class_predictions == 1 & y_true == 1)) | |
def result(self): | |
precision = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
recall = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
return {"F1": (2 * precision * recall) / (precision + recall)} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment