Skip to content

Instantly share code, notes, and snippets.

@timotta
Created August 16, 2024 17:30
Show Gist options
  • Save timotta/db6b037d75c921d707e1c3f4f50aa690 to your computer and use it in GitHub Desktop.
Save timotta/db6b037d75c921d707e1c3f4f50aa690 to your computer and use it in GitHub Desktop.
from sklearn.preprocessing import OneHotEncoder
from sklearn.base import BaseEstimator, TransformerMixin
class OneHotEconderByColumn(TransformerMixin, BaseEstimator):
def __init__(self, *, columns):
self.columns = columns
self.ohe = OneHotEncoder(drop="first", sparse=False)
def fit(self, X, y=None):
self.ohe.fit(X[self.columns])
return self
def transform(self, X):
enconded_values = self.ohe.transform(X[self.columns])
encoded_df = pd.DataFrame(
enconded_values, columns=self.ohe.get_feature_names_out()
)
return pd.concat([X.drop(columns=self.columns), encoded_df], axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment