Skip to content

Instantly share code, notes, and snippets.

@lmassaron
Created September 3, 2019 08:31
Show Gist options
  • Select an option

  • Save lmassaron/f4c00689ba2bab53c1fd7b5b63730a34 to your computer and use it in GitHub Desktop.

Select an option

Save lmassaron/f4c00689ba2bab53c1fd7b5b63730a34 to your computer and use it in GitHub Desktop.
ClassifierTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
class ClassifierTransformer(BaseEstimator, TransformerMixin):
"""
Classifier's estimates of a regression problem using oof
"""
def __init__(self, estimator=None, n_classes=2, cv=3):
self.estimator = estimator
self.n_classes = n_classes
self.cv = cv
def _get_labels(self, y):
y_labels = np.zeros(len(y))
y_us = np.sort(np.unique(y))
step = int(len(y_us) / self.n_classes)
for i_class in range(self.n_classes):
if i_class + 1 == self.n_classes:
y_labels[y >= y_us[i_class * step]] = i_class
else:
y_labels[
np.logical_and(
y >= y_us[i_class * step],
y < y_us[(i_class + 1) * step]
)
] = i_class
return y_labels
def fit(self, X, y):
X = X.replace([np.inf,-np.inf], np.nan)
X = X.fillna(0)
y_labels = self._get_labels(y)
cv = check_cv(self.cv, y_labels, classifier=is_classifier(self.estimator))
self.estimators_ = []
for train, _ in cv.split(X, y_labels):
X = np.array(X)
self.estimators_.append(
clone(self.estimator).fit(X[train], y_labels[train])
)
return self
def transform(self, X, y=None):
cv = check_cv(self.cv, y, classifier=is_classifier(self.estimator))
X = X.replace([np.inf,-np.inf], np.nan)
X = X.fillna(0)
X = np.array(X)
X_prob = np.zeros((X.shape[0], self.n_classes))
X_pred = np.zeros(X.shape[0])
for estimator, (_, test) in zip(self.estimators_, cv.split(X)):
X_prob[test] = estimator.predict_proba(X[test])
X_pred[test] = estimator.predict(X[test])
return np.hstack([X_prob, np.array([X_pred]).T])
clf = ClassifierTransformer(RandomForestClassifier(), n_classes=5, cv=5)
clf1 = ClassifierTransformer(RandomForestClassifier(), n_classes=2, cv=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment