Last active
March 31, 2022 08:46
-
-
Save mberr/78559dfec160d5bf7245d674ca18f42f to your computer and use it in GitHub Desktop.
Determine optimal threshold for Macro F1 score
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Determine optimal threshold for Macro F1 score.""" | |
from typing import Tuple | |
import numpy | |
from sklearn.metrics._ranking import _binary_clf_curve | |
def f1_scores( | |
precision: numpy.ndarray, | |
recall: numpy.ndarray, | |
) -> numpy.ndarray: | |
denom = precision + recall | |
denom[denom == 0.0] = 1.0 | |
return 2 * (precision * recall) / denom | |
def recall( | |
tps: numpy.ndarray, | |
tps_fns: numpy.ndarray, | |
) -> numpy.ndarray: | |
return tps / tps_fns | |
def precision( | |
tps: numpy.ndarray, | |
tps_fps: numpy.ndarray, | |
) -> numpy.ndarray: | |
return tps / tps_fps | |
def all_f1_scores( | |
y_true: numpy.ndarray, | |
y_score: numpy.ndarray, | |
) -> Tuple[numpy.ndarray, numpy.ndarray]: | |
# cf. https://stats.stackexchange.com/questions/518616/how-to-find-the-optimal-threshold-for-the-weighted-f1-score-in-a-binary-classifi | |
# cf. https://arxiv.org/abs/1911.03347 | |
# compute TP, FP, FN, TN for all thresholds | |
fps, tps, thresholds = _binary_clf_curve(y_true, y_score) | |
tns = fps[-1] - fps | |
fns = tps[-1] - tps | |
# F1-scores positive class | |
f1_pos = f1_scores( | |
precision=precision(tps=tps, tps_fps=tps + fps), | |
recall=recall(tps=tps, tps_fns=tps[-1]), # tps + fns = tps + (tps[-1] - tps) = tps[-1] | |
) | |
# F1-scores negative class | |
f1_neg = f1_scores( | |
precision=precision(tps=tns, tps_fps=tns + fns), | |
recall=recall(tps=tns, tps_fns=tns + fps), # tns + fps = fps[-1] - fps + fps = fps[-1] | |
) | |
# macro average | |
f1 = 0.5 * (f1_pos + f1_neg) | |
return thresholds, f1 | |
def optimal_f1_score( | |
y_true: numpy.ndarray, | |
y_score: numpy.ndarray, | |
) -> Tuple[float, float]: | |
thresholds, f1s = all_f1_scores(y_true=y_true, y_score=y_score) | |
idx = numpy.nanargmax(f1s) | |
return thresholds[idx], f1s[idx] | |
if __name__ == "__main__": | |
from sklearn.neural_network import MLPClassifier | |
from sklearn.datasets import make_moons | |
from matplotlib import pyplot as plt | |
clf = MLPClassifier() | |
X, y_true = make_moons() | |
clf.fit(X, y_true) | |
y_score = clf.predict_proba(X)[:, 1] | |
ts, f1_both = all_f1_scores(y_true=y_true, y_score=y_score) | |
t, opt = optimal_f1_score(y_true=y_true, y_score=y_score) | |
fig, ax = plt.subplots() | |
ax.plot(ts, f1_both) | |
ax.axvline(t, ls="dashed", color="black") | |
ax.axhline(opt, ls="dashed", color="black") | |
ax.set_xlabel("score") | |
ax.set_ylabel("$F_1$") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment