Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save NeroHin/80a01b320405050e7aa82458e1b365b0 to your computer and use it in GitHub Desktop.
Save NeroHin/80a01b320405050e7aa82458e1b365b0 to your computer and use it in GitHub Desktop.
ensemble-learning-example-with-voting-and-stacking
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean radius</th>\n",
" <th>mean texture</th>\n",
" <th>mean perimeter</th>\n",
" <th>mean area</th>\n",
" <th>mean smoothness</th>\n",
" <th>mean compactness</th>\n",
" <th>mean concavity</th>\n",
" <th>mean concave points</th>\n",
" <th>mean symmetry</th>\n",
" <th>mean fractal dimension</th>\n",
" <th>...</th>\n",
" <th>worst radius</th>\n",
" <th>worst texture</th>\n",
" <th>worst perimeter</th>\n",
" <th>worst area</th>\n",
" <th>worst smoothness</th>\n",
" <th>worst compactness</th>\n",
" <th>worst concavity</th>\n",
" <th>worst concave points</th>\n",
" <th>worst symmetry</th>\n",
" <th>worst fractal dimension</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>17.99</td>\n",
" <td>10.38</td>\n",
" <td>122.80</td>\n",
" <td>1001.0</td>\n",
" <td>0.11840</td>\n",
" <td>0.27760</td>\n",
" <td>0.3001</td>\n",
" <td>0.14710</td>\n",
" <td>0.2419</td>\n",
" <td>0.07871</td>\n",
" <td>...</td>\n",
" <td>25.38</td>\n",
" <td>17.33</td>\n",
" <td>184.60</td>\n",
" <td>2019.0</td>\n",
" <td>0.1622</td>\n",
" <td>0.6656</td>\n",
" <td>0.7119</td>\n",
" <td>0.2654</td>\n",
" <td>0.4601</td>\n",
" <td>0.11890</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>20.57</td>\n",
" <td>17.77</td>\n",
" <td>132.90</td>\n",
" <td>1326.0</td>\n",
" <td>0.08474</td>\n",
" <td>0.07864</td>\n",
" <td>0.0869</td>\n",
" <td>0.07017</td>\n",
" <td>0.1812</td>\n",
" <td>0.05667</td>\n",
" <td>...</td>\n",
" <td>24.99</td>\n",
" <td>23.41</td>\n",
" <td>158.80</td>\n",
" <td>1956.0</td>\n",
" <td>0.1238</td>\n",
" <td>0.1866</td>\n",
" <td>0.2416</td>\n",
" <td>0.1860</td>\n",
" <td>0.2750</td>\n",
" <td>0.08902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>19.69</td>\n",
" <td>21.25</td>\n",
" <td>130.00</td>\n",
" <td>1203.0</td>\n",
" <td>0.10960</td>\n",
" <td>0.15990</td>\n",
" <td>0.1974</td>\n",
" <td>0.12790</td>\n",
" <td>0.2069</td>\n",
" <td>0.05999</td>\n",
" <td>...</td>\n",
" <td>23.57</td>\n",
" <td>25.53</td>\n",
" <td>152.50</td>\n",
" <td>1709.0</td>\n",
" <td>0.1444</td>\n",
" <td>0.4245</td>\n",
" <td>0.4504</td>\n",
" <td>0.2430</td>\n",
" <td>0.3613</td>\n",
" <td>0.08758</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>11.42</td>\n",
" <td>20.38</td>\n",
" <td>77.58</td>\n",
" <td>386.1</td>\n",
" <td>0.14250</td>\n",
" <td>0.28390</td>\n",
" <td>0.2414</td>\n",
" <td>0.10520</td>\n",
" <td>0.2597</td>\n",
" <td>0.09744</td>\n",
" <td>...</td>\n",
" <td>14.91</td>\n",
" <td>26.50</td>\n",
" <td>98.87</td>\n",
" <td>567.7</td>\n",
" <td>0.2098</td>\n",
" <td>0.8663</td>\n",
" <td>0.6869</td>\n",
" <td>0.2575</td>\n",
" <td>0.6638</td>\n",
" <td>0.17300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>20.29</td>\n",
" <td>14.34</td>\n",
" <td>135.10</td>\n",
" <td>1297.0</td>\n",
" <td>0.10030</td>\n",
" <td>0.13280</td>\n",
" <td>0.1980</td>\n",
" <td>0.10430</td>\n",
" <td>0.1809</td>\n",
" <td>0.05883</td>\n",
" <td>...</td>\n",
" <td>22.54</td>\n",
" <td>16.67</td>\n",
" <td>152.20</td>\n",
" <td>1575.0</td>\n",
" <td>0.1374</td>\n",
" <td>0.2050</td>\n",
" <td>0.4000</td>\n",
" <td>0.1625</td>\n",
" <td>0.2364</td>\n",
" <td>0.07678</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 30 columns</p>\n",
"</div>"
],
"text/plain": [
" mean radius mean texture mean perimeter mean area mean smoothness \\\n",
"0 17.99 10.38 122.80 1001.0 0.11840 \n",
"1 20.57 17.77 132.90 1326.0 0.08474 \n",
"2 19.69 21.25 130.00 1203.0 0.10960 \n",
"3 11.42 20.38 77.58 386.1 0.14250 \n",
"4 20.29 14.34 135.10 1297.0 0.10030 \n",
"\n",
" mean compactness mean concavity mean concave points mean symmetry \\\n",
"0 0.27760 0.3001 0.14710 0.2419 \n",
"1 0.07864 0.0869 0.07017 0.1812 \n",
"2 0.15990 0.1974 0.12790 0.2069 \n",
"3 0.28390 0.2414 0.10520 0.2597 \n",
"4 0.13280 0.1980 0.10430 0.1809 \n",
"\n",
" mean fractal dimension ... worst radius worst texture worst perimeter \\\n",
"0 0.07871 ... 25.38 17.33 184.60 \n",
"1 0.05667 ... 24.99 23.41 158.80 \n",
"2 0.05999 ... 23.57 25.53 152.50 \n",
"3 0.09744 ... 14.91 26.50 98.87 \n",
"4 0.05883 ... 22.54 16.67 152.20 \n",
"\n",
" worst area worst smoothness worst compactness worst concavity \\\n",
"0 2019.0 0.1622 0.6656 0.7119 \n",
"1 1956.0 0.1238 0.1866 0.2416 \n",
"2 1709.0 0.1444 0.4245 0.4504 \n",
"3 567.7 0.2098 0.8663 0.6869 \n",
"4 1575.0 0.1374 0.2050 0.4000 \n",
"\n",
" worst concave points worst symmetry worst fractal dimension \n",
"0 0.2654 0.4601 0.11890 \n",
"1 0.1860 0.2750 0.08902 \n",
"2 0.2430 0.3613 0.08758 \n",
"3 0.2575 0.6638 0.17300 \n",
"4 0.1625 0.2364 0.07678 \n",
"\n",
"[5 rows x 30 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.datasets import load_breast_cancer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import VotingClassifier, StackingClassifier\n",
"from sklearn.svm import SVC\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import classification_report\n",
"import matplotlib.pyplot as plt\n",
"import warnings\n",
"import pandas as pd\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"# Load the data and format it as a dataframe\n",
"cancer = load_breast_cancer(as_frame=True)\n",
"\n",
"# print the first 5 rows of the data\n",
"display(cancer.data.head())\n",
"\n",
"# training model\n",
"def model_training(model, train_data: pd.DataFrame, train_target: pd.DataFrame, test_data: pd.DataFrame, test_target: pd.DataFrame):\n",
" model.fit(train_data, train_target)\n",
" return model\n",
"\n",
"\n",
"# print the accuracy of the model\n",
"def model_evaluation(model, test_data: pd.DataFrame, predict_target: pd.DataFrame):\n",
" print(f'Model Evaluation:{ model.score(test_data,predict_target)}\\n')\n",
"\n",
"\n",
"# print classification report\n",
"def print_classification_report(model, test_data: pd.DataFrame, test_target: pd.DataFrame):\n",
" predict_target = model.predict(test_data)\n",
" print(f'Classification Report:\\n{classification_report(y_true=test_target, y_pred=predict_target)}\\n')\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======== SVC() Model Training ========\n",
"Model Evaluation:0.9473684210526315\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.86 0.92 43\n",
" 1 0.92 1.00 0.96 71\n",
"\n",
" accuracy 0.95 114\n",
" macro avg 0.96 0.93 0.94 114\n",
"weighted avg 0.95 0.95 0.95 114\n",
"\n",
"\n",
"====================================================================================================\n",
"======== DecisionTreeClassifier() Model Training ========\n",
"Model Evaluation:0.9385964912280702\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" 0 0.93 0.91 0.92 43\n",
" 1 0.94 0.96 0.95 71\n",
"\n",
" accuracy 0.94 114\n",
" macro avg 0.94 0.93 0.93 114\n",
"weighted avg 0.94 0.94 0.94 114\n",
"\n",
"\n",
"====================================================================================================\n",
"======== LogisticRegression() Model Training ========\n",
"Model Evaluation:0.956140350877193\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" 0 0.97 0.91 0.94 43\n",
" 1 0.95 0.99 0.97 71\n",
"\n",
" accuracy 0.96 114\n",
" macro avg 0.96 0.95 0.95 114\n",
"weighted avg 0.96 0.96 0.96 114\n",
"\n",
"\n",
"====================================================================================================\n"
]
}
],
"source": [
"\n",
"# split the data into training and testing\n",
"cancer_data_train,cancer_data_test,cancer_target_train,cancer_target_test = train_test_split(cancer.data,cancer.target, test_size=0.2, random_state=42)\n",
"\n",
"# single model list\n",
"models = [SVC(),DecisionTreeClassifier(),LogisticRegression()]\n",
"\n",
"\n",
"# single model training\n",
"for model in models:\n",
" print(f'======== { model } Model Training ========')\n",
" model_training(model=model, train_data=cancer_data_train, train_target=cancer_target_train, test_data=cancer_data_test, test_target=cancer_target_test)\n",
" model_evaluation(model=model, test_data=cancer_data_test, predict_target=cancer_target_test)\n",
" print_classification_report(model=model, test_data=cancer_data_test, test_target=cancer_target_test)\n",
" print('====================================================================================================')\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======== VotingClassifier(estimators=[('svm', SVC()), ('dt', DecisionTreeClassifier()),\n",
" ('lr', LogisticRegression())]) Model Training ========\n",
"Model Evaluation:0.9649122807017544\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.91 0.95 43\n",
" 1 0.95 1.00 0.97 71\n",
"\n",
" accuracy 0.96 114\n",
" macro avg 0.97 0.95 0.96 114\n",
"weighted avg 0.97 0.96 0.96 114\n",
"\n",
"\n",
"====================================================================================================\n",
"======== VotingClassifier(estimators=[('svm', SVC(probability=True)),\n",
" ('dt', DecisionTreeClassifier()),\n",
" ('lr', LogisticRegression())],\n",
" voting='soft') Model Training ========\n",
"Model Evaluation:0.9649122807017544\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.91 0.95 43\n",
" 1 0.95 1.00 0.97 71\n",
"\n",
" accuracy 0.96 114\n",
" macro avg 0.97 0.95 0.96 114\n",
"weighted avg 0.97 0.96 0.96 114\n",
"\n",
"\n",
"====================================================================================================\n",
"======== StackingClassifier(estimators=[('svm', SVC()),\n",
" ('ada', DecisionTreeClassifier()),\n",
" ('lr', LogisticRegression())],\n",
" final_estimator=DecisionTreeClassifier()) Model Training ========\n",
"Model Evaluation:0.9824561403508771\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.95 0.98 43\n",
" 1 0.97 1.00 0.99 71\n",
"\n",
" accuracy 0.98 114\n",
" macro avg 0.99 0.98 0.98 114\n",
"weighted avg 0.98 0.98 0.98 114\n",
"\n",
"\n",
"====================================================================================================\n",
"======== StackingClassifier(estimators=[('svm', SVC()),\n",
" ('ada', DecisionTreeClassifier()),\n",
" ('lr', LogisticRegression())],\n",
" final_estimator=LogisticRegression()) Model Training ========\n",
"Model Evaluation:0.9649122807017544\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.91 0.95 43\n",
" 1 0.95 1.00 0.97 71\n",
"\n",
" accuracy 0.96 114\n",
" macro avg 0.97 0.95 0.96 114\n",
"weighted avg 0.97 0.96 0.96 114\n",
"\n",
"\n",
"====================================================================================================\n"
]
}
],
"source": [
"# voting classifier, voting='hard' means majority voting\n",
"voting_model_hard = VotingClassifier(estimators=[('svm',SVC()),('dt',DecisionTreeClassifier()),('lr',LogisticRegression())],voting='hard')\n",
"\n",
"# voting classifier, voting='soft' means weighted voting\n",
"voting_model_soft = VotingClassifier(estimators=[('svm',SVC(probability=True)),('dt',DecisionTreeClassifier()),('lr',LogisticRegression())],voting='soft')\n",
"\n",
"# stacking classifier, final_estimator is the final estimator which will be used to predict the output\n",
"stack_model_dst = StackingClassifier(estimators=[('svm',SVC()),('ada',DecisionTreeClassifier()),('lr',LogisticRegression())],final_estimator=DecisionTreeClassifier())\n",
"\n",
"# stacking classifier, final_estimator is the final estimator which will be used to predict the output\n",
"stack_model_lr = StackingClassifier(estimators=[('svm',SVC()),('ada',DecisionTreeClassifier()),('lr',LogisticRegression())],final_estimator=LogisticRegression())\n",
"\n",
"\n",
"ensemble_models = [voting_model_hard, voting_model_soft, stack_model_dst, stack_model_lr]\n",
"# ensemble model training\n",
"for model in ensemble_models:\n",
" print(f'======== { model } Model Training ========')\n",
" model_training(model=model, train_data=cancer_data_train, train_target=cancer_target_train, test_data=cancer_data_test, test_target=cancer_target_test)\n",
" model_evaluation(model=model, test_data=cancer_data_test, predict_target=cancer_target_test)\n",
" print_classification_report(model=model, test_data=cancer_data_test, test_target=cancer_target_test)\n",
" print('====================================================================================================')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.11 ('Research')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "df831c49162da7e36ee390de774850456fe39ea2aa2f18de316c1c9a8a009a8f"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment