Forked from ksugar/keras_weighted_binary_crossentropy.py
Created
February 9, 2019 14:27
-
-
Save VolkerH/e9d6183e73c71e5af72f4efc76317c38 to your computer and use it in GitHub Desktop.
Custom loss function for weighted binary crossentropy in Keras with Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# code picked up from https://github.com/keras-team/keras/blob/master/keras/backend/tensorflow_backend.py | |
# Just used tf.nn.weighted_cross_entropy_with_logits instead of tf.nn.sigmoid_cross_entropy_with_logits with input pos_weight in calculation | |
import tensorflow as tf | |
from keras import backend as K | |
""" Weighted binary crossentropy between an output tensor and a target tensor. | |
# Arguments | |
pos_weight: A coefficient to use on the positive examples. | |
# Returns | |
A loss function supposed to be used in model.compile(). | |
""" | |
def weighted_binary_crossentropy(pos_weight=1): | |
def _to_tensor(x, dtype): | |
"""Convert the input `x` to a tensor of type `dtype`. | |
# Arguments | |
x: An object to be converted (numpy array, list, tensors). | |
dtype: The destination type. | |
# Returns | |
A tensor. | |
""" | |
return tf.convert_to_tensor(x, dtype=dtype) | |
def _calculate_weighted_binary_crossentropy(target, output, from_logits=False): | |
"""Calculate weighted binary crossentropy between an output tensor and a target tensor. | |
# Arguments | |
target: A tensor with the same shape as `output`. | |
output: A tensor. | |
from_logits: Whether `output` is expected to be a logits tensor. | |
By default, we consider that `output` | |
encodes a probability distribution. | |
# Returns | |
A tensor. | |
""" | |
# Note: tf.nn.sigmoid_cross_entropy_with_logits | |
# expects logits, Keras expects probabilities. | |
if not from_logits: | |
# transform back to logits | |
_epsilon = _to_tensor(K.common.epsilon(), output.dtype.base_dtype) | |
output = tf.clip_by_value(output, _epsilon, 1 - _epsilon) | |
output = tf.log(output / (1 - output)) | |
return tf.nn.weighted_cross_entropy_with_logits(targets=target, logits=output, pos_weight=pos_weight) | |
def _weighted_binary_crossentropy(y_true, y_pred): | |
return K.mean(_calculate_weighted_binary_crossentropy(y_true, y_pred), axis=-1) | |
return _weighted_binary_crossentropy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment