Last active
July 19, 2024 17:31
-
-
Save firas-jolha/e04c8a868bee5acbb1319ced78124640 to your computer and use it in GitHub Desktop.
An example of column transformers and pipelines in sklearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# configs/data/bank-marketing.yaml | |
data_url: /home/firasj/project/data/raw/data.csv | |
num_cols: ['age', 'balance', 'duration', 'campaign', 'pdays', 'previous'] | |
bin_cols: ['default', 'housing', 'loan'] | |
dt_cols: {"day":['day_of_week'], "month": ['month']} | |
target_cols: ['y'] | |
cat_cols: ['job', 'marital', 'education', 'contact', 'poutcome'] | |
categories_cat_file_path: data/raw/categories_cat.npy | |
categories_bin_file_path: data/raw/categories_bin.npy | |
labels: ["yes", "no"] | |
dataset_name: "bank_marketing_dataset" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
array([['no', 'yes'], | |
['yes', 'no'], | |
['no', 'yes']], dtype=object) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
array([array(['management', 'technician', 'entrepreneur', 'blue-collar', | |
'retired', 'admin.', 'services', 'self-employed', 'unemployed', | |
'housemaid', 'student'], dtype=object) , | |
array(['married', 'single', 'divorced'], dtype=object), | |
array(['tertiary', 'secondary', 'primary'], dtype=object), | |
array(['cellular', 'telephone'], dtype=object), | |
array(['failure', 'other', 'success'], dtype=object)], dtype=object) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# src/transform_data.py | |
from utils import init_hydra | |
import os | |
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, FunctionTransformer, StandardScaler | |
from sklearn.pipeline import Pipeline, make_pipeline | |
from sklearn.compose import ColumnTransformer | |
import numpy as np | |
import calendar | |
import pandas as pd | |
pd.options.mode.chained_assignment = None # default='warn' | |
from sklearn.impute import SimpleImputer | |
import dvc.api | |
from zenml import save_artifact | |
from zenml.integrations.sklearn.materializers.sklearn_materializer import SklearnMaterializer | |
from model import load_artifact | |
BASE_PATH = os.path.expandvars("$PROJECTPATH") | |
# initialize Hydra using Compose API | |
# cfg = init_hydra() | |
def extract_data(base_path = BASE_PATH, cfg = None): | |
""" | |
Extracts data from dvc remote data store | |
:param str base_path: The project path. | |
""" | |
if cfg is None: | |
# initialize Hydra using Compose API | |
cfg = init_hydra() | |
version = cfg.data_version | |
data_path = cfg.data.sample_path | |
data_store = cfg.data.data_store | |
# Get relative url to retrieve the data sample | |
url = dvc.api.get_url( | |
rev = version, | |
path = data_path, | |
remote= data_store, | |
repo=BASE_PATH | |
) | |
# Get absolute path to the file | |
url = base_path + url | |
df = pd.read_csv(url) | |
return df, version | |
def convert_month_abr(month): | |
""" | |
A converter to transform the month abbreviations into numbers, e.g. `jan` to 0. | |
:param str month: The month abbreviation | |
""" | |
if month is None: | |
return month | |
months = {m.lower(): index-1 for index, m in enumerate(calendar.month_abbr) if m} | |
return month.map(lambda m : months[m.lower()]) | |
def convert_data(X, cfg = None): | |
""" | |
Converts and fixes specific data features | |
:param pd.DataFrame X: the input dataframe | |
""" | |
if cfg is None: | |
# initialize Hydra using Compose API | |
cfg = init_hydra() | |
# Get all cols of type month | |
month_cols = list(cfg.data.dt_cols['month']) | |
# convert it | |
X[month_cols] = X[month_cols].apply(convert_month_abr) | |
return X | |
def transform_data(df, version = None, return_df = False, cfg = None, only_transform = False, transformer_version=None, only_X = False): | |
""" | |
Transform the raw data into features | |
:param pd.DataFrame df: the input raw dataframe | |
:param str version: the version of the data sample | |
:param bool return_df: True if the returned value is a concatenated dataframe, False if the returned value is input dataframe X and target column y. | |
""" | |
if cfg is None : | |
# initialize Hydra using Compose API | |
cfg = init_hydra() | |
if version is None: | |
version = "v1" | |
# Define labels and features | |
target_col = list(cfg.data.target_cols)[0] | |
print("df columns is ", df.columns) | |
X_cols = [col for col in df.columns if col not in target_col] | |
X = df[X_cols] | |
if not only_X: | |
y = df[target_col] | |
if only_transform: | |
if transformer_version is None: | |
transformer_version = version | |
X_model = load_artifact(name = "X_transform_pipeline", version = transformer_version) | |
if not only_X: | |
y_model = load_artifact(name = "y_transform_pipeline", version = transformer_version) | |
# Convert and fix some specific data features | |
X = convert_data(X, cfg = cfg) | |
X_preprocessed = X_model.transform(X) | |
if not only_X: | |
y_encoded = y_model.transform(y) | |
else: | |
# Define the category of features | |
categorical_features = list(cfg.data.cat_cols) | |
binary_features = list(cfg.data.bin_cols) | |
numerical_features = list(cfg.data.num_cols) | |
dt_features = list(sum(cfg.data.dt_cols.values(), [])) | |
# Convert and fix some specific data features | |
X = convert_data(X, cfg = cfg) | |
# print(X.head()) | |
categories_cat = list(np.load(cfg.data.categories_cat_file_path, allow_pickle=True)) | |
# categories_cat = [l.tolist() for l in categories_cat] | |
categories_bin = list(np.load(cfg.data.categories_bin_file_path, allow_pickle=True)) | |
# categories_bin = [l.tolist() for l in categories_bin] | |
print(categories_cat) | |
# Define the preprocessing transformers | |
categorical_transformer = Pipeline(steps=[ | |
('imputer', SimpleImputer(strategy='most_frequent', keep_empty_features=True)), | |
('onehot', OneHotEncoder(handle_unknown='ignore', categories=categories_cat)) | |
]) | |
binary_transformer = Pipeline(steps=[ | |
('imputer', SimpleImputer(strategy='most_frequent', keep_empty_features=True)), | |
('onehot', OneHotEncoder(handle_unknown='ignore', categories=categories_bin)) | |
]) | |
numerical_transformer = Pipeline(steps=[ | |
('scaler', StandardScaler()) | |
]) | |
# Define the cyclical feature transformers | |
def sin_transformer(period): | |
return FunctionTransformer(lambda x: np.sin(x.astype(float) / period * 2 * np.pi)) | |
def cos_transformer(period): | |
return FunctionTransformer(lambda x: np.cos(x.astype(float) / period * 2 * np.pi)) | |
# Define the cyclical feature transformers | |
# def sin_transformer_month(period): | |
# return FunctionTransformer(lambda x: np.sin(x.map(lambda xx: [index-1 for index, m in enumarte(calendar.month_abbr) if m and m==xx.lower()]).astype(float) / period * 2 * np.pi)) | |
# def cos_transformer_month(period): | |
# return FunctionTransformer(lambda x: np.cos(x.map(lambda xx: [index-1 for index, m in enumarte(calendar.month_abbr) if m and m==xx.lower()]).astype(float) / period * 2 * np.pi)) | |
# def convert_month(): | |
# m.lower(): index-1 for index, m in enumerate(calendar.month_abbr) if m | |
# return FunctionTransformer(lambda x: ) | |
# month_transformer = | |
dt_transformer = ColumnTransformer(transformers=[ | |
('day_sin', sin_transformer(31), list(cfg.data.dt_cols['day'])), | |
('day_cos', cos_transformer(31), list(cfg.data.dt_cols['day'])), | |
# ('month_fix', convert_month(), list(cfg.data.dt_cols['month'])), | |
('month_sin', sin_transformer(12), list(cfg.data.dt_cols['month'])), | |
('month_cos', cos_transformer(12), list(cfg.data.dt_cols['month'])) | |
]) | |
# Combine the preprocessing transformers | |
preprocessor = ColumnTransformer( | |
transformers=[ | |
('num', numerical_transformer, numerical_features), | |
('cat', categorical_transformer, categorical_features), | |
('bin', binary_transformer, binary_features), | |
('dt', dt_transformer, dt_features) | |
], | |
remainder="drop", # Drop all other features which did not pass through any feature transformer | |
n_jobs = 4 # parallelism | |
) | |
# print(numerical_features, categorical_features, binary_features, dt_features, labels) | |
pipe = make_pipeline(preprocessor) | |
# This will draw a diagram if you run it in a Jupyter notebook. | |
from sklearn import set_config | |
set_config(display="diagram") | |
print(pipe) | |
# Fit input data X | |
X_model = pipe.fit(X) | |
# Transform input data X | |
X_preprocessed = X_model.transform(X) | |
# Define the label encoder for the target variable | |
le = LabelEncoder() # This encoder cannot be used in a column transformer | |
y_model = le.fit(np.array(cfg.data.labels)) | |
y_encoded = y_model.transform(y.values.ravel()) | |
save_artifact(data = X_model, name = "X_transform_pipeline", tags=[version], materializer=SklearnMaterializer) | |
save_artifact(data = y_model, name = "y_transform_pipeline", tags=[version], materializer=SklearnMaterializer) | |
# X_preprocessed | |
X = pd.DataFrame(X_preprocessed) | |
if not only_X: | |
y = pd.DataFrame(y_encoded,columns = cfg.data.target_cols) # type: ignore | |
# Do not forget to make columns of string type | |
X.columns = X.columns.astype(str) | |
if not only_X: | |
y.columns = y.columns.astype(str) | |
if only_X: | |
return X | |
print("features shape is : ", X.shape) | |
if return_df: | |
df = pd.concat([X, y], axis = 1) | |
return df | |
else: | |
return X, y | |
if __name__=="__main__": | |
df, version = extract_data() | |
df = transform_data(df, version, return_df = True) | |
print(df.head()) # type: ignore |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment