-
-
Save wassname/f1452b748efcbeb4cb9b1d059dce6f96 to your computer and use it in GitHub Desktop.
from keras import backend as K | |
def jaccard_distance_loss(y_true, y_pred, smooth=100): | |
""" | |
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) | |
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) | |
The jaccard distance loss is usefull for unbalanced datasets. This has been | |
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing | |
gradient. | |
Ref: https://en.wikipedia.org/wiki/Jaccard_index | |
@url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96 | |
@author: wassname | |
""" | |
intersection = K.sum(K.abs(y_true * y_pred), axis=-1) | |
sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) | |
jac = (intersection + smooth) / (sum_ - intersection + smooth) | |
return (1 - jac) * smooth | |
# Test and plot | |
y_pred = np.array([np.arange(-10, 10+0.1, 0.1)]).T | |
y_true = np.zeros(y_pred.shape) | |
name='jaccard_distance_loss' | |
try: | |
loss = jaccard_distance_loss( | |
K.variable(y_true),K.variable(y_pred) | |
).eval(session=K.get_session()) | |
except Exception as e: | |
print("error plotting", name ,e) | |
else: | |
plt.title(name) | |
plt.plot(y_pred,loss) | |
plt.show() | |
# Test | |
# Test | |
print("TYPE |Almost_right |half right |all_wrong") | |
y_true = np.array([[0,0,1,0],[0,0,1,0],[0,0,1.,0.]]) | |
y_pred = np.array([[0,0,0.9,0],[0,0,0.1,0],[1,1,0.1,1.]]) | |
r = jaccard_distance_loss( | |
K.variable(y_true), | |
K.variable(y_pred), | |
).eval(session=K.get_session()) | |
print('jaccard_distance_loss',r) | |
assert r[0]<r[1] | |
assert r[1]<r[2] | |
r = keras.losses.binary_crossentropy( | |
K.variable(y_true), | |
K.variable(y_pred), | |
).eval(session=K.get_session()) | |
print('binary_crossentropy',r) | |
print('binary_crossentropy_scaled',r/r.max()) | |
assert r[0]<r[1] | |
assert r[1]<r[2] | |
""" | |
TYPE |Almost_right |half right |all_wrong | |
jaccard_distance_loss [ 0.09900928 0.89108944 3.75000238] | |
binary_crossentropy [ 0.02634021 0.57564634 12.53243446] | |
binary_crossentropy_scaled [ 0.00210176 0.04593252 1. ] | |
""" |
In my case, most of the images (cell segmentation) are foreground rather than background, and the training samples are very unbalanced. I used binary cross entropy and the model turned to output all the white mask images of the model.
Do you have any good Suggestions about which loss to use?
I once tried to exchange foreground and background in the calculation of loss function then calculate dice_loss, but keras would prompt the error of gradient disappearing.
Thank you very much!
You helped a lot @wassname, thanks! Could you please explain a little about the parameter "smooth"? why did you add it and why the number is 100?
I wanted it to move towards 1 like other losses, that's why I added (1- jac). And I wanted it to have a similar curve to mean squared error when plotted, that's why I chose 100. Ideally, I wanted a similar gradient, from a similar change in loss, at all possible values of X, this way it learns equally well when it's close and far from convergence.
You can plot it with different values and compare it to MSE and MAE losses too see, that's what I did to select 100.
Have similar questions here:
keras-team/keras-contrib#329