-
-
Save daveebbelaar/a31e9ebadaf02472bc744c5b91d02321 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import optuna | |
import xgboost as xgb | |
from sklearn.metrics import mean_squared_error # or any other metric | |
from sklearn.model_selection import train_test_split | |
# Load the dataset | |
X, y = ... # load your own | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
# Define the objective function for Optuna | |
def objective(trial): | |
# Define the search space for hyperparameters | |
param = { | |
'objective': 'reg:squarederror', | |
'eval_metric': 'rmse', | |
'eta': trial.suggest_float('eta', 0.01, 0.3), | |
'num_boost_round': 100000, # Fix the boosting round and use early stopping | |
'max_depth': trial.suggest_int('max_depth', 3, 10), | |
'subsample': trial.suggest_float('subsample', 0.5, 1.0), | |
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0), | |
'gamma': trial.suggest_float('gamma', 0.0, 10.0), | |
'min_child_weight': trial.suggest_float('min_child_weight', 0.1, 10.0), | |
'lambda': trial.suggest_float('lambda', 0.1, 10.0), | |
'alpha': trial.suggest_float('alpha', 0.0, 10.0), | |
} | |
# Split the data into further training and validation sets (three sets are preferable) | |
train_data, valid_data, train_target, valid_target = train_test_split(X_train, y_train, test_size=0.2, random_state=42) | |
# Convert the data into DMatrix format | |
dtrain = xgb.DMatrix(train_data, label=train_target) | |
dvalid = xgb.DMatrix(valid_data, label=valid_target) | |
# Define the pruning callback for early stopping | |
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-rmse') | |
# Train the model with early stopping | |
model = xgb.train(param, dtrain, evals=[(dvalid, 'validation')], early_stopping_rounds=100, callbacks=[pruning_callback]) | |
# Make predictions on the test set | |
dtest = xgb.DMatrix(valid_data) | |
y_pred = model.predict(dtest) | |
# Calculate the root mean squared error | |
rmse = mean_squared_error(valid_test, y_pred, squared=False) | |
return rmse | |
# Create an Optuna study and optimize the objective function | |
study = optuna.create_study(direction='minimize') | |
study.optimize(objective, n_trials=100) # Control the number of trials | |
# Print the best hyperparameters and the best RMSE | |
best_params = study.best_params | |
best_rmse = study.best_value | |
print("Best Hyperparameters: ", best_params) | |
print("Best RMSE: ", best_rmse) | |
#---------------------------------------------------------------------# | |
# You can also tune for multiple metrics. See here: https://stackoverflow.com/questions/69071684/how-to-optimize-for-multiple-metrics-in-optuna |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment