Skip to content

Instantly share code, notes, and snippets.

@SS1031
Last active August 11, 2020 19:41
Show Gist options
  • Save SS1031/ca4adb16dd13dcd2ab5859a332b779fb to your computer and use it in GitHub Desktop.
Save SS1031/ca4adb16dd13dcd2ab5859a332b779fb to your computer and use it in GitHub Desktop.
GBDT-LogisticRegression
"""
GBDT + LogisticRegression
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.pipeline import Pipeline
class GBDTTransformer(TransformerMixin):
def __init__(self, n_estimators=100):
self.n_estimators = n_estimators
self.gbdt = GradientBoostingClassifier(n_estimators=n_estimators)
self.gbdt_ohe = OneHotEncoder()
def fit(self, X, y):
self.gbdt.fit(X, y)
self.gbdt_ohe.fit(self.gbdt.apply(X)[:, :, 0])
return self
def transform(self, X):
return self.gbdt_ohe.transform(self.gbdt.apply(X)[:, :, 0])
if __name__ == '__main__':
X, y = make_moons(n_samples=1000, noise=0.3, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y)
classifiers = [
LogisticRegression(),
GradientBoostingClassifier(),
Pipeline([('gbdt', GBDTTransformer()),
('lr', LogisticRegression())])
]
h = .02 # mesh step
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
cm = plt.cm.RdBu
cm_bright = ListedColormap(['#FF0000', '#0000FF'])
for clf in classifiers:
clf.fit(X_train, y_train)
Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=cm, alpha=.8)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment