Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save VolkerH/e9d6183e73c71e5af72f4efc76317c38 to your computer and use it in GitHub Desktop.
Save VolkerH/e9d6183e73c71e5af72f4efc76317c38 to your computer and use it in GitHub Desktop.
Custom loss function for weighted binary crossentropy in Keras with Tensorflow
# code picked up from https://github.com/keras-team/keras/blob/master/keras/backend/tensorflow_backend.py
# Just used tf.nn.weighted_cross_entropy_with_logits instead of tf.nn.sigmoid_cross_entropy_with_logits with input pos_weight in calculation
import tensorflow as tf
from keras import backend as K
""" Weighted binary crossentropy between an output tensor and a target tensor.
# Arguments
pos_weight: A coefficient to use on the positive examples.
# Returns
A loss function supposed to be used in model.compile().
"""
def weighted_binary_crossentropy(pos_weight=1):
def _to_tensor(x, dtype):
"""Convert the input `x` to a tensor of type `dtype`.
# Arguments
x: An object to be converted (numpy array, list, tensors).
dtype: The destination type.
# Returns
A tensor.
"""
return tf.convert_to_tensor(x, dtype=dtype)
def _calculate_weighted_binary_crossentropy(target, output, from_logits=False):
"""Calculate weighted binary crossentropy between an output tensor and a target tensor.
# Arguments
target: A tensor with the same shape as `output`.
output: A tensor.
from_logits: Whether `output` is expected to be a logits tensor.
By default, we consider that `output`
encodes a probability distribution.
# Returns
A tensor.
"""
# Note: tf.nn.sigmoid_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
if not from_logits:
# transform back to logits
_epsilon = _to_tensor(K.common.epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, _epsilon, 1 - _epsilon)
output = tf.log(output / (1 - output))
return tf.nn.weighted_cross_entropy_with_logits(targets=target, logits=output, pos_weight=pos_weight)
def _weighted_binary_crossentropy(y_true, y_pred):
return K.mean(_calculate_weighted_binary_crossentropy(y_true, y_pred), axis=-1)
return _weighted_binary_crossentropy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment