Created
August 5, 2021 14:46
-
-
Save aflansburg/5a252691dcae0c69061b186ddae40fa7 to your computer and use it in GitHub Desktop.
Check GridSearchCV fit Runtime
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import time - not the abstract construct of 'time' | |
# but rather a library built into Python for | |
# dealing with time | |
from time import time | |
# ML stuff | |
ada_tuned_clf = AdaBoostClassifier(random_state=1) | |
# some canned params for hypertuning | |
parameters = { | |
"base_estimator":[DecisionTreeClassifier(max_depth=1),DecisionTreeClassifier(max_depth=2),DecisionTreeClassifier(max_depth=3)], | |
"n_estimators": np.arange(10,110,10), | |
"learning_rate":np.arange(0.1,2,0.1) | |
} | |
# scorer - recall (sensitivity) -> TP/(TP+FN) | |
acc_scorer = metrics.make_scorer(metrics.recall_score) | |
# Run the grid search | |
grid_obj = GridSearchCV(abc_tuned, parameters, scoring=acc_scorer,cv=5) | |
## 1st Method Using time ## | |
# capture the current time in the variable 'start' | |
start = time() | |
# fit the grid object to the data | |
grid_obj = grid_obj.fit(X_train, y_train) | |
# capture the current time in the variable 'end' | |
end = time() | |
# Set the clf to the best combination of parameters | |
ada_tuned = grid_obj.best_estimator_ | |
# fit the best estimator to the data | |
ada_tuned.fit(X_train, y_train) | |
## 2nd Method Using attributes of cv_results ## | |
# runtime info based on solution below and fit_time results of the gridsearchcv return object | |
# https://datascience.stackexchange.com/a/93524/41883 | |
mean_fit_time= grid_obj.cv_results_['mean_fit_time'] | |
mean_score_time= grid_obj.cv_results_['mean_score_time'] | |
n_splits = grid_obj.n_splits_ #number of splits of training data | |
n_iter = pd.DataFrame(grid_obj.cv_results_).shape[0] #Iterations per split | |
# print both for comparison | |
print('calculated runtime from GridSearchCV cv_results_ attributes') | |
print(np.mean(mean_fit_time + mean_score_time) * n_splits * n_iter) | |
print('calculated runtime from using time()') | |
print(end - start) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
calculated runtime from using time() | |
295.12124705314636 | |
calculated runtime from GridSearchCV cv_results_ attributes | |
293.37535548210144 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment