Last active
February 23, 2024 12:11
-
-
Save RahulDas-dev/3fda38233a3fa9ace94663ed4cdd2be5 to your computer and use it in GitHub Desktop.
Scikit-Learn Pipe Line Building + Optuna Search cv + Joblib Memory
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "056ec550", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"\n", | |
"from sklearn.pipeline import Pipeline\n", | |
"from sklearn.compose import ColumnTransformer, make_column_selector\n", | |
"# from mlxtend.feature_selection import ColumnSelector\n", | |
"from sklearn.preprocessing import StandardScaler, OneHotEncoder, FunctionTransformer, OrdinalEncoder\n", | |
"from sklearn.impute import SimpleImputer\n", | |
"from sklearn.ensemble import RandomForestClassifier\n", | |
"from joblib import Memory\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c298be18", | |
"metadata": {}, | |
"source": [ | |
"## DataLoad " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "cb238cf7", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(48842, 14) (48842,)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>age</th>\n", | |
" <th>workclass</th>\n", | |
" <th>fnlwgt</th>\n", | |
" <th>education</th>\n", | |
" <th>education-num</th>\n", | |
" <th>marital-status</th>\n", | |
" <th>occupation</th>\n", | |
" <th>relationship</th>\n", | |
" <th>race</th>\n", | |
" <th>sex</th>\n", | |
" <th>capitalgain</th>\n", | |
" <th>capitalloss</th>\n", | |
" <th>hoursperweek</th>\n", | |
" <th>native-country</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>2</td>\n", | |
" <td>State-gov</td>\n", | |
" <td>77516</td>\n", | |
" <td>Bachelors</td>\n", | |
" <td>13</td>\n", | |
" <td>Never-married</td>\n", | |
" <td>Adm-clerical</td>\n", | |
" <td>Not-in-family</td>\n", | |
" <td>White</td>\n", | |
" <td>Male</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>United-States</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>3</td>\n", | |
" <td>Self-emp-not-inc</td>\n", | |
" <td>83311</td>\n", | |
" <td>Bachelors</td>\n", | |
" <td>13</td>\n", | |
" <td>Married-civ-spouse</td>\n", | |
" <td>Exec-managerial</td>\n", | |
" <td>Husband</td>\n", | |
" <td>White</td>\n", | |
" <td>Male</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>United-States</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>2</td>\n", | |
" <td>Private</td>\n", | |
" <td>215646</td>\n", | |
" <td>HS-grad</td>\n", | |
" <td>9</td>\n", | |
" <td>Divorced</td>\n", | |
" <td>Handlers-cleaners</td>\n", | |
" <td>Not-in-family</td>\n", | |
" <td>White</td>\n", | |
" <td>Male</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>United-States</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>3</td>\n", | |
" <td>Private</td>\n", | |
" <td>234721</td>\n", | |
" <td>11th</td>\n", | |
" <td>7</td>\n", | |
" <td>Married-civ-spouse</td>\n", | |
" <td>Handlers-cleaners</td>\n", | |
" <td>Husband</td>\n", | |
" <td>Black</td>\n", | |
" <td>Male</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>United-States</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>Private</td>\n", | |
" <td>338409</td>\n", | |
" <td>Bachelors</td>\n", | |
" <td>13</td>\n", | |
" <td>Married-civ-spouse</td>\n", | |
" <td>Prof-specialty</td>\n", | |
" <td>Wife</td>\n", | |
" <td>Black</td>\n", | |
" <td>Female</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>Cuba</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" age workclass fnlwgt education education-num marital-status \\\n", | |
"0 2 State-gov 77516 Bachelors 13 Never-married \n", | |
"1 3 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse \n", | |
"2 2 Private 215646 HS-grad 9 Divorced \n", | |
"3 3 Private 234721 11th 7 Married-civ-spouse \n", | |
"4 1 Private 338409 Bachelors 13 Married-civ-spouse \n", | |
"\n", | |
" occupation relationship race sex capitalgain capitalloss \\\n", | |
"0 Adm-clerical Not-in-family White Male 1 0 \n", | |
"1 Exec-managerial Husband White Male 0 0 \n", | |
"2 Handlers-cleaners Not-in-family White Male 0 0 \n", | |
"3 Handlers-cleaners Husband Black Male 0 0 \n", | |
"4 Prof-specialty Wife Black Female 0 0 \n", | |
"\n", | |
" hoursperweek native-country \n", | |
"0 2 United-States \n", | |
"1 0 United-States \n", | |
"2 2 United-States \n", | |
"3 2 United-States \n", | |
"4 2 Cuba " | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def fetch_adult_data():\n", | |
" from sklearn.datasets import fetch_openml\n", | |
" \n", | |
" from sklearn.datasets import fetch_openml\n", | |
" openml_ds = fetch_openml(data_id=179, as_frame=True, parser='pandas')\n", | |
" dataset = openml_ds['frame']\n", | |
" return dataset\n", | |
"\n", | |
"dataset = fetch_adult_data()\n", | |
"\n", | |
"target = dataset.pop('class')\n", | |
"\n", | |
"print(dataset.shape, target.shape)\n", | |
"dataset.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "cff016e6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"RangeIndex: 48842 entries, 0 to 48841\n", | |
"Data columns (total 14 columns):\n", | |
" # Column Non-Null Count Dtype \n", | |
"--- ------ -------------- ----- \n", | |
" 0 age 48842 non-null category\n", | |
" 1 workclass 46043 non-null category\n", | |
" 2 fnlwgt 48842 non-null int64 \n", | |
" 3 education 48842 non-null category\n", | |
" 4 education-num 48842 non-null int64 \n", | |
" 5 marital-status 48842 non-null category\n", | |
" 6 occupation 46033 non-null category\n", | |
" 7 relationship 48842 non-null category\n", | |
" 8 race 48842 non-null category\n", | |
" 9 sex 48842 non-null category\n", | |
" 10 capitalgain 48842 non-null category\n", | |
" 11 capitalloss 48842 non-null category\n", | |
" 12 hoursperweek 48842 non-null category\n", | |
" 13 native-country 47985 non-null category\n", | |
"dtypes: category(12), int64(2)\n", | |
"memory usage: 1.3 MB\n" | |
] | |
} | |
], | |
"source": [ | |
"dataset.info()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "16cbfa18", | |
"metadata": {}, | |
"source": [ | |
"## Preprocessor Builder" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "233f6925", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def bool_to_number(x: np.ndarray) -> np.ndarray:\n", | |
" return np.multiply(x, 1)\n", | |
"\n", | |
"BooleanTransformer = FunctionTransformer(bool_to_number, feature_names_out = 'one-to-one')\n", | |
"\n", | |
"def build_preprocessor_pipeline(dataset: pd.DataFrame, n_jobs_: int = -1, verbose_: bool = False) -> ColumnTransformer:\n", | |
" numerical_columns = make_column_selector(dtype_include=[np.number])(dataset)\n", | |
" categorical_columns = make_column_selector(dtype_include=['category'])(dataset)\n", | |
" boolean_columns = make_column_selector(dtype_include=['bool'])(dataset)\n", | |
" \n", | |
" transformers_ = []\n", | |
" \n", | |
" if numerical_columns: \n", | |
" transformers_.append((\"transformer_n\", SimpleImputer(strategy=\"mean\"), numerical_columns ))\n", | |
" if categorical_columns: \n", | |
" transformer_c = Pipeline(\n", | |
" steps=[\n", | |
" (\"imputer_c\", SimpleImputer(missing_values=np.nan, strategy='most_frequent')),\n", | |
" (\"encoder_c\", OrdinalEncoder(handle_unknown=\"use_encoded_value\",\n", | |
" dtype=np.int8, \n", | |
" encoded_missing_value=-1,\n", | |
" unknown_value=-1)\n", | |
" ),\n", | |
" ],\n", | |
" verbose = False,\n", | |
" memory= None\n", | |
" )\n", | |
" transformers_.append((\"transformer_c\", transformer_c, categorical_columns )) \n", | |
" if boolean_columns: \n", | |
" transformer_b = Pipeline(\n", | |
" steps=[(\"to_int\", BooleanTransformer), \n", | |
" (\"imputer_c\", SimpleImputer(missing_values=np.nan, strategy='most_frequent'))\n", | |
" ],\n", | |
" verbose = False,\n", | |
" memory= None\n", | |
" ) \n", | |
" transformers_.append((\"transformer_b\", transformer_b, boolean_columns )) \n", | |
" \n", | |
" preprocessor = ColumnTransformer(\n", | |
" transformers=transformers_,\n", | |
" n_jobs = n_jobs_,\n", | |
" remainder='drop',\n", | |
" verbose_feature_names_out=False,\n", | |
" verbose=verbose_\n", | |
" ).set_output(transform='pandas')\n", | |
"\n", | |
" return preprocessor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "8ea499d5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>ColumnTransformer(n_jobs=-1,\n", | |
" transformers=[('transformer_n', SimpleImputer(),\n", | |
" ['fnlwgt', 'education-num']),\n", | |
" ('transformer_c',\n", | |
" Pipeline(steps=[('imputer_c',\n", | |
" SimpleImputer(strategy='most_frequent')),\n", | |
" ('encoder_c',\n", | |
" OrdinalEncoder(dtype=<class 'numpy.int8'>,\n", | |
" encoded_missing_value=-1,\n", | |
" handle_unknown='use_encoded_value',\n", | |
" unknown_value=-1))]),\n", | |
" ['age', 'workclass', 'education',\n", | |
" 'marital-status', 'occupation',\n", | |
" 'relationship', 'race', 'sex', 'capitalgain',\n", | |
" 'capitalloss', 'hoursperweek',\n", | |
" 'native-country'])],\n", | |
" verbose_feature_names_out=False)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-10\" type=\"checkbox\" ><label for=\"sk-estimator-id-10\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ColumnTransformer</label><div class=\"sk-toggleable__content\"><pre>ColumnTransformer(n_jobs=-1,\n", | |
" transformers=[('transformer_n', SimpleImputer(),\n", | |
" ['fnlwgt', 'education-num']),\n", | |
" ('transformer_c',\n", | |
" Pipeline(steps=[('imputer_c',\n", | |
" SimpleImputer(strategy='most_frequent')),\n", | |
" ('encoder_c',\n", | |
" OrdinalEncoder(dtype=<class 'numpy.int8'>,\n", | |
" encoded_missing_value=-1,\n", | |
" handle_unknown='use_encoded_value',\n", | |
" unknown_value=-1))]),\n", | |
" ['age', 'workclass', 'education',\n", | |
" 'marital-status', 'occupation',\n", | |
" 'relationship', 'race', 'sex', 'capitalgain',\n", | |
" 'capitalloss', 'hoursperweek',\n", | |
" 'native-country'])],\n", | |
" verbose_feature_names_out=False)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-11\" type=\"checkbox\" ><label for=\"sk-estimator-id-11\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">transformer_n</label><div class=\"sk-toggleable__content\"><pre>['fnlwgt', 'education-num']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-12\" type=\"checkbox\" ><label for=\"sk-estimator-id-12\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer()</pre></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-13\" type=\"checkbox\" ><label for=\"sk-estimator-id-13\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">transformer_c</label><div class=\"sk-toggleable__content\"><pre>['age', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capitalgain', 'capitalloss', 'hoursperweek', 'native-country']</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-14\" type=\"checkbox\" ><label for=\"sk-estimator-id-14\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy='most_frequent')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-15\" type=\"checkbox\" ><label for=\"sk-estimator-id-15\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">OrdinalEncoder</label><div class=\"sk-toggleable__content\"><pre>OrdinalEncoder(dtype=<class 'numpy.int8'>, encoded_missing_value=-1,\n", | |
" handle_unknown='use_encoded_value', unknown_value=-1)</pre></div></div></div></div></div></div></div></div></div></div></div></div>" | |
], | |
"text/plain": [ | |
"ColumnTransformer(n_jobs=-1,\n", | |
" transformers=[('transformer_n', SimpleImputer(),\n", | |
" ['fnlwgt', 'education-num']),\n", | |
" ('transformer_c',\n", | |
" Pipeline(steps=[('imputer_c',\n", | |
" SimpleImputer(strategy='most_frequent')),\n", | |
" ('encoder_c',\n", | |
" OrdinalEncoder(dtype=<class 'numpy.int8'>,\n", | |
" encoded_missing_value=-1,\n", | |
" handle_unknown='use_encoded_value',\n", | |
" unknown_value=-1))]),\n", | |
" ['age', 'workclass', 'education',\n", | |
" 'marital-status', 'occupation',\n", | |
" 'relationship', 'race', 'sex', 'capitalgain',\n", | |
" 'capitalloss', 'hoursperweek',\n", | |
" 'native-country'])],\n", | |
" verbose_feature_names_out=False)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"preprocessor = build_preprocessor_pipeline(dataset)\n", | |
"preprocessor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "e696eae3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>fnlwgt</th>\n", | |
" <th>education-num</th>\n", | |
" <th>age</th>\n", | |
" <th>workclass</th>\n", | |
" <th>education</th>\n", | |
" <th>marital-status</th>\n", | |
" <th>occupation</th>\n", | |
" <th>relationship</th>\n", | |
" <th>race</th>\n", | |
" <th>sex</th>\n", | |
" <th>capitalgain</th>\n", | |
" <th>capitalloss</th>\n", | |
" <th>hoursperweek</th>\n", | |
" <th>native-country</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>77516.0</td>\n", | |
" <td>13.0</td>\n", | |
" <td>2</td>\n", | |
" <td>6</td>\n", | |
" <td>9</td>\n", | |
" <td>4</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>4</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>38</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>83311.0</td>\n", | |
" <td>13.0</td>\n", | |
" <td>3</td>\n", | |
" <td>5</td>\n", | |
" <td>9</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>0</td>\n", | |
" <td>4</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>38</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>215646.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>11</td>\n", | |
" <td>0</td>\n", | |
" <td>5</td>\n", | |
" <td>1</td>\n", | |
" <td>4</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>38</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>234721.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>3</td>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>2</td>\n", | |
" <td>5</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>38</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>338409.0</td>\n", | |
" <td>13.0</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>9</td>\n", | |
" <td>2</td>\n", | |
" <td>9</td>\n", | |
" <td>5</td>\n", | |
" <td>2</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" fnlwgt education-num age workclass education marital-status \\\n", | |
"0 77516.0 13.0 2 6 9 4 \n", | |
"1 83311.0 13.0 3 5 9 2 \n", | |
"2 215646.0 9.0 2 3 11 0 \n", | |
"3 234721.0 7.0 3 3 1 2 \n", | |
"4 338409.0 13.0 1 3 9 2 \n", | |
"\n", | |
" occupation relationship race sex capitalgain capitalloss \\\n", | |
"0 0 1 4 1 1 0 \n", | |
"1 3 0 4 1 0 0 \n", | |
"2 5 1 4 1 0 0 \n", | |
"3 5 0 2 1 0 0 \n", | |
"4 9 5 2 0 0 0 \n", | |
"\n", | |
" hoursperweek native-country \n", | |
"0 2 38 \n", | |
"1 0 38 \n", | |
"2 2 38 \n", | |
"3 2 38 \n", | |
"4 2 4 " | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset_trf = preprocessor.fit_transform(dataset)\n", | |
"\n", | |
"dataset_trf.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "759ffe3f", | |
"metadata": {}, | |
"source": [ | |
"## Setting Estimator to Pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "3f86fc65", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"args = {\n", | |
" \"random_state\": 10,\n", | |
" \"n_jobs\": -1,\n", | |
"}\n", | |
"\n", | |
"\n", | |
"model = Pipeline(\n", | |
" steps=[(\"transformer\", preprocessor ), \n", | |
" (\"estimator\", RandomForestClassifier(**args))\n", | |
" ],\n", | |
" verbose = False,\n", | |
" memory = None\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a955d486", | |
"metadata": {}, | |
"source": [ | |
"## Fitting Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "6cadad96", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[I 2024-01-15 22:04:51,540] A new study created in RDB with name: Randomforest Tuner\n", | |
"C:\\Users\\rdas6\\AppData\\Local\\Temp\\ipykernel_15780\\2964663329.py:38: ExperimentalWarning: OptunaSearchCV is experimental (supported from v0.17.0). The interface can change in the future.\n", | |
" optuna_search = OptunaSearchCV(model_,\n", | |
"[I 2024-01-15 22:04:51,579] Searching the best hyperparameters using 48842 samples...\n", | |
"[I 2024-01-15 22:05:09,718] Trial 1 finished with value: 0.7607182362198229 and parameters: {'estimator__n_estimators': 100, 'estimator__max_depth': 10, 'estimator__min_impurity_decrease': 0.3810441382224819, 'estimator__max_features': 'log2', 'estimator__bootstrap': False}. Best is trial 1 with value: 0.7607182362198229.\n", | |
"[I 2024-01-15 22:05:11,399] Trial 3 finished with value: 0.8387248805305925 and parameters: {'estimator__n_estimators': 170, 'estimator__max_depth': 6, 'estimator__min_impurity_decrease': 0.0006576437138114963, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': True}. Best is trial 3 with value: 0.8387248805305925.\n", | |
"[I 2024-01-15 22:05:12,046] Trial 0 finished with value: 0.8497195138074449 and parameters: {'estimator__n_estimators': 150, 'estimator__max_depth': 9, 'estimator__min_impurity_decrease': 3.6984829700086176e-08, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': False}. Best is trial 0 with value: 0.8497195138074449.\n", | |
"[I 2024-01-15 22:05:12,490] Trial 4 finished with value: 0.7607182362198229 and parameters: {'estimator__n_estimators': 260, 'estimator__max_depth': 6, 'estimator__min_impurity_decrease': 0.2013327427553822, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': False}. Best is trial 0 with value: 0.8497195138074449.\n", | |
"[I 2024-01-15 22:05:12,814] Trial 2 finished with value: 0.8476516205761779 and parameters: {'estimator__n_estimators': 220, 'estimator__max_depth': 8, 'estimator__min_impurity_decrease': 4.4744726010704304e-05, 'estimator__max_features': 'sqrt', 'estimator__bootstrap': False}. Best is trial 0 with value: 0.8497195138074449.\n", | |
"[I 2024-01-15 22:05:12,814] Finished hyperparameter search!\n", | |
"[I 2024-01-15 22:05:12,833] Refitting the estimator using 48842 samples...\n", | |
"[I 2024-01-15 22:05:14,353] Finished refitting! (elapsed time: 1.520 sec.)\n", | |
"[Memory(location=C:\\Users\\rdas6\\AppData\\Local\\Temp\\tmp_q1q_jbu\\joblib)]: Flushing completely the cache\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"End 2 End Time - 22.852312326431274 secs\n" | |
] | |
} | |
], | |
"source": [ | |
"import tempfile\n", | |
"from sklearn.base import clone\n", | |
"import joblib\n", | |
"import time\n", | |
"\n", | |
"from optuna import samplers, create_study\n", | |
"from optuna.distributions import FloatDistribution, IntDistribution, CategoricalDistribution, IntUniformDistribution\n", | |
"from optuna.integration import OptunaSearchCV\n", | |
"\n", | |
"param_distributions = {\n", | |
" \"estimator__n_estimators\": IntDistribution(10, 300, step=10),\n", | |
" \"estimator__max_depth\": IntDistribution(1, 11),\n", | |
" \"estimator__min_impurity_decrease\": FloatDistribution(0.000000001, 0.5, log=True),\n", | |
" \"estimator__max_features\": FloatDistribution(0.4, 1),\n", | |
" \"estimator__max_features\": CategoricalDistribution([1.0, \"sqrt\", \"log2\"]),\n", | |
" \"estimator__bootstrap\": CategoricalDistribution([True, False]),\n", | |
"}\n", | |
"\n", | |
"\n", | |
"\n", | |
"storage_string_ = \"sqlite:///./test_2.db\" # optional\n", | |
"sampler_ = samplers.TPESampler(seed=10)\n", | |
"study_ = create_study(storage=storage_string_, \n", | |
" study_name='Randomforest Tuner',\n", | |
" direction=\"maximize\", \n", | |
" sampler=sampler_)\n", | |
"\n", | |
"\n", | |
"\n", | |
"cv_result, best_params, best_model, best_score = None, None, None, None\n", | |
"try:\n", | |
" st_time = time.time()\n", | |
" tempdir = tempfile.TemporaryDirectory()\n", | |
" model_ = clone(model)\n", | |
" memory_ = Memory(tempdir.name, verbose=0) ## use for hypermeter tunning,\n", | |
" model_.memory = memory_ \n", | |
" model_.verbose = False\n", | |
" optuna_search = OptunaSearchCV(model_,\n", | |
" param_distributions,\n", | |
" cv=5,\n", | |
" #max_iter=20,\n", | |
" n_trials = 5,\n", | |
" n_jobs=-1,\n", | |
" random_state=10,\n", | |
" refit=True,\n", | |
" verbose = 10,\n", | |
" timeout = 60*60,\n", | |
" study=study_\n", | |
" ) \n", | |
" optuna_search.fit(dataset,target)\n", | |
"except Exception as err:\n", | |
" print(err)\n", | |
"else:\n", | |
" cv_result = pd.DataFrame().from_dict(optuna_search.cv_results_)\n", | |
" best_score = optuna_search.best_score_\n", | |
" best_model = optuna_search.best_estimator_\n", | |
" best_params = optuna_search.best_params_\n", | |
" #print(optuna_search.best_params_, optuna_search.best_index_)\n", | |
" best_model.memory = None\n", | |
"finally: \n", | |
" memory_.clear()\n", | |
" tempdir.cleanup()\n", | |
" print(f'End 2 End Time - {time.time() - st_time} secs')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "addf1456", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment