Last active
August 27, 2018 12:13
-
-
Save atamborrino/d9fe0ca806a64f87cdf46b0b1a9ea20c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# src: https://www.kaggle.com/aglotero/another-iou-metric | |
def iou_metric(y_true_in, y_pred_in, print_table=False): | |
labels = y_true_in | |
y_pred = y_pred_in | |
true_objects = 2 | |
pred_objects = 2 | |
intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0] | |
# Compute areas (needed for finding the union between all objects) | |
area_true = np.histogram(labels, bins = true_objects)[0] | |
area_pred = np.histogram(y_pred, bins = pred_objects)[0] | |
area_true = np.expand_dims(area_true, -1) | |
area_pred = np.expand_dims(area_pred, 0) | |
# Compute union | |
union = area_true + area_pred - intersection | |
# Exclude background from the analysis | |
intersection = intersection[1:,1:] | |
union = union[1:,1:] | |
union[union == 0] = 1e-9 | |
# Compute the intersection over union | |
iou = intersection / union | |
# Precision helper function | |
def precision_at(threshold, iou): | |
matches = iou > threshold | |
true_positives = np.sum(matches, axis=1) == 1 # Correct objects | |
false_positives = np.sum(matches, axis=0) == 0 # Missed objects | |
false_negatives = np.sum(matches, axis=1) == 0 # Extra objects | |
tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives) | |
return tp, fp, fn | |
# Loop over IoU thresholds | |
prec = [] | |
if print_table: | |
print("Thresh\tTP\tFP\tFN\tPrec.") | |
for t in np.arange(0.5, 1.0, 0.05): | |
tp, fp, fn = precision_at(t, iou) | |
if (tp + fp + fn) > 0: | |
p = tp / (tp + fp + fn) | |
else: | |
p = 0 | |
if print_table: | |
print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p)) | |
prec.append(p) | |
if print_table: | |
print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec))) | |
return np.mean(prec) | |
def iou_metric_batch(y_true_in, y_pred_in): | |
batch_size = y_true_in.shape[0] | |
metric = [] | |
for batch in range(batch_size): | |
value = iou_metric(y_true_in[batch], y_pred_in[batch]) | |
metric.append(value) | |
return np.mean(metric) | |
def best_iou_and_threshold(y_true, y_pred, plot=False): | |
thresholds = np.linspace(0, 1, 50) | |
ious = np.array([iou_metric_batch(y_true, np.int32(y_pred > threshold)) for threshold in thresholds]) | |
threshold_best_index = np.argmax(ious[9:-10]) + 9 | |
iou_best = ious[threshold_best_index] | |
threshold_best = thresholds[threshold_best_index] | |
if plot: | |
plt.plot(thresholds, ious) | |
plt.plot(threshold_best, iou_best, "xr", label="Best threshold") | |
plt.xlabel("Threshold") | |
plt.ylabel("IoU") | |
plt.title("Threshold vs IoU ({}, {})".format(threshold_best, iou_best)) | |
plt.legend() | |
print(f'threshold_best: {threshold_best}') | |
print(f'iou_best: {iou_best}') | |
return iou_best, threshold_best | |
class ValGlobalMetrics(keras.callbacks.Callback): | |
def on_epoch_end(self, batch, logs={}): | |
predict = np.asarray(self.model.predict(x_valid)) | |
targ = y_valid | |
best_iou, _ = best_iou_and_threshold(y_true=targ, y_pred=predict) | |
logs['val_best_iou'] = best_iou | |
print(f' - val_best_iou: {best_iou}') | |
early_stopping = EarlyStopping(patience=10, verbose=1, monitor='val_best_iou', mode='max') | |
model_checkpoint = ModelCheckpoint("./keras.model", save_best_only=True, verbose=1, monitor='val_best_iou', mode='max') | |
reduce_lr = ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1, monitor='val_best_iou', mode='max') | |
keras_callbacks=[ValGlobalMetrics(), early_stopping, model_checkpoint, reduce_lr] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment