Skip to content

Instantly share code, notes, and snippets.

@smly
Last active July 9, 2023 08:20
Show Gist options
  • Save smly/5a5ddf968d59492b79e4cbf90b2d3430 to your computer and use it in GitHub Desktop.
Save smly/5a5ddf968d59492b79e4cbf90b2d3430 to your computer and use it in GitHub Desktop.
Usage of custom eval metric function with Optuna
diff --git a/examples/pruning/lightgbm_integration.py b/examples/pruning/lightgbm_integration.py
index 8e623772..4c0c315c 100644
--- a/examples/pruning/lightgbm_integration.py
+++ b/examples/pruning/lightgbm_integration.py
@@ -21,6 +21,15 @@ import optuna
# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
def objective(trial):
+ def custom_accuracy_pct(preds, data):
+ y_true = data.get_label()
+ acc = custom_accuracy_numpy(preds > 0.5, y_true)
+ return 'custom_accuracy', acc, True
+
+ def custom_accuracy_numpy(y_pred, y_true):
+ acc = np.mean(y_true == y_pred) * 100
+ return acc
+
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)
dtrain = lgb.Dataset(train_x, label=train_y)
@@ -41,14 +50,16 @@ def objective(trial):
}
# Add a callback for pruning.
- pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'auc')
+ pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'custom_accuracy')
gbm = lgb.train(
- param, dtrain, valid_sets=[dtest], verbose_eval=False, callbacks=[pruning_callback])
+ param, dtrain, valid_sets=[dtest], verbose_eval=False, callbacks=[pruning_callback],
+ feval=custom_accuracy_pct)
preds = gbm.predict(test_x)
pred_labels = np.rint(preds)
- accuracy = sklearn.metrics.accuracy_score(test_y, pred_labels)
- return accuracy
+
+ accuracy_pct = custom_accuracy_numpy(test_y, pred_labels)
+ return accuracy_pct
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment