-
-
Save jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08 to your computer and use it in GitHub Desktop.
def soft_dice_loss(y_true, y_pred, epsilon=1e-6): | |
''' | |
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions. | |
Assumes the `channels_last` format. | |
# Arguments | |
y_true: b x X x Y( x Z...) x c One hot encoding of ground truth | |
y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) | |
epsilon: Used for numerical stability to avoid divide by zero errors | |
# References | |
V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation | |
https://arxiv.org/abs/1606.04797 | |
More details on Dice loss formulation | |
https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72) | |
Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022 | |
''' | |
# skip the batch and class axis for calculating Dice score | |
axes = tuple(range(1, len(y_pred.shape)-1)) | |
numerator = 2. * np.sum(y_pred * y_true, axes) | |
denominator = np.sum(np.square(y_pred) + np.square(y_true), axes) | |
return 1 - np.mean((numerator + epsilon) / (denominator + epsilon)) # average over classes and batch | |
# thanks @mfernezir for catching a bug in an earlier version of this implementation! |
Good!
The combination of Tensor flow 2.5 and python 3.9.5 gave me the following error, when I am running training a model :
Cannot convert a symbolic Tensor (soft_dice_loss/mul:0) to a NumPy array.
I modified Jeremy's code (thanks btw for sharing!) by replacing NumPy functions with their tf equivalents, and here is the result:
def soft_dice_loss(y_true, y_pred, epsilon=1e-6): # skip the batch and class axis for calculating Dice score axes = tuple(range(1, len(y_pred.shape)-1)) numerator = 2. * tf.reduce_sum(y_pred * y_true, axes) denominator = tf.reduce_sum(tf.square(y_pred) + tf.square(y_true), axes) result = 1 - tf.reduce_mean((numerator + epsilon) / (denominator + epsilon)) return result # average over classes and batch
The code does not throw any errors. However, I am not %100 that there is no miscalculations in it. So use it at your own risk...
👍