Last active
August 1, 2021 18:52
-
-
Save aflansburg/12e64fc9bfb7cbf744b1595c37703e28 to your computer and use it in GitHub Desktop.
Calculate GridSearchCV runtime
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# runtime info based on solution below and fit_time results of the gridsearchcv return object | |
# based on a response on StackExchange Data Science - Naveen Vuppula | |
# https://datascience.stackexchange.com/a/93524/41883 | |
# from time import time | |
def gridsearch_runtime(grid_obj, X_train, y_train): | |
''' | |
Parameters: | |
grid_obj: GridSearchCV return object that has not yet been fit to training data | |
X_train: split training data independent variables | |
y_train: split training data containing dependent variable | |
''' | |
start = time() | |
grid_obj.fit(X_train, y_train) | |
end = time() | |
mean_fit_time= grid_obj.cv_results_['mean_fit_time'] | |
mean_score_time= grid_obj.cv_results_['mean_score_time'] | |
n_splits = grid_obj.n_splits_ #number of splits of training data | |
n_iter = pd.DataFrame(grid_obj.cv_results_).shape[0] #Iterations per split | |
time_from_cv_result = np.mean(mean_fit_time + mean_score_time) * n_splits * n_iter | |
time_from_sys_time = end - start | |
return time_from_cv_result, time_from_sys_time |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# obviously you would use this (and write the method above) in a different manner | |
# but this is more for use as a very basic compute performance testing tool | |
# assuming some estimator (tuned_estimator), parameters, and scorer | |
grid_obj = GridSearchCV(tuned_estimator, parameters, scoring=acc_scorer,cv=5) | |
res_time, res_sys_time = gridsearch_runtime(grid_obj, X_train, y_train) | |
print('calculated runtime from GridSearchCV cv_results_ attributes') | |
print(res_time) | |
print('calculated runtime from using time()') | |
print(res_sys_time) | |
# Output (time units are in seconds): | |
# calculated runtime from GridSearchCV cv_results_ attributes | |
# 293.37535548210144 | |
# calculated runtime from using time() | |
# 295.12124705314636 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment