Created
August 22, 2022 10:45
-
-
Save andrea-dagostino/22740d74fe437dcc9731146c68591cdc to your computer and use it in GitHub Desktop.
overfitting_example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pandas as pd | |
| df = pd.read_csv('wineQualityReds.csv') # download dataset -> https://www.kaggle.com/datasets/piyushgoyal443/red-wine-dataset | |
| # poiché il dataset contiene solo i numeri da 3 a 8, rimappiamoli con i numeri da 1 a 5. | |
| quality_mapping = { | |
| 3: 0, | |
| 4: 1, | |
| 5: 2, | |
| 6: 3, | |
| 7: 4, | |
| 8: 5 | |
| } | |
| df.loc[:, 'quality'] = df.quality.map(quality_mapping) | |
| # dividiamo il dataset in due parti, training e test. Poiché il dataset ha 1599 esempi, | |
| # ne useremo 1000 per il training e 599 per il test. | |
| # usiamo frac=1 per fare shuffling dei dati e resettiamo l'indice | |
| df = df.sample(frac=1).reset_index(drop=True) | |
| # definiamo training e test set | |
| df_train = df.head(1000) | |
| df_test = df.tail(599) | |
| # ora addestriamo un albero decisionale sulle colonne feature | |
| from sklearn import tree | |
| from sklearn import metrics | |
| cols = [ | |
| 'fixed.acidity', 'volatile.acidity', 'citric.acid','residual.sugar', 'chlorides', 'free.sulfur.dioxide', | |
| 'total.sulfur.dioxide', 'density', 'pH', 'sulphates', 'alcohol', | |
| ] | |
| # addestriamo il modello | |
| clf = tree.DecisionTreeClassifier(max_depth=3) | |
| clf.fit(df_train[cols], df_train.quality) | |
| # creiamo predizioni | |
| train_predictions = clf.predict(df_train[cols]) | |
| test_predictions = clf.predict(df_test[cols]) | |
| # usiamo Sklearn per calcolare l'accuratezza | |
| train_accuracy = metrics.accuracy_score(df_train.quality, train_predictions) | |
| test_accuracy = metrics.accuracy_score(df_test.quality, test_predictions) | |
| print(f"Train accuracy: {round(train_accuracy, 3)}") | |
| print(f"Test accuracy: {round(test_accuracy, 3)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment