Last active
October 3, 2024 14:16
-
-
Save mgbckr/13fe59473a9135dd11d03cc5650932e0 to your computer and use it in GitHub Desktop.
Custom `cross_val` function for Pythons scikit-learn library. As opposed to scikit-learn's implementations, it returns predictions, split information, estimators.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.base import clone | |
from sklearn.model_selection import KFold | |
def cross_val( | |
estimator, | |
X, y, groups=None, | |
cv=5, | |
# basic | |
predict="regression", | |
squeeze_pred=True, | |
squeeze_pred_proba=True, | |
return_y_pred=True, | |
# data | |
return_X=False, | |
return_y=False, | |
return_groups=False, | |
# split infos | |
return_estimator=False, | |
return_y_pred_test=False, | |
return_y_test=False, | |
# return_y_train=True, | |
random_state=None | |
): | |
""" | |
Custom `cross_val` function for Pythons scikit-learn library. | |
I don't like scikit-learn's | |
`cross_val` because it does not return predictions and | |
`cross_val_predict` does not give me split information or estimators. | |
TODO: enhance | |
* parallelization via joblib | |
* repeated cross validation | |
* maybe more convenience methods (return training predictions, runtimes, etc.) | |
Parameters | |
---------- | |
estimator : object | |
A object with a `fit` method that is used to fit the data. | |
X : array-like of shape (n_samples, n_features) | |
The data to fit. | |
y : array-like of shape (n_samples,) | |
The target variable to try to predict. | |
groups : array-like of shape (n_samples,), default=None | |
Group labels for the samples used while splitting the dataset into train/test set. | |
This is passed to the fit method of the estimator as well. | |
cv : int, cross-validation generator or an iterable, default=5 | |
Determines the cross-validation splitting strategy. | |
Possible inputs for cv are: | |
- None, to use the default 5-fold cross validation, | |
- int, to specify the number of folds in a KFold (not grouped or stratified!), | |
- An object to be used as a cross-validation generator. | |
- An iterable yielding train, test splits. | |
predict : str or list of str, default="regression" | |
The prediction method to use. Possible values are: | |
- "regression": Use the `predict` method. | |
- "classification": Use the `predict_proba` method. | |
- list of str: Use the methods with the names in the list. | |
Non existing methods are ignored. | |
squeeze_pred : bool, default=True | |
If True and only one predict function is used, the result is squeezed, | |
i.e., the result is the prediction itself, not a dictionary. | |
squeeze_pred_proba : bool, default=True | |
If True and predict_proba is used, | |
only predictions for the positive class are returned, i.e., `y_pred[:,1]`. | |
return_y_pred : bool, default=True | |
If True, the predictions across all and sorted like `y` | |
are returned as `y_pred`. | |
return_X : bool, default=False | |
If True, the input data `X` is returned. | |
return_y : bool, default=False | |
If True, the target variable `y` is returned. | |
return_groups : bool, default=False | |
If True, the groups are returned. | |
return_estimator : bool, default=False | |
If True, the estimator is returned. | |
return_y_pred_test : bool, default=False | |
If True, the predictions for each test set are returned. | |
return_y_test : bool, default=False | |
If True, the target variable for each test set is returned. | |
random_state : int, RandomState instance or None, default=None | |
Determines random number generation for dataset creation. | |
Pass an int for reproducible output | |
across multiple function calls. | |
Currently only used for the default cv (`None`, or `int`). | |
Returns | |
------- | |
results : dict | |
""" | |
if cv is None: | |
cv = KFold(n_splits=5, shuffle=True, random_state=random_state) | |
elif isinstance(cv, int): | |
cv = KFold(n_splits=cv, shuffle=True, random_state=random_state) | |
elif not hasattr(cv, "split"): | |
cv = cv | |
if predict == "regression": | |
predict_functions = ["predict"] | |
elif predict == "classification": | |
predict_functions = ["predict_proba"] | |
else: | |
predict_functions = predict | |
results = { | |
"test_idx": [], | |
} | |
# data | |
if return_X: | |
results["X"] = X | |
if return_y: | |
results["y"] = y | |
if return_groups: | |
results["groups"] = groups | |
# split infos | |
if return_estimator: | |
results["estimator"] = [] | |
if return_y_test: | |
results["y_test"] = [] | |
y_pred_test = [] | |
for train_idx, test_idx in cv.split(X, y, groups=groups): | |
X_train = X[train_idx] | |
y_train = y[train_idx] | |
X_test = X[test_idx] | |
y_test = y[test_idx] | |
if groups is not None: | |
groups_train = groups[train_idx] | |
else: | |
groups_train = None | |
estimator_split = clone(estimator) | |
# call fit (with groups if groups argument is available in fit method) | |
# TODO: check if using `__code__` is efficient | |
if "groups" in estimator_split.fit.__code__.co_varnames: | |
estimator_split.fit(X_train, y_train, groups=groups_train) | |
else: | |
estimator_split.fit(X_train, y_train) | |
y_pred_test_split = {} | |
for predict in predict_functions: | |
# call the method with name `predict` on estimator and store the result | |
if hasattr(estimator_split, predict): | |
y_pred_test_split[predict] = getattr(estimator_split, predict)(X_test) | |
y_pred_test.append(y_pred_test_split) | |
results["test_idx"].append(test_idx) | |
if return_estimator: | |
results["estimator"].append(estimator) | |
if return_y_test: | |
results["y_test"].append(y_test) | |
# actually available predict functions | |
predict_functions_actual = list(y_pred_test[0].keys()) | |
def process_results(result, concatenate_splits): | |
# squeeze predict_proba if available | |
if squeeze_pred_proba and "predict_proba" in predict_functions_actual: | |
result = [x["predict_proba"][:,1] for x in result] | |
# merge | |
if concatenate_splits: | |
merged = {} | |
for predict_function in predict_functions_actual: | |
merged[predict_function] = \ | |
np.concatenate([x[predict_function] for x in result]) | |
# squeeze predictions if only one predict function is used | |
if squeeze_pred and len(predict_functions_actual) == 1: | |
predict_function = predict_functions_actual[0] | |
if concatenate_splits: | |
result = merged[predict_function] | |
else: | |
result = [x[predict_function] for x in result] | |
return result | |
# return `y_pred` sorted like original y | |
if return_y_pred: | |
y_pred = process_results(y_pred_test, concatenate_splits=True) | |
# derive original y order | |
test_idx = np.concatenate(results["test_idx"]) | |
test_idx_order = np.argsort(test_idx) | |
# sort | |
if squeeze_pred and len(predict_functions_actual) == 1: | |
y_pred = y_pred[test_idx_order] | |
else: | |
y_pred = {k: v[test_idx_order] for k, v in y_pred.items()} | |
# return y_pred | |
results["y_pred"] = y_pred | |
# return y_pred_test | |
if return_y_pred_test: | |
y_pred_test_processed = process_results(y_pred_test, concatenate_splits=False) | |
results["y_pred_test"] = y_pred_test_processed | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment