Skip to content

Instantly share code, notes, and snippets.

@pangyuteng
Last active May 4, 2022 22:26
Show Gist options
  • Save pangyuteng/a1b99a4e3ac4f8142b7181c520929adf to your computer and use it in GitHub Desktop.
Save pangyuteng/a1b99a4e3ac4f8142b7181c520929adf to your computer and use it in GitHub Desktop.
keras multi-class dice
#
# gattia commented on Apr 6, 2018
# https://github.com/keras-team/keras/issues/9395#issuecomment-379276452
#
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
def dice_coef(y_true, y_pred, smooth=1e-7):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_multilabel(y_true, y_pred, weight_list=[0.0,1.0]):
dice=0
for index in range(len(weight_list)):
dice -= weight_list[index]*dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
return dice
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment