Last active
April 3, 2023 07:50
-
-
Save innat/20b25854e29acfcd9dc07eb1c5da2c49 to your computer and use it in GitHub Desktop.
Weighted BinaryCrossEntropy Loss in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def weighted_binary_loss(weight, from_logits=True, reduction="mean"): | |
def inverse_sigmoid(sigmoidal): | |
return - tf.math.log(1. / sigmoidal - 1.) | |
def weighted_loss(labels, predictions): | |
predictions = tf.convert_to_tensor(predictions) | |
labels = tf.cast(labels, predictions.dtype) | |
num_samples = tf.cast(tf.shape(labels)[-1], dtype=labels.dtype) | |
logits = tf.cond( | |
tf.cast(from_logits, dtype=tf.bool), | |
lambda: predictions, | |
lambda: inverse_sigmoid(sigmoidal=predictions), | |
) | |
loss = tf.nn.weighted_cross_entropy_with_logits( | |
tf.cast(labels, dtype=tf.float32), logits, pos_weight=weight | |
) | |
if reduction.lower() == "mean": | |
return tf.reduce_mean(loss) | |
elif reduction.lower() == "sum": | |
return tf.reduce_sum(loss) / num_samples | |
else: | |
raise ValueError( | |
'Reduction type is should be `mean` or `sum`', | |
f'But, received {reduction}' | |
) | |
return weighted_loss | |
# ------------- Test 1 | |
# samples | |
y_true = [[0, 1, 0], [0, 0, 1]] | |
y_pred = [[-18.6, 0.51, 0.21], [2.94, -12.8, 10.3]] | |
# Official | |
bce = tf.keras.losses.BinaryCrossentropy( | |
from_logits=True, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE | |
) | |
bce(y_true, y_pred).numpy() | |
weighted_binary_loss(weight=1, from_logits=True, reduction="mean")(y_true, y_pred) | |
# 0.7109192 | |
# 0.7109192 | |
# ------------- Test 2 | |
# samples | |
y_true = [[0, 1, 0], [0, 0, 1]] | |
y_pred = [[-18.6, 0.51, 0.21], [2.94, -12.8, 10.3]] | |
y_pred = tf.nn.sigmoid(y_pred) | |
# Official | |
bce = tf.keras.losses.BinaryCrossentropy( | |
from_logits=False, reduction=tf.keras.losses.Reduction.SUM | |
) | |
bce(y_true, y_pred).numpy() | |
weighted_binary_loss(weight=1, from_logits=True, reduction="sum")(y_true, y_pred) | |
# 1.4218384 | |
# 1.4218384 | |
# ------------------- Test 3 | |
# Add weight 5 | |
weighted_binary_loss(weight=5, from_logits=False, reduction="sum")(y_true, y_pred) | |
<tf.Tensor: shape=(), dtype=float32, numpy=2.0489676> | |
# What is Weight? | |
# A value `weight > 1` decreases the false negative count, | |
# hence increasing the recall. Conversely setting `weight < 1` decreases the | |
# false positive count and increases the precision. | |
# More details: tf.nn.weighted_cross_entropy_with_logits |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment