Skip to content

Instantly share code, notes, and snippets.

@innat
Last active April 3, 2023 07:50
Show Gist options
  • Save innat/20b25854e29acfcd9dc07eb1c5da2c49 to your computer and use it in GitHub Desktop.
Save innat/20b25854e29acfcd9dc07eb1c5da2c49 to your computer and use it in GitHub Desktop.
Weighted BinaryCrossEntropy Loss in Keras
import tensorflow as tf
def weighted_binary_loss(weight, from_logits=True, reduction="mean"):
def inverse_sigmoid(sigmoidal):
return - tf.math.log(1. / sigmoidal - 1.)
def weighted_loss(labels, predictions):
predictions = tf.convert_to_tensor(predictions)
labels = tf.cast(labels, predictions.dtype)
num_samples = tf.cast(tf.shape(labels)[-1], dtype=labels.dtype)
logits = tf.cond(
tf.cast(from_logits, dtype=tf.bool),
lambda: predictions,
lambda: inverse_sigmoid(sigmoidal=predictions),
)
loss = tf.nn.weighted_cross_entropy_with_logits(
tf.cast(labels, dtype=tf.float32), logits, pos_weight=weight
)
if reduction.lower() == "mean":
return tf.reduce_mean(loss)
elif reduction.lower() == "sum":
return tf.reduce_sum(loss) / num_samples
else:
raise ValueError(
'Reduction type is should be `mean` or `sum`',
f'But, received {reduction}'
)
return weighted_loss
# ------------- Test 1
# samples
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[-18.6, 0.51, 0.21], [2.94, -12.8, 10.3]]
# Official
bce = tf.keras.losses.BinaryCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
)
bce(y_true, y_pred).numpy()
weighted_binary_loss(weight=1, from_logits=True, reduction="mean")(y_true, y_pred)
# 0.7109192
# 0.7109192
# ------------- Test 2
# samples
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[-18.6, 0.51, 0.21], [2.94, -12.8, 10.3]]
y_pred = tf.nn.sigmoid(y_pred)
# Official
bce = tf.keras.losses.BinaryCrossentropy(
from_logits=False, reduction=tf.keras.losses.Reduction.SUM
)
bce(y_true, y_pred).numpy()
weighted_binary_loss(weight=1, from_logits=True, reduction="sum")(y_true, y_pred)
# 1.4218384
# 1.4218384
# ------------------- Test 3
# Add weight 5
weighted_binary_loss(weight=5, from_logits=False, reduction="sum")(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=2.0489676>
# What is Weight?
# A value `weight > 1` decreases the false negative count,
# hence increasing the recall. Conversely setting `weight < 1` decreases the
# false positive count and increases the precision.
# More details: tf.nn.weighted_cross_entropy_with_logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment