Last active
June 1, 2021 12:10
-
-
Save Kautenja/69d306c587ccdf464c45d28c1545e580 to your computer and use it in GitHub Desktop.
An implementation of the Intersection over Union (IoU) metric for Keras.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""An implementation of the Intersection over Union (IoU) metric for Keras.""" | |
from keras import backend as K | |
def iou(y_true, y_pred, label: int): | |
""" | |
Return the Intersection over Union (IoU) for a given label. | |
Args: | |
y_true: the expected y values as a one-hot | |
y_pred: the predicted y values as a one-hot or softmax output | |
label: the label to return the IoU for | |
Returns: | |
the IoU for the given label | |
""" | |
# extract the label values using the argmax operator then | |
# calculate equality of the predictions and truths to the label | |
y_true = K.cast(K.equal(K.argmax(y_true), label), K.floatx()) | |
y_pred = K.cast(K.equal(K.argmax(y_pred), label), K.floatx()) | |
# calculate the |intersection| (AND) of the labels | |
intersection = K.sum(y_true * y_pred) | |
# calculate the |union| (OR) of the labels | |
union = K.sum(y_true) + K.sum(y_pred) - intersection | |
# avoid divide by zero - if the union is zero, return 1 | |
# otherwise, return the intersection over union | |
return K.switch(K.equal(union, 0), 1.0, intersection / union) | |
def build_iou_for(label: int, name: str=None): | |
""" | |
Build an Intersection over Union (IoU) metric for a label. | |
Args: | |
label: the label to build the IoU metric for | |
name: an optional name for debugging the built method | |
Returns: | |
a keras metric to evaluate IoU for the given label | |
Note: | |
label and name support list inputs for multiple labels | |
""" | |
# handle recursive inputs (e.g. a list of labels and names) | |
if isinstance(label, list): | |
if isinstance(name, list): | |
return [build_iou_for(l, n) for (l, n) in zip(label, name)] | |
return [build_iou_for(l) for l in label] | |
# build the method for returning the IoU of the given label | |
def label_iou(y_true, y_pred): | |
""" | |
Return the Intersection over Union (IoU) score for {0}. | |
Args: | |
y_true: the expected y values as a one-hot | |
y_pred: the predicted y values as a one-hot or softmax output | |
Returns: | |
the scalar IoU value for the given label ({0}) | |
""".format(label) | |
return iou(y_true, y_pred, label) | |
# if no name is provided, us the label | |
if name is None: | |
name = label | |
# change the name of the method for debugging | |
label_iou.__name__ = 'iou_{}'.format(name) | |
return label_iou | |
def mean_iou(y_true, y_pred): | |
""" | |
Return the Intersection over Union (IoU) score. | |
Args: | |
y_true: the expected y values as a one-hot | |
y_pred: the predicted y values as a one-hot or softmax output | |
Returns: | |
the scalar IoU value (mean over all labels) | |
""" | |
# get number of labels to calculate IoU for | |
num_labels = K.int_shape(y_pred)[-1] | |
# initialize a variable to store total IoU in | |
total_iou = K.variable(0) | |
# iterate over labels to calculate IoU for | |
for label in range(num_labels): | |
total_iou = total_iou + iou(y_true, y_pred, label) | |
# divide total IoU by number of labels to get mean IoU | |
return total_iou / num_labels | |
# explicitly define the outward facing API of this module | |
__all__ = [build_iou_for.__name__, mean_iou.__name__] |
In keras gradient doesnt exist for argmax. this loss doesnt work.
it's a metric, not a loss. So yes, as a loss, it does not work, but as a metric (the intended design), it works fine.
Yeap. As a metric it works.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I try to use this IoU loss for a U-Net model but this error comes up. What's the problem? thank you.
model.compile(optimizer = Adam(lr = 1e-4), loss = mean_iou, metrics = [mean_iou])
ValueError: No gradients provided for any variable: ['conv2d/kernel:0', 'conv2d/bias:0', 'conv2d_1/kernel:0', 'conv2d_1/bias:0', 'conv2d_2/kernel:0', 'conv2d_2/bias:0', 'conv2d_3/kernel:0', 'conv2d_3/bias:0', 'conv2d_4/kernel:0', 'conv2d_4/bias:0', 'conv2d_5/kernel:0', 'conv2d_5/bias:0', 'conv2d_6/kernel:0', 'conv2d_6/bias:0', 'conv2d_7/kernel:0', 'conv2d_7/bias:0', 'conv2d_8/kernel:0', 'conv2d_8/bias:0', 'conv2d_9/kernel:0', 'conv2d_9/bias:0', 'conv2d_10/kernel:0', 'conv2d_10/bias:0', 'conv2d_11/kernel:0', 'conv2d_11/bias:0', 'conv2d_12/kernel:0', 'conv2d_12/bias:0', 'conv2d_13/kernel:0', 'conv2d_13/bias:0', 'conv2d_14/kernel:0', 'conv2d_14/bias:0', 'conv2d_15/kernel:0', 'conv2d_15/bias:0', 'conv2d_16/kernel:0', 'conv2d_16/bias:0', 'conv2d_17/kernel:0', 'conv2d_17/bias:0', 'conv2d_18/kernel:0', 'conv2d_18/bias:0', 'conv2d_19/kernel:0', 'conv2d_19/bias:0', 'conv2d_20/kernel:0', 'conv2d_20/bias:0', 'conv2d_21/kernel:0', 'conv2d_21/bias:0', 'conv2d_22/kernel:0', 'conv2d_22/bias:0', 'conv2d_23/kernel:0', 'conv2d_23/bias:0'].