Skip to content

Instantly share code, notes, and snippets.

@aflansburg
Created August 5, 2021 14:46
Show Gist options
  • Save aflansburg/5a252691dcae0c69061b186ddae40fa7 to your computer and use it in GitHub Desktop.
Save aflansburg/5a252691dcae0c69061b186ddae40fa7 to your computer and use it in GitHub Desktop.
Check GridSearchCV fit Runtime
# import time - not the abstract construct of 'time'
# but rather a library built into Python for
# dealing with time
from time import time
# ML stuff
ada_tuned_clf = AdaBoostClassifier(random_state=1)
# some canned params for hypertuning
parameters = {
"base_estimator":[DecisionTreeClassifier(max_depth=1),DecisionTreeClassifier(max_depth=2),DecisionTreeClassifier(max_depth=3)],
"n_estimators": np.arange(10,110,10),
"learning_rate":np.arange(0.1,2,0.1)
}
# scorer - recall (sensitivity) -> TP/(TP+FN)
acc_scorer = metrics.make_scorer(metrics.recall_score)
# Run the grid search
grid_obj = GridSearchCV(abc_tuned, parameters, scoring=acc_scorer,cv=5)
## 1st Method Using time ##
# capture the current time in the variable 'start'
start = time()
# fit the grid object to the data
grid_obj = grid_obj.fit(X_train, y_train)
# capture the current time in the variable 'end'
end = time()
# Set the clf to the best combination of parameters
ada_tuned = grid_obj.best_estimator_
# fit the best estimator to the data
ada_tuned.fit(X_train, y_train)
## 2nd Method Using attributes of cv_results ##
# runtime info based on solution below and fit_time results of the gridsearchcv return object
# https://datascience.stackexchange.com/a/93524/41883
mean_fit_time= grid_obj.cv_results_['mean_fit_time']
mean_score_time= grid_obj.cv_results_['mean_score_time']
n_splits = grid_obj.n_splits_ #number of splits of training data
n_iter = pd.DataFrame(grid_obj.cv_results_).shape[0] #Iterations per split
# print both for comparison
print('calculated runtime from GridSearchCV cv_results_ attributes')
print(np.mean(mean_fit_time + mean_score_time) * n_splits * n_iter)
print('calculated runtime from using time()')
print(end - start)
calculated runtime from using time()
295.12124705314636
calculated runtime from GridSearchCV cv_results_ attributes
293.37535548210144
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment