Skip to content

Instantly share code, notes, and snippets.

@mgbckr
Last active October 3, 2024 14:16
Show Gist options
  • Save mgbckr/13fe59473a9135dd11d03cc5650932e0 to your computer and use it in GitHub Desktop.
Save mgbckr/13fe59473a9135dd11d03cc5650932e0 to your computer and use it in GitHub Desktop.
Custom `cross_val` function for Pythons scikit-learn library. As opposed to scikit-learn's implementations, it returns predictions, split information, estimators.
from sklearn.base import clone
from sklearn.model_selection import KFold
def cross_val(
estimator,
X, y, groups=None,
cv=5,
# basic
predict="regression",
squeeze_pred=True,
squeeze_pred_proba=True,
return_y_pred=True,
# data
return_X=False,
return_y=False,
return_groups=False,
# split infos
return_estimator=False,
return_y_pred_test=False,
return_y_test=False,
# return_y_train=True,
random_state=None
):
"""
Custom `cross_val` function for Pythons scikit-learn library.
I don't like scikit-learn's
`cross_val` because it does not return predictions and
`cross_val_predict` does not give me split information or estimators.
TODO: enhance
* parallelization via joblib
* repeated cross validation
* maybe more convenience methods (return training predictions, runtimes, etc.)
Parameters
----------
estimator : object
A object with a `fit` method that is used to fit the data.
X : array-like of shape (n_samples, n_features)
The data to fit.
y : array-like of shape (n_samples,)
The target variable to try to predict.
groups : array-like of shape (n_samples,), default=None
Group labels for the samples used while splitting the dataset into train/test set.
This is passed to the fit method of the estimator as well.
cv : int, cross-validation generator or an iterable, default=5
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 5-fold cross validation,
- int, to specify the number of folds in a KFold (not grouped or stratified!),
- An object to be used as a cross-validation generator.
- An iterable yielding train, test splits.
predict : str or list of str, default="regression"
The prediction method to use. Possible values are:
- "regression": Use the `predict` method.
- "classification": Use the `predict_proba` method.
- list of str: Use the methods with the names in the list.
Non existing methods are ignored.
squeeze_pred : bool, default=True
If True and only one predict function is used, the result is squeezed,
i.e., the result is the prediction itself, not a dictionary.
squeeze_pred_proba : bool, default=True
If True and predict_proba is used,
only predictions for the positive class are returned, i.e., `y_pred[:,1]`.
return_y_pred : bool, default=True
If True, the predictions across all and sorted like `y`
are returned as `y_pred`.
return_X : bool, default=False
If True, the input data `X` is returned.
return_y : bool, default=False
If True, the target variable `y` is returned.
return_groups : bool, default=False
If True, the groups are returned.
return_estimator : bool, default=False
If True, the estimator is returned.
return_y_pred_test : bool, default=False
If True, the predictions for each test set are returned.
return_y_test : bool, default=False
If True, the target variable for each test set is returned.
random_state : int, RandomState instance or None, default=None
Determines random number generation for dataset creation.
Pass an int for reproducible output
across multiple function calls.
Currently only used for the default cv (`None`, or `int`).
Returns
-------
results : dict
"""
if cv is None:
cv = KFold(n_splits=5, shuffle=True, random_state=random_state)
elif isinstance(cv, int):
cv = KFold(n_splits=cv, shuffle=True, random_state=random_state)
elif not hasattr(cv, "split"):
cv = cv
if predict == "regression":
predict_functions = ["predict"]
elif predict == "classification":
predict_functions = ["predict_proba"]
else:
predict_functions = predict
results = {
"test_idx": [],
}
# data
if return_X:
results["X"] = X
if return_y:
results["y"] = y
if return_groups:
results["groups"] = groups
# split infos
if return_estimator:
results["estimator"] = []
if return_y_test:
results["y_test"] = []
y_pred_test = []
for train_idx, test_idx in cv.split(X, y, groups=groups):
X_train = X[train_idx]
y_train = y[train_idx]
X_test = X[test_idx]
y_test = y[test_idx]
if groups is not None:
groups_train = groups[train_idx]
else:
groups_train = None
estimator_split = clone(estimator)
# call fit (with groups if groups argument is available in fit method)
# TODO: check if using `__code__` is efficient
if "groups" in estimator_split.fit.__code__.co_varnames:
estimator_split.fit(X_train, y_train, groups=groups_train)
else:
estimator_split.fit(X_train, y_train)
y_pred_test_split = {}
for predict in predict_functions:
# call the method with name `predict` on estimator and store the result
if hasattr(estimator_split, predict):
y_pred_test_split[predict] = getattr(estimator_split, predict)(X_test)
y_pred_test.append(y_pred_test_split)
results["test_idx"].append(test_idx)
if return_estimator:
results["estimator"].append(estimator)
if return_y_test:
results["y_test"].append(y_test)
# actually available predict functions
predict_functions_actual = list(y_pred_test[0].keys())
def process_results(result, concatenate_splits):
# squeeze predict_proba if available
if squeeze_pred_proba and "predict_proba" in predict_functions_actual:
result = [x["predict_proba"][:,1] for x in result]
# merge
if concatenate_splits:
merged = {}
for predict_function in predict_functions_actual:
merged[predict_function] = \
np.concatenate([x[predict_function] for x in result])
# squeeze predictions if only one predict function is used
if squeeze_pred and len(predict_functions_actual) == 1:
predict_function = predict_functions_actual[0]
if concatenate_splits:
result = merged[predict_function]
else:
result = [x[predict_function] for x in result]
return result
# return `y_pred` sorted like original y
if return_y_pred:
y_pred = process_results(y_pred_test, concatenate_splits=True)
# derive original y order
test_idx = np.concatenate(results["test_idx"])
test_idx_order = np.argsort(test_idx)
# sort
if squeeze_pred and len(predict_functions_actual) == 1:
y_pred = y_pred[test_idx_order]
else:
y_pred = {k: v[test_idx_order] for k, v in y_pred.items()}
# return y_pred
results["y_pred"] = y_pred
# return y_pred_test
if return_y_pred_test:
y_pred_test_processed = process_results(y_pred_test, concatenate_splits=False)
results["y_pred_test"] = y_pred_test_processed
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment