Last active
June 11, 2024 08:33
-
-
Save jonnyli1125/5384bb9a41caaac983f1cd737359c6c2 to your computer and use it in GitHub Desktop.
SparseCategoricalCrossentropy with class weights for Keras/Tensorflow 2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Since Model.fit doesn't support class_weight when using multiple outputs, | |
this custom loss subclass may be useful. | |
Relevant issues: | |
https://github.com/keras-team/keras/issues/11735 | |
https://github.com/tensorflow/tensorflow/issues/40457 | |
https://github.com/tensorflow/tensorflow/issues/41448 | |
""" | |
import tensorflow as tf | |
from tensorflow import keras | |
class WeightedSCCE(keras.losses.Loss): | |
def __init__(self, class_weight, from_logits=False, name='weighted_scce'): | |
if class_weight is None or all(v == 1. for v in class_weight): | |
self.class_weight = None | |
else: | |
self.class_weight = tf.convert_to_tensor(class_weight, | |
dtype=tf.float32) | |
self.reduction = keras.losses.Reduction.NONE | |
self.unreduced_scce = keras.losses.SparseCategoricalCrossentropy( | |
from_logits=from_logits, name=name, | |
reduction=self.reduction) | |
def __call__(self, y_true, y_pred, sample_weight=None): | |
loss = self.unreduced_scce(y_true, y_pred, sample_weight) | |
if self.class_weight is not None: | |
weight_mask = tf.gather(self.class_weight, y_true) | |
loss = tf.math.multiply(loss, weight_mask) | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment