Created
April 6, 2021 00:23
-
-
Save PatrickRudgeri/a45a1b38c3c998a03320bfdfe14463b3 to your computer and use it in GitHub Desktop.
Modified Grid Search CV
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.linear_model import LogisticRegression | |
from sklearn.neighbors import KNeighborsClassifier | |
from sklearn.model_selection import StratifiedKFold, cross_validate | |
from sklearn.ensemble import VotingClassifier | |
import pandas as pd | |
from IPython.display import display | |
def cross_validation(estimator, model_name, X, y, n_splits=5, random_state=None): | |
# definir metricas aqui (Se for métrica personalizada deverá utilizar `make_score`. Consultar docs do sklearn) | |
metricas = { | |
"F1": "f1", | |
"Acc": "accuracy", | |
"Recall": "recall", | |
"Precision": "precision", | |
"AUC": "roc_auc" | |
} | |
# K-fold CV | |
K = StratifiedKFold(n_splits=n_splits, random_state=random_state, shuffle=True) | |
cv_result = cross_validate(estimator, X, y, cv=K, n_jobs=-1, scoring=metricas, verbose=5) | |
# se quiser imprimir tempo de treinamento | |
# print(f"{cv_result['fit_time'].mean():.05f} ({cv_result['fit_time'].std():.05f})") | |
mean_std = lambda v: f"{v.mean():.04f} ({v.std():.04f})" | |
scores = {k.lstrip("test_"): v for k, v in cv_result.items() if "test" in k} | |
# retorna uma tupla com um DataFrame das avaliações e os scores individuais | |
cv_result = {k: mean_std(v) for k, v in scores.items()} | |
return pd.DataFrame(cv_result, index=[model_name]), scores | |
estimators = ( | |
("logReg", LogisticRegression()), | |
("knC", KNeighborsClassifier()) | |
) | |
param_weights = [(0.1, 0.9), (0.2, 0.8), (0.5, 0.5)] | |
# Itera sobre cada tupla de pesos | |
for weights in param_weights: | |
# 10-fold cv | |
K = 10 | |
vot_clf = VotingClassifier(estimators, voting="soft", weights=weights, n_jobs=-1) | |
df_results, scores = cross_validation(vot_clf, model_name=f"Soft_voting-{str(weights)}", X=X, y=y, | |
n_splits=K, random_state=42) | |
display(df_results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment