Skip to content

Instantly share code, notes, and snippets.

@audhiaprilliant
Created December 24, 2020 03:05
Show Gist options
  • Select an option

  • Save audhiaprilliant/56ae7d00c7bf59f92e84563392e1f300 to your computer and use it in GitHub Desktop.

Select an option

Save audhiaprilliant/56ae7d00c7bf59f92e84563392e1f300 to your computer and use it in GitHub Desktop.
How to choose the optimal threshold for imbalanced classification
# Array for finding the optimal threshold
thresholds = np.arange(0.0, 1.0, 0.0001)
fscore = np.zeros(shape=(len(thresholds)))
print('Length of sequence: {}'.format(len(thresholds)))
# Fit the model
for index, elem in enumerate(thresholds):
# Corrected probabilities
y_pred_prob = (y_pred > elem).astype('int')
# Calculate the f-score
fscore[index] = f1_score(y_test, y_pred_prob)
# Find the optimal threshold
index = np.argmax(fscore)
thresholdOpt = round(thresholds[index], ndigits = 4)
fscoreOpt = round(fscore[index], ndigits = 4)
print('Best Threshold: {} with F-Score: {}'.format(thresholdOpt, fscoreOpt))
# Plot the threshold tuning
df_threshold_tuning = pd.DataFrame({'F-score':fscore,
'Threshold':thresholds})
df_threshold_tuning.head()
plotnine.options.figure_size = (8, 4.8)
(
ggplot(data = df_threshold_tuning)+
geom_point(aes(x = 'Threshold',
y = 'F-score'),
size = 0.4)+
# Best threshold
geom_point(aes(x = thresholdOpt,
y = fscoreOpt),
color = '#981220',
size = 4)+
geom_line(aes(x = 'Threshold',
y = 'F-score'))+
# Annotate the text
geom_text(aes(x = thresholdOpt,
y = fscoreOpt),
label = 'Optimal threshold \n for class: {}'.format(thresholdOpt),
nudge_x = 0,
nudge_y = -0.10,
size = 10,
fontstyle = 'italic')+
labs(title = 'Threshold Tuning Curve')+
xlab('Threshold')+
ylab('F-score')+
theme_minimal()
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment