Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save pierrelouisbescond/f6ee711ca82cac14d794005ea2fb251a to your computer and use it in GitHub Desktop.
Save pierrelouisbescond/f6ee711ca82cac14d794005ea2fb251a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Attrition_Model_CatBoost_and_Shap.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "oR3EvrN3Q1Xc"
},
"source": [
"!pip install shap\n",
"!pip install catboost"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gKNi-41-jdzH"
},
"source": [
"import pandas as pd\n",
"\n",
"from imblearn.under_sampling import ClusterCentroids\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"from catboost import CatBoostClassifier\n",
"\n",
"import shap\n",
"\n",
"# The following lines should be used only on Google Colab\n",
"# to connectto Google Drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8gh1V89a7aMy"
},
"source": [
"## Data Exploration"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ua_0AE50nT3F",
"outputId": "8fec3852-69be-4073-a6e2-c33f6c24f8b4",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"df = pd.read_csv(\"/content/drive/My Drive/Medium/WA_Fn-UseC_-HR-Employee-Attrition.csv\", index_col=9)\n",
"df.info()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Int64Index: 1470 entries, 1 to 2068\n",
"Data columns (total 34 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Age 1470 non-null int64 \n",
" 1 Attrition 1470 non-null object\n",
" 2 BusinessTravel 1470 non-null object\n",
" 3 DailyRate 1470 non-null int64 \n",
" 4 Department 1470 non-null object\n",
" 5 DistanceFromHome 1470 non-null int64 \n",
" 6 Education 1470 non-null int64 \n",
" 7 EducationField 1470 non-null object\n",
" 8 EmployeeCount 1470 non-null int64 \n",
" 9 EnvironmentSatisfaction 1470 non-null int64 \n",
" 10 Gender 1470 non-null object\n",
" 11 HourlyRate 1470 non-null int64 \n",
" 12 JobInvolvement 1470 non-null int64 \n",
" 13 JobLevel 1470 non-null int64 \n",
" 14 JobRole 1470 non-null object\n",
" 15 JobSatisfaction 1470 non-null int64 \n",
" 16 MaritalStatus 1470 non-null object\n",
" 17 MonthlyIncome 1470 non-null int64 \n",
" 18 MonthlyRate 1470 non-null int64 \n",
" 19 NumCompaniesWorked 1470 non-null int64 \n",
" 20 Over18 1470 non-null object\n",
" 21 OverTime 1470 non-null object\n",
" 22 PercentSalaryHike 1470 non-null int64 \n",
" 23 PerformanceRating 1470 non-null int64 \n",
" 24 RelationshipSatisfaction 1470 non-null int64 \n",
" 25 StandardHours 1470 non-null int64 \n",
" 26 StockOptionLevel 1470 non-null int64 \n",
" 27 TotalWorkingYears 1470 non-null int64 \n",
" 28 TrainingTimesLastYear 1470 non-null int64 \n",
" 29 WorkLifeBalance 1470 non-null int64 \n",
" 30 YearsAtCompany 1470 non-null int64 \n",
" 31 YearsInCurrentRole 1470 non-null int64 \n",
" 32 YearsSinceLastPromotion 1470 non-null int64 \n",
" 33 YearsWithCurrManager 1470 non-null int64 \n",
"dtypes: int64(25), object(9)\n",
"memory usage: 402.0+ KB\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pRstcmfltm6U",
"outputId": "2f56aff0-84d8-4827-d931-74f63ab0c27b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"# Detect columns with one unique value\n",
"print(df.columns[df.nunique()==1])\n",
"\n",
"df = df.drop([\"EmployeeCount\",\"StandardHours\",\"Over18\"], axis=1)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"Index(['EmployeeCount', 'Over18', 'StandardHours'], dtype='object')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TYx_8bGkzR7E",
"outputId": "6a3dc8f7-2714-4c49-c97e-0d3d5ad44f5c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 571
}
},
"source": [
"# When relevant, convert categorical columns to numerical ones\n",
"yes_no_to_0_1_dict = {\"Yes\": 1, \"No\": 0}\n",
"business_travel_dict = {\"Non-Travel\": 0,\n",
" \"Travel_Rarely\": 1,\n",
" \"Travel_Frequently\": 2\n",
" }\n",
"\n",
"df = df.replace({\"Attrition\":yes_no_to_0_1_dict})\n",
"df = df.replace({\"OverTime\":yes_no_to_0_1_dict})\n",
"df = df.replace({\"BusinessTravel\":business_travel_dict})\n",
"\n",
"df.sample(10)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Age</th>\n",
" <th>Attrition</th>\n",
" <th>BusinessTravel</th>\n",
" <th>DailyRate</th>\n",
" <th>Department</th>\n",
" <th>DistanceFromHome</th>\n",
" <th>Education</th>\n",
" <th>EducationField</th>\n",
" <th>EnvironmentSatisfaction</th>\n",
" <th>Gender</th>\n",
" <th>HourlyRate</th>\n",
" <th>JobInvolvement</th>\n",
" <th>JobLevel</th>\n",
" <th>JobRole</th>\n",
" <th>JobSatisfaction</th>\n",
" <th>MaritalStatus</th>\n",
" <th>MonthlyIncome</th>\n",
" <th>MonthlyRate</th>\n",
" <th>NumCompaniesWorked</th>\n",
" <th>OverTime</th>\n",
" <th>PercentSalaryHike</th>\n",
" <th>PerformanceRating</th>\n",
" <th>RelationshipSatisfaction</th>\n",
" <th>StockOptionLevel</th>\n",
" <th>TotalWorkingYears</th>\n",
" <th>TrainingTimesLastYear</th>\n",
" <th>WorkLifeBalance</th>\n",
" <th>YearsAtCompany</th>\n",
" <th>YearsInCurrentRole</th>\n",
" <th>YearsSinceLastPromotion</th>\n",
" <th>YearsWithCurrManager</th>\n",
" </tr>\n",
" <tr>\n",
" <th>EmployeeNumber</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>424</th>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>106</td>\n",
" <td>Human Resources</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>Human Resources</td>\n",
" <td>1</td>\n",
" <td>Male</td>\n",
" <td>62</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>Human Resources</td>\n",
" <td>1</td>\n",
" <td>Married</td>\n",
" <td>6410</td>\n",
" <td>17822</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>12</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1604</th>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>329</td>\n",
" <td>Research &amp; Development</td>\n",
" <td>24</td>\n",
" <td>3</td>\n",
" <td>Medical</td>\n",
" <td>3</td>\n",
" <td>Male</td>\n",
" <td>51</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>Laboratory Technician</td>\n",
" <td>2</td>\n",
" <td>Married</td>\n",
" <td>2408</td>\n",
" <td>7324</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>17</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1408</th>\n",
" <td>42</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1147</td>\n",
" <td>Human Resources</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>Human Resources</td>\n",
" <td>3</td>\n",
" <td>Female</td>\n",
" <td>31</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>Manager</td>\n",
" <td>1</td>\n",
" <td>Married</td>\n",
" <td>16799</td>\n",
" <td>16616</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>14</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>21</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1981</th>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>771</td>\n",
" <td>Research &amp; Development</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Life Sciences</td>\n",
" <td>2</td>\n",
" <td>Male</td>\n",
" <td>45</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>Healthcare Representative</td>\n",
" <td>3</td>\n",
" <td>Single</td>\n",
" <td>4617</td>\n",
" <td>14120</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>12</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1028</th>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>802</td>\n",
" <td>Research &amp; Development</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>Other</td>\n",
" <td>2</td>\n",
" <td>Male</td>\n",
" <td>45</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>Laboratory Technician</td>\n",
" <td>4</td>\n",
" <td>Divorced</td>\n",
" <td>3917</td>\n",
" <td>9541</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>20</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1489</th>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>790</td>\n",
" <td>Sales</td>\n",
" <td>24</td>\n",
" <td>4</td>\n",
" <td>Medical</td>\n",
" <td>1</td>\n",
" <td>Female</td>\n",
" <td>40</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>Sales Executive</td>\n",
" <td>2</td>\n",
" <td>Single</td>\n",
" <td>4599</td>\n",
" <td>7815</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>23</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>16</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>15</td>\n",
" <td>9</td>\n",
" <td>10</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>857</th>\n",
" <td>36</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>928</td>\n",
" <td>Sales</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Life Sciences</td>\n",
" <td>2</td>\n",
" <td>Male</td>\n",
" <td>56</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>Sales Executive</td>\n",
" <td>4</td>\n",
" <td>Married</td>\n",
" <td>6201</td>\n",
" <td>2823</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>14</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>18</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>18</td>\n",
" <td>14</td>\n",
" <td>4</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>951</th>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1184</td>\n",
" <td>Research &amp; Development</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Life Sciences</td>\n",
" <td>3</td>\n",
" <td>Female</td>\n",
" <td>70</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>Laboratory Technician</td>\n",
" <td>2</td>\n",
" <td>Married</td>\n",
" <td>2332</td>\n",
" <td>3974</td>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>20</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1044</th>\n",
" <td>33</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1038</td>\n",
" <td>Sales</td>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>Life Sciences</td>\n",
" <td>2</td>\n",
" <td>Female</td>\n",
" <td>88</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>Sales Representative</td>\n",
" <td>4</td>\n",
" <td>Single</td>\n",
" <td>2342</td>\n",
" <td>21437</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>19</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>930</th>\n",
" <td>28</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>995</td>\n",
" <td>Research &amp; Development</td>\n",
" <td>9</td>\n",
" <td>3</td>\n",
" <td>Medical</td>\n",
" <td>3</td>\n",
" <td>Female</td>\n",
" <td>77</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>Research Scientist</td>\n",
" <td>3</td>\n",
" <td>Divorced</td>\n",
" <td>2377</td>\n",
" <td>9834</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>18</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Age Attrition ... YearsSinceLastPromotion YearsWithCurrManager\n",
"EmployeeNumber ... \n",
"424 31 0 ... 1 0\n",
"1604 28 1 ... 0 0\n",
"1408 42 0 ... 0 9\n",
"1981 24 0 ... 1 2\n",
"1028 35 0 ... 1 2\n",
"1489 34 1 ... 10 10\n",
"857 36 0 ... 4 11\n",
"951 32 0 ... 0 2\n",
"1044 33 0 ... 2 2\n",
"930 28 0 ... 2 2\n",
"\n",
"[10 rows x 31 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "l3GTlc6RfeJ4",
"outputId": "a3ca943c-8e2c-46e7-d22c-99ad876ac624",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 381
}
},
"source": [
"# Display original MaritalStatus column\n",
"df.filter(like=\"Marital\").sample(10, random_state=22)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MaritalStatus</th>\n",
" </tr>\n",
" <tr>\n",
" <th>EmployeeNumber</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>862</th>\n",
" <td>Married</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1464</th>\n",
" <td>Married</td>\n",
" </tr>\n",
" <tr>\n",
" <th>349</th>\n",
" <td>Divorced</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1293</th>\n",
" <td>Single</td>\n",
" </tr>\n",
" <tr>\n",
" <th>55</th>\n",
" <td>Single</td>\n",
" </tr>\n",
" <tr>\n",
" <th>208</th>\n",
" <td>Married</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1792</th>\n",
" <td>Divorced</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2040</th>\n",
" <td>Single</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1858</th>\n",
" <td>Divorced</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2053</th>\n",
" <td>Married</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MaritalStatus\n",
"EmployeeNumber \n",
"862 Married\n",
"1464 Married\n",
"349 Divorced\n",
"1293 Single\n",
"55 Single\n",
"208 Married\n",
"1792 Divorced\n",
"2040 Single\n",
"1858 Divorced\n",
"2053 Married"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "w-JVZn_5-UG-"
},
"source": [
"# Convert categorical columns to numerical vectors\n",
"df = pd.get_dummies(df)"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ann9jm4ZhoKa",
"outputId": "00d50811-209b-4f8a-fb09-95879c55b206",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 381
}
},
"source": [
"# Display final MaritalStatus columns\n",
"df.filter(like=\"Marital\").sample(10, random_state=22)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MaritalStatus_Divorced</th>\n",
" <th>MaritalStatus_Married</th>\n",
" <th>MaritalStatus_Single</th>\n",
" </tr>\n",
" <tr>\n",
" <th>EmployeeNumber</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>862</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1464</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>349</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1293</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>55</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>208</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1792</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2040</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1858</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2053</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MaritalStatus_Divorced ... MaritalStatus_Single\n",
"EmployeeNumber ... \n",
"862 0 ... 0\n",
"1464 0 ... 0\n",
"349 1 ... 0\n",
"1293 0 ... 1\n",
"55 0 ... 1\n",
"208 0 ... 0\n",
"1792 1 ... 0\n",
"2040 0 ... 1\n",
"1858 1 ... 0\n",
"2053 0 ... 0\n",
"\n",
"[10 rows x 3 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_far9-SVrCgd",
"outputId": "e2eb6f68-4dcb-47fb-bc00-dd6c6ecc07e1",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"# Assess target imbalance\n",
"df.groupby(\"Attrition\")[\"Attrition\"].count()"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Attrition\n",
"0 1233\n",
"1 237\n",
"Name: Attrition, dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QEzctfoKxolS"
},
"source": [
"# Prepare train and test dataset\n",
"X = df.drop(\"Attrition\", axis=1)\n",
"y = df[\"Attrition\"]\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=22)"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9PPCl1mg7fox"
},
"source": [
"## Models training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pxk23-gCycMS"
},
"source": [
"### Model Training - Round 1"
]
},
{
"cell_type": "code",
"metadata": {
"id": "UHD28CdDRTdC",
"outputId": "48d6fdf8-6a1b-4d17-c9b6-4f4b08c9e74a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"model = CatBoostClassifier(iterations=500, verbose=100, eval_metric=\"Recall\")\n",
"model.fit(X_train, y_train);"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"Learning rate set to 0.020276\n",
"0:\tlearn: 0.2701149\ttotal: 52.4ms\tremaining: 26.1s\n",
"100:\tlearn: 0.2816092\ttotal: 334ms\tremaining: 1.32s\n",
"200:\tlearn: 0.5287356\ttotal: 614ms\tremaining: 914ms\n",
"300:\tlearn: 0.7011494\ttotal: 901ms\tremaining: 596ms\n",
"400:\tlearn: 0.8275862\ttotal: 1.19s\tremaining: 293ms\n",
"499:\tlearn: 0.9022989\ttotal: 1.46s\tremaining: 0us\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AA9jA1IWqnRR",
"outputId": "fd399e65-1c40-4b16-ee44-46468ee96774",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 393
}
},
"source": [
"print(\"TRAIN PERFORMANCE:\\n\")\n",
"confusion_matrix_train = confusion_matrix(y_train, model.predict(X_train))\n",
"confusion_matrix_train = pd.DataFrame(confusion_matrix_train,\n",
" index=[\"Actual_No\",\"Actual_Yes\"],\n",
" columns=[\"Predicted_No\",\"Predicted_Yes\"])\n",
"\n",
"display(confusion_matrix_train)\n",
"\n",
"recall_resignation_train = confusion_matrix_train.iloc[1,1] / confusion_matrix_train.iloc[1,:].sum()\n",
"print(\"Train Score: {}\".format(round(model.score(X_train,y_train),3)))\n",
"print(\"Train Recall: {}\".format(round(recall_resignation_train,3)))\n",
"print(\"\\n* * * * * * * * * * * * * * * * * * *\\n\")\n",
"print(\"TEST PERFORMANCE:\\n\")\n",
"confusion_matrix_test = confusion_matrix(y_test, model.predict(X_test))\n",
"confusion_matrix_test = pd.DataFrame(confusion_matrix_test,\n",
" index=[\"Actual_No\",\"Actual_Yes\"],\n",
" columns=[\"Predicted_No\",\"Predicted_Yes\"])\n",
"\n",
"display(confusion_matrix_test)\n",
"\n",
"recall_resignation_test = confusion_matrix_test.iloc[1,1] / confusion_matrix_test.iloc[1,:].sum()\n",
"print(\"Test Score: {}\".format(round(model.score(X_test,y_test),3)))\n",
"print(\"Test Recall: {}\".format(round(recall_resignation_test,3)))"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"TRAIN PERFORMANCE:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Predicted_No</th>\n",
" <th>Predicted_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Actual_No</th>\n",
" <td>928</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Actual_Yes</th>\n",
" <td>17</td>\n",
" <td>157</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Predicted_No Predicted_Yes\n",
"Actual_No 928 0\n",
"Actual_Yes 17 157"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Train Score: 0.985\n",
"Train Recall: 0.902\n",
"\n",
"* * * * * * * * * * * * * * * * * * *\n",
"\n",
"TEST PERFORMANCE:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Predicted_No</th>\n",
" <th>Predicted_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Actual_No</th>\n",
" <td>302</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Actual_Yes</th>\n",
" <td>51</td>\n",
" <td>12</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Predicted_No Predicted_Yes\n",
"Actual_No 302 3\n",
"Actual_Yes 51 12"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Test Score: 0.853\n",
"Test Recall: 0.19\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j80u2hno6zYB"
},
"source": [
"### Model Training - Round 2 (using class weights)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "blKrpA5TyWke",
"outputId": "a27f0724-97f3-4b40-bd7d-d2dac077fad1",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"class_weights = dict({0:1, 1:5})\n",
"\n",
"model = CatBoostClassifier(iterations=500,\n",
" verbose=100,\n",
" eval_metric=\"Recall\",\n",
" class_weights=class_weights)\n",
"\n",
"model.fit(X_train,y_train);"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"Learning rate set to 0.020276\n",
"0:\tlearn: 0.6666667\ttotal: 3.4ms\tremaining: 1.7s\n",
"100:\tlearn: 0.8793103\ttotal: 309ms\tremaining: 1.22s\n",
"200:\tlearn: 0.9425287\ttotal: 584ms\tremaining: 868ms\n",
"300:\tlearn: 0.9712644\ttotal: 875ms\tremaining: 578ms\n",
"400:\tlearn: 0.9942529\ttotal: 1.18s\tremaining: 292ms\n",
"499:\tlearn: 1.0000000\ttotal: 1.5s\tremaining: 0us\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "7wwRMqj9ytx-",
"outputId": "208dc35e-345d-4fea-f2e9-33d2f1b1402e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 393
}
},
"source": [
"print(\"TRAIN PERFORMANCE:\\n\")\n",
"confusion_matrix_train = confusion_matrix(y_train, model.predict(X_train))\n",
"confusion_matrix_train = pd.DataFrame(confusion_matrix_train, index=[\"Actual_No\",\"Actual_Yes\"], columns=[\"Predicted_No\",\"Predicted_Yes\"])\n",
"\n",
"display(confusion_matrix_train)\n",
"\n",
"recall_resignation_train = confusion_matrix_train.iloc[1,1] / confusion_matrix_train.iloc[1,:].sum()\n",
"print(\"Train Score: {}\".format(round(model.score(X_train,y_train),3)))\n",
"print(\"Train Recall: {}\".format(round(recall_resignation_train,3)))\n",
"print(\"\\n* * * * * * * * * * * * * * * * * * *\\n\")\n",
"print(\"TEST PERFORMANCE:\\n\")\n",
"confusion_matrix_test = confusion_matrix(y_test, model.predict(X_test))\n",
"confusion_matrix_test = pd.DataFrame(confusion_matrix_test, index=[\"Actual_No\",\"Actual_Yes\"], columns=[\"Predicted_No\",\"Predicted_Yes\"])\n",
"\n",
"display(confusion_matrix_test)\n",
"\n",
"recall_resignation_test = confusion_matrix_test.iloc[1,1] / confusion_matrix_test.iloc[1,:].sum()\n",
"print(\"Test Score: {}\".format(round(model.score(X_test,y_test),3)))\n",
"print(\"Test Recall: {}\".format(round(recall_resignation_test,3)))"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": [
"TRAIN PERFORMANCE:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Predicted_No</th>\n",
" <th>Predicted_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Actual_No</th>\n",
" <td>923</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Actual_Yes</th>\n",
" <td>0</td>\n",
" <td>174</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Predicted_No Predicted_Yes\n",
"Actual_No 923 5\n",
"Actual_Yes 0 174"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Train Score: 0.995\n",
"Train Recall: 1.0\n",
"\n",
"* * * * * * * * * * * * * * * * * * *\n",
"\n",
"TEST PERFORMANCE:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Predicted_No</th>\n",
" <th>Predicted_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Actual_No</th>\n",
" <td>288</td>\n",
" <td>17</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Actual_Yes</th>\n",
" <td>38</td>\n",
" <td>25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Predicted_No Predicted_Yes\n",
"Actual_No 288 17\n",
"Actual_Yes 38 25"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Test Score: 0.851\n",
"Test Recall: 0.397\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4Ui5wonY67gS"
},
"source": [
"### Model Training - Round 3 (using under-sampling)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZUWDDoh7EkB_"
},
"source": [
"cc = ClusterCentroids()\n",
"X_cc, y_cc = cc.fit_resample(X, y)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X_cc, y_cc)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3FgCvg5JD66F",
"outputId": "fa66b305-4f3a-4fb9-b20c-cfdeef24b82c",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"class_weights = dict({0:1, 1:5})\n",
"\n",
"model = CatBoostClassifier(iterations=500,\n",
" verbose=100,\n",
" eval_metric=\"Recall\",\n",
" class_weights=class_weights)\n",
"\n",
"model.fit(X_train, y_train);"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"Learning rate set to 0.0125\n",
"0:\tlearn: 0.9891304\ttotal: 8.77ms\tremaining: 4.38s\n",
"100:\tlearn: 1.0000000\ttotal: 425ms\tremaining: 1.68s\n",
"200:\tlearn: 1.0000000\ttotal: 807ms\tremaining: 1.2s\n",
"300:\tlearn: 1.0000000\ttotal: 1.18s\tremaining: 782ms\n",
"400:\tlearn: 1.0000000\ttotal: 1.59s\tremaining: 393ms\n",
"499:\tlearn: 1.0000000\ttotal: 1.97s\tremaining: 0us\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ItbNUQneD_pE",
"outputId": "65f5ba04-624a-44b7-ef5c-b7259e255edb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 393
}
},
"source": [
"print(\"TRAIN PERFORMANCE:\\n\")\n",
"confusion_matrix_train = confusion_matrix(y_train, model.predict(X_train))\n",
"confusion_matrix_train = pd.DataFrame(confusion_matrix_train, index=[\"Actual_No\",\"Actual_Yes\"], columns=[\"Predicted_No\",\"Predicted_Yes\"])\n",
"\n",
"display(confusion_matrix_train)\n",
"\n",
"recall_resignation_train = confusion_matrix_train.iloc[1,1] / confusion_matrix_train.iloc[1,:].sum()\n",
"print(\"Train Score: {}\".format(round(model.score(X_train,y_train),3)))\n",
"print(\"Train Recall: {}\".format(round(recall_resignation_train,3)))\n",
"print(\"\\n* * * * * * * * * * * * * * * * * * *\\n\")\n",
"print(\"TEST PERFORMANCE:\\n\")\n",
"confusion_matrix_test = confusion_matrix(y_test, model.predict(X_test))\n",
"confusion_matrix_test = pd.DataFrame(confusion_matrix_test, index=[\"Actual_No\",\"Actual_Yes\"], columns=[\"Predicted_No\",\"Predicted_Yes\"])\n",
"\n",
"display(confusion_matrix_test)\n",
"\n",
"recall_resignation_test = confusion_matrix_test.iloc[1,1] / confusion_matrix_test.iloc[1,:].sum()\n",
"print(\"Test Score: {}\".format(round(model.score(X_test,y_test),3)))\n",
"print(\"Test Recall: {}\".format(round(recall_resignation_test,3)))"
],
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": [
"TRAIN PERFORMANCE:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Predicted_No</th>\n",
" <th>Predicted_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Actual_No</th>\n",
" <td>169</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Actual_Yes</th>\n",
" <td>0</td>\n",
" <td>184</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Predicted_No Predicted_Yes\n",
"Actual_No 169 2\n",
"Actual_Yes 0 184"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Train Score: 0.994\n",
"Train Recall: 1.0\n",
"\n",
"* * * * * * * * * * * * * * * * * * *\n",
"\n",
"TEST PERFORMANCE:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Predicted_No</th>\n",
" <th>Predicted_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Actual_No</th>\n",
" <td>60</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Actual_Yes</th>\n",
" <td>0</td>\n",
" <td>53</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Predicted_No Predicted_Yes\n",
"Actual_No 60 6\n",
"Actual_Yes 0 53"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Test Score: 0.95\n",
"Test Recall: 1.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PeJiZmld7HvQ"
},
"source": [
"## Explaining predictions with SHAP"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cFjs2HYE7LXS"
},
"source": [
"### Original Features Importance with SciKit-Learn"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jilO70BwWTfJ",
"outputId": "bee458cf-1ca3-46d7-c37c-0dd3b0b6ee2d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 502
}
},
"source": [
"feat_imp = pd.DataFrame(model.feature_importances_, index=X.columns, columns=[\"Importance\"])\n",
"feat_imp.sort_values(by=\"Importance\", ascending=False).head(15)"
],
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Importance</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>MaritalStatus_Married</th>\n",
" <td>14.131075</td>\n",
" </tr>\n",
" <tr>\n",
" <th>StockOptionLevel</th>\n",
" <td>6.222870</td>\n",
" </tr>\n",
" <tr>\n",
" <th>OverTime</th>\n",
" <td>5.985478</td>\n",
" </tr>\n",
" <tr>\n",
" <th>MaritalStatus_Single</th>\n",
" <td>5.763319</td>\n",
" </tr>\n",
" <tr>\n",
" <th>EducationField_Medical</th>\n",
" <td>5.758970</td>\n",
" </tr>\n",
" <tr>\n",
" <th>MaritalStatus_Divorced</th>\n",
" <td>5.431973</td>\n",
" </tr>\n",
" <tr>\n",
" <th>EducationField_Life Sciences</th>\n",
" <td>4.629362</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Gender_Female</th>\n",
" <td>3.996923</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Department_Research &amp; Development</th>\n",
" <td>3.526615</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Gender_Male</th>\n",
" <td>2.968202</td>\n",
" </tr>\n",
" <tr>\n",
" <th>JobRole_Research Director</th>\n",
" <td>2.927788</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Department_Sales</th>\n",
" <td>2.926049</td>\n",
" </tr>\n",
" <tr>\n",
" <th>YearsWithCurrManager</th>\n",
" <td>2.184329</td>\n",
" </tr>\n",
" <tr>\n",
" <th>MonthlyIncome</th>\n",
" <td>2.056980</td>\n",
" </tr>\n",
" <tr>\n",
" <th>EnvironmentSatisfaction</th>\n",
" <td>1.841916</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Importance\n",
"MaritalStatus_Married 14.131075\n",
"StockOptionLevel 6.222870\n",
"OverTime 5.985478\n",
"MaritalStatus_Single 5.763319\n",
"EducationField_Medical 5.758970\n",
"MaritalStatus_Divorced 5.431973\n",
"EducationField_Life Sciences 4.629362\n",
"Gender_Female 3.996923\n",
"Department_Research & Development 3.526615\n",
"Gender_Male 2.968202\n",
"JobRole_Research Director 2.927788\n",
"Department_Sales 2.926049\n",
"YearsWithCurrManager 2.184329\n",
"MonthlyIncome 2.056980\n",
"EnvironmentSatisfaction 1.841916"
]
},
"metadata": {
"tags": []
},
"execution_count": 18
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EoXcDEyWZrU"
},
"source": [
"### SHAP Explanations"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ny8ls2GebF5q"
},
"source": [
"shap_explainer = shap.TreeExplainer(model)\n",
"shap_values = shap_explainer.shap_values(X)"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2cFKFFVhbIF3",
"outputId": "d74df3cf-4a41-4136-d398-2e713d38843a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 896
}
},
"source": [
"Employee_ID = 1\n",
"employees_profile = df[df.index == Employee_ID]\n",
"employees_profile = employees_profile.loc[:, (employees_profile != 0).any(axis=0)]\n",
"employees_profile.T"
],
"execution_count": 20,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>EmployeeNumber</th>\n",
" <th>1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Age</th>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Attrition</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>BusinessTravel</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>DailyRate</th>\n",
" <td>1102</td>\n",
" </tr>\n",
" <tr>\n",
" <th>DistanceFromHome</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Education</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>EnvironmentSatisfaction</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>HourlyRate</th>\n",
" <td>94</td>\n",
" </tr>\n",
" <tr>\n",
" <th>JobInvolvement</th>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>JobLevel</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>JobSatisfaction</th>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>MonthlyIncome</th>\n",
" <td>5993</td>\n",
" </tr>\n",
" <tr>\n",
" <th>MonthlyRate</th>\n",
" <td>19479</td>\n",
" </tr>\n",
" <tr>\n",
" <th>NumCompaniesWorked</th>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>OverTime</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>PercentSalaryHike</th>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>PerformanceRating</th>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>RelationshipSatisfaction</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>TotalWorkingYears</th>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>WorkLifeBalance</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>YearsAtCompany</th>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>YearsInCurrentRole</th>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>YearsWithCurrManager</th>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Department_Sales</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>EducationField_Life Sciences</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Gender_Female</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>JobRole_Sales Executive</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>MaritalStatus_Single</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"EmployeeNumber 1\n",
"Age 41\n",
"Attrition 1\n",
"BusinessTravel 1\n",
"DailyRate 1102\n",
"DistanceFromHome 1\n",
"Education 2\n",
"EnvironmentSatisfaction 2\n",
"HourlyRate 94\n",
"JobInvolvement 3\n",
"JobLevel 2\n",
"JobSatisfaction 4\n",
"MonthlyIncome 5993\n",
"MonthlyRate 19479\n",
"NumCompaniesWorked 8\n",
"OverTime 1\n",
"PercentSalaryHike 11\n",
"PerformanceRating 3\n",
"RelationshipSatisfaction 1\n",
"TotalWorkingYears 8\n",
"WorkLifeBalance 1\n",
"YearsAtCompany 6\n",
"YearsInCurrentRole 4\n",
"YearsWithCurrManager 5\n",
"Department_Sales 1\n",
"EducationField_Life Sciences 1\n",
"Gender_Female 1\n",
"JobRole_Sales Executive 1\n",
"MaritalStatus_Single 1"
]
},
"metadata": {
"tags": []
},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "07kY92G7baNf"
},
"source": [
"shap.initjs()\n",
"index_choice = df.index.get_loc(Employee_ID)\n",
"shap.force_plot(shap_explainer.expected_value, shap_values[index_choice], X.iloc[index_choice])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "y9eTLRI8y7MN"
},
"source": [
"shap.initjs()\n",
"shap.force_plot(shap_explainer.expected_value, shap_values, X)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "99k078Try9ev"
},
"source": [
"shap.summary_plot(shap_values, X, X.columns)"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment