Last active
May 26, 2024 20:35
-
-
Save rohithteja/885a0d231016b24bc0c3c248e53d1692 to your computer and use it in GitHub Desktop.
Optuna Hyperparameter Tuning with XGBoost
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import optuna | |
| from xgboost import XGBClassifier | |
| from optuna.trial import TrialState | |
| from sklearn.metrics import accuracy_score | |
| # optuna's objective function | |
| def objective(trial): | |
| learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True) | |
| max_depth = trial.suggest_int("max_depth", 2, 10,step=2, log=False) | |
| n_estimators = trial.suggest_int("n_estimators", 100, 300,step=100, log=False) | |
| model = XGBClassifier(objective= 'multi:softprob', | |
| learning_rate = learning_rate, | |
| n_estimators = n_estimators, | |
| max_depth = max_depth, | |
| seed=42) | |
| model.fit(x_train,y_train) | |
| y_pred = model.predict(x_val) | |
| accuracy = accuracy_score(y_val, y_pred) | |
| # Handle pruning based on the intermediate value. | |
| if trial.should_prune(): | |
| raise optuna.exceptions.TrialPruned() | |
| trial.set_user_attr(key="best_model", value=model) # save model | |
| return accuracy | |
| # callback function to save the best model as user attribute | |
| def callback(study, trial): | |
| if study.best_trial.number == trial.number: | |
| study.set_user_attr(key="best_model", value=trial.user_attrs["best_model"]) | |
| # study to maximize the accuracy metric | |
| study = optuna.create_study(direction="maximize") | |
| study.optimize(objective, n_trials=20, timeout=None, callbacks=[callback]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great example of hyperparameter optimization with optuna!
You can also retrieve the best parameters in optuna via
study.best_paramsand fit a new final model directly, rather than using a callback.