Skip to content

Instantly share code, notes, and snippets.

@jeremyjordan
Last active September 20, 2023 12:50
Show Gist options
  • Save jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08 to your computer and use it in GitHub Desktop.
Save jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08 to your computer and use it in GitHub Desktop.
Generic calculation of the soft Dice loss used as the objective function in image segmentation tasks.
def soft_dice_loss(y_true, y_pred, epsilon=1e-6):
'''
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
Assumes the `channels_last` format.
# Arguments
y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax)
epsilon: Used for numerical stability to avoid divide by zero errors
# References
V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
https://arxiv.org/abs/1606.04797
More details on Dice loss formulation
https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
'''
# skip the batch and class axis for calculating Dice score
axes = tuple(range(1, len(y_pred.shape)-1))
numerator = 2. * np.sum(y_pred * y_true, axes)
denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)
return 1 - np.mean((numerator + epsilon) / (denominator + epsilon)) # average over classes and batch
# thanks @mfernezir for catching a bug in an earlier version of this implementation!
@jeremyjordan
Copy link
Author

jeremyjordan commented May 24, 2018

Laplace smoothing is sometimes used, this can be implemented trivially by adding 1 to both the numerator and denominator.

@daltongriner
Copy link

Can this loss be minimized using a standard optimizer? Like ADAM?

@xiaomaxiao
Copy link

use image dice loss or batch dice loss , need experiment
https://arxiv.org/pdf/1812.02427.pdf Segmentation of Head and Neck Organs at Risk
Using CNN with Batch Dice Loss

@mfernezir
Copy link

@jeremyjordan, thanks for the implementation, and especially the reference to the original dice loss thesis, which gives an argument why, at least in theory, the formulation with the squares is better.

However, there is a critical bug in your implementation.

Let's say we have two channels, just 1 example, and ground truth equal to the prediction such that we have all zeroes on one channel, and all ones on the other one. This should have dice loss zero.

What happens in your implementation is that you end up dividing [0, 2] / [epsilon, epsilon + 2] before you do the final averaging across channels. And since there are two channels, your implementation will average 0 and (approximately) 1 to get 0.5.

We should add epsilon to zero as well so that the first channel also gives approximately 1.

You can check e.g. y_pred = y_true = np.array([0., 1.]).reshape(1,1,1,2). Your function will return 0.5, instead of 0.0 for the loss.

TLDR:

Line 25 should be
return 1 - np.mean((numerator + epsilon) / (denominator + epsilon))

@jeremyjordan
Copy link
Author

@mfernezir thank you for this very detailed example! your example makes total sense, I've updated the gist with a correction :)

@mfernezir
Copy link

👍

@yuleichin
Copy link

Good!

@MasoudAali
Copy link

The combination of Tensor flow 2.5 and python 3.9.5 gave me the following error, when I am running training a model :
Cannot convert a symbolic Tensor (soft_dice_loss/mul:0) to a NumPy array.
I modified Jeremy's code (thanks btw for sharing!) by replacing NumPy functions with their tf equivalents, and here is the result:
def soft_dice_loss(y_true, y_pred, epsilon=1e-6): # skip the batch and class axis for calculating Dice score axes = tuple(range(1, len(y_pred.shape)-1)) numerator = 2. * tf.reduce_sum(y_pred * y_true, axes) denominator = tf.reduce_sum(tf.square(y_pred) + tf.square(y_true), axes) result = 1 - tf.reduce_mean((numerator + epsilon) / (denominator + epsilon)) return result # average over classes and batch
The code does not throw any errors. However, I am not %100 that there is no miscalculations in it. So use it at your own risk...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment