Skip to content

Instantly share code, notes, and snippets.

@Dref360
Created June 21, 2017 12:11
Show Gist options
  • Save Dref360/04c51fbc96314921af3a27905167e234 to your computer and use it in GitHub Desktop.
Save Dref360/04c51fbc96314921af3a27905167e234 to your computer and use it in GitHub Desktop.
class weighted_categorical_crossentropy:
def __init__(self, weights):
self.weights = weights
self.__name__ = 'wcentroid_loss'
def __call__(self, y_true, y_pred):
class0 = K.ones_like(y_pred)[:, :, :, 0] * self.weights[0]
class1 = K.ones_like(y_pred)[:, :, :, 0] * self.weights[1]
x = K.tf.where(y_true[:, :, :, 0] > 0, class0, class1)
result = x * K.categorical_crossentropy(y_pred, y_true)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment