Created
August 27, 2022 14:26
-
-
Save andrea-dagostino/76a0b34e440a98fa0c7b5203e9ea765e to your computer and use it in GitHub Desktop.
cross_val
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.model_selection import KFold | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn import datasets | |
| from sklearn import metrics | |
| # creiamo un dataset per un task di classificazione | |
| X, y = datasets.make_classification(n_samples=2000, n_features=20, n_classes=2, random_state=42) | |
| # creiamo l'oggetto KFold applicando la regola di Sturges | |
| sturges = int(1 + np.log(len(X))) | |
| kf = KFold(n_splits=sturges, shuffle=True, random_state=42) | |
| fold = 0 | |
| aucs = [] | |
| # QUESTO È IL LOOP DI CROSS-VALIDAZIONE! | |
| for train_idx, val_idx, in kf.split(X, y): | |
| # l'oggetto kf genera gli indici e i valori per le rispettive X e y, creando il set di validazione su cui testare il modello nello split. | |
| X_tr = X[train_idx] | |
| y_tr = y[train_idx] | |
| X_val = X[val_idx] | |
| y_val = y[val_idx] | |
| # ---- | |
| # applicare qui le ipotesi che vogliamo testare | |
| # ... | |
| # ---- | |
| # qui addestriamo il modello | |
| clf = RandomForestClassifier(n_estimators=100) | |
| clf.fit(X_tr, y_tr) | |
| # creiamo le predizioni e salviamo lo score nella lista aucs | |
| pred = clf.predict(X_val) | |
| pred_prob = clf.predict_proba(X_val)[:, 1] | |
| acc_score = metrics.accuracy_score(y_val, pred) | |
| auc_score = metrics.roc_auc_score(y_val, pred_prob) | |
| print(f"======= Fold {fold} ========") | |
| print( | |
| f"Accuracy on the validation set is {acc_score:0.4f} and AUC is {auc_score:0.4f}" | |
| ) | |
| # aggiorniamo il valore di fold così possiamo stampare il progresso | |
| fold += 1 | |
| aucs.append(auc_score) | |
| general_auc_score = np.mean(aucs) | |
| print(f'\nOur out of fold AUC score is {general_auc_score:0.4f}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment