Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save brookisme/c56d20007dbc928ade6e0b13b1e4f974 to your computer and use it in GitHub Desktop.
Save brookisme/c56d20007dbc928ade6e0b13b1e4f974 to your computer and use it in GitHub Desktop.
Keras: Weighted Categorical Crossentropy
import numpy as np
import tensorflow as tf
import keras.backend as K
def weighted_categorical_crossentropy(weights):
""" weighted_categorical_crossentropy
Args:
* weights<ktensor|nparray|list>: crossentropy weights
Returns:
* weighted categorical crossentropy function
"""
if isinstance(weights,list) or isinstance(np.ndarray):
weights=K.variable(weights)
def loss(target,output,from_logits=False):
if not from_logits:
output /= tf.reduce_sum(output,
len(output.get_shape()) - 1,
True)
_epsilon = tf.convert_to_tensor(K.epsilon(), dtype=output.dtype.base_dtype)
output = tf.clip_by_value(output, _epsilon, 1. - _epsilon)
weighted_losses = target * tf.log(output) * weights
return - tf.reduce_sum(weighted_losses,len(output.get_shape()) - 1)
else:
raise ValueError('WeightedCategoricalCrossentropy: not valid with logits')
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment