Created
August 5, 2022 13:11
-
-
Save Rocketknight1/242e891c87c736fc2e7cc886332a8d82 to your computer and use it in GitHub Desktop.
Keras metrics for use with Hugging Face transformers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class MaskedAccuracy(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, clm=False, label_to_ignore=-100, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.label_to_ignore = label_to_ignore | |
self.correct_predictions = self.add_weight(name='correct_predictions', initializer='zeros', dtype=tf.int64) | |
self.all_predictions = self.add_weight(name='all_predictions', initializer='zeros', dtype=tf.int64) | |
self.clm = clm | |
def update_state(self, y_true, y_pred, sample_weight=None): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
if self.clm: | |
class_predictions = class_predictions[:, :-1] | |
y_true = y_true[:, 1:] | |
self.correct_predictions.assign_add(tf.math.count_nonzero((class_predictions == y_true) & (y_true != self.label_to_ignore))) | |
self.all_predictions.assign_add(tf.math.count_nonzero(y_true != self.label_to_ignore)) | |
def result(self): | |
return tf.cast(self.correct_predictions, tf.float32) / tf.cast(self.all_predictions, tf.float32) | |
class MultiClassPrecision(tf.keras.metrics.Metric): | |
def __init__(self, label_id_to_name, label_to_ignore=-100, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
if sorted(label_id_to_name.keys()) != list(range(max(label_id_to_name.keys()))): | |
raise ValueError("label_id_to_name should be a dict whose keys are sequential integers from 0!") | |
self.num_classes = len(label_id_to_name) | |
self.label_id_to_name = label_id_to_name | |
self.label_to_ignore = label_to_ignore | |
self.per_class_statistics = False | |
self.tp = self.add_weight(shape=(self.num_classes,), name="tp", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(shape=(self.num_classes,), name="tp_plus_fp", | |
initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
true_positive_weights_mask = tf.cast(class_predictions == y_true, dtype=tf.int64) | |
self.tp_plus_fp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes)) | |
self.tp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes, | |
weights=true_positive_weights_mask)) | |
def result(self): | |
precisions = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
return {f"{class_name}_precision": precisions[i] for i, class_name in self.label_id_to_name.items()} | |
class MultiClassRecall(tf.keras.metrics.Metric): | |
def __init__(self, label_id_to_name, label_to_ignore=-100, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
if sorted(label_id_to_name.keys()) != list(range(max(label_id_to_name.keys()))): | |
raise ValueError("label_id_to_name should be a dict whose keys are sequential integers from 0!") | |
self.num_classes = len(label_id_to_name) | |
self.label_id_to_name = label_id_to_name | |
self.label_to_ignore = label_to_ignore | |
self.tp = self.add_weight(shape=(self.num_classes,), name="tp", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(shape=(self.num_classes,), name="tp_plus_fn", | |
initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
# Mask with a too-high value that will be ignored by bincount | |
y_true = tf.where(y_true != self.label_to_ignore, y_true, self.num_classes) | |
true_positive_weights_mask = tf.cast(tf.math.argmax(y_pred, axis=-1) == y_true, dtype=tf.int64) | |
self.tp_plus_fn.assign_add(tf.math.bincount(y_true, minlength=self.num_classes, maxlength=self.num_classes)) | |
self.tp.assign_add(tf.math.bincount(y_true, minlength=self.num_classes, maxlength=self.num_classes, | |
weights=true_positive_weights_mask)) | |
def result(self): | |
recalls = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
return {f"{class_name}_recall": recalls[i] for i, class_name in self.label_id_to_name.items()} | |
class MultiClassF1(tf.keras.metrics.Metric): | |
def __init__(self, label_id_to_name, label_to_ignore=-100, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
if sorted(label_id_to_name.keys()) != list(range(max(label_id_to_name.keys()))): | |
raise ValueError("label_id_to_name should be a dict whose keys are sequential integers from 0!") | |
self.num_classes = len(label_id_to_name) | |
self.label_id_to_name = label_id_to_name | |
self.label_to_ignore = label_to_ignore | |
self.tp = self.add_weight(shape=(self.num_classes,), name="tp", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(shape=(self.num_classes,), name="tp_plus_fn", | |
initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(shape=(self.num_classes,), name="tp_plus_fp", | |
initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
y_true = tf.where(y_true != self.label_to_ignore, y_true, self.num_classes) | |
true_positive_weights_mask = tf.cast(class_predictions == y_true, dtype=tf.int64) | |
self.tp_plus_fp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes)) | |
self.tp.assign_add(tf.math.bincount(class_predictions, minlength=self.num_classes, maxlength=self.num_classes, | |
weights=true_positive_weights_mask)) | |
self.tp_plus_fn.assign_add(tf.math.bincount(y_true, minlength=self.num_classes, maxlength=self.num_classes)) | |
def result(self): | |
precisions = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
recalls = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
f1_scores = (2 * precisions * recalls) / tf.clip_by_value(precisions + recalls, 1., float(1e12)) | |
return {f"{class_name}_f1": f1_scores[i] for i, class_name in self.label_id_to_name.items()} | |
class Precision(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.tp = self.add_weight(name="tp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(name="tp_plus_fp", initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
self.tp_plus_fp.assign_add(tf.math.count_nonzero(class_predictions)) | |
self.tp.assign_add(tf.math.count_nonzero(class_predictions == 1 & y_true == 1)) | |
def result(self): | |
precision = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
return {"Precision": precision} | |
class BinaryRecall(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.tp = self.add_weight(name="tp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(name="tp_plus_fn", initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
self.tp_plus_fn.assign_add(tf.math.count_nonzero(y_true)) | |
self.tp.assign_add(tf.math.count_nonzero(class_predictions == 1 & y_true == 1)) | |
def result(self): | |
recall = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
return {"Recall": recall} | |
class BinaryF1(tf.keras.metrics.Metric): | |
def __init__(self, name=None, dtype=None, **kwargs): | |
super().__init__(name=name, dtype=dtype, **kwargs) | |
self.tp = self.add_weight(name="tp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fp = self.add_weight(name="tp_plus_fp", initializer='zeros', dtype=tf.int64) | |
self.tp_plus_fn = self.add_weight(name="tp_plus_fn", initializer='zeros', dtype=tf.int64) | |
def update_state(self, y_true, y_pred): | |
class_predictions = tf.math.argmax(y_pred, axis=-1) | |
self.tp_plus_fp.assign_add(tf.math.count_nonzero(class_predictions)) | |
self.tp_plus_fn.assign_add(tf.math.count_nonzero(y_true)) | |
self.tp.assign_add(tf.math.count_nonzero(class_predictions == 1 & y_true == 1)) | |
def result(self): | |
precision = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fp, 1, 1e12), tf.float32) | |
recall = tf.cast(self.tp, tf.float32) / tf.cast(tf.clip_by_value(self.tp_plus_fn, 1, 1e12), tf.float32) | |
return {"F1": (2 * precision * recall) / (precision + recall)} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment