Skip to content

Instantly share code, notes, and snippets.

@lesteve
Last active January 11, 2017 15:18
Show Gist options
  • Save lesteve/89754ebae39d441bc4edee4ca788a6dd to your computer and use it in GitHub Desktop.
Save lesteve/89754ebae39d441bc4edee4ca788a6dd to your computer and use it in GitHub Desktop.
from copy import copy
import numpy as np
from sklearn.base import clone
from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
def combine(all_ensembles):
final_ensemble = copy(all_ensembles[0])
final_ensemble.estimators_ = []
for ensemble in all_ensembles:
final_ensemble.estimators_ += ensemble.estimators_
return final_ensemble
def train_model(model, X, y, sample_weight=None, random_state=None):
model.set_params(random_state=random_state)
if sample_weight is not None:
model.fit(X, y, sample_weight=sample_weight)
else:
model.fit(X, y)
return model
def grow_ensemble(base_model, X, y, sample_weight=None, n_estimators=1,
n_jobs=1, random_state=None):
random_state = check_random_state(random_state)
max_seed = np.iinfo('uint32').max
random_states = random_state.randint(max_seed + 1, size=n_estimators)
results = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(train_model)(
clone(base_model), X, y,
sample_weight=sample_weight, random_state=rs)
for rs in random_states)
return combine(results)
if __name__ == '__main__':
from sklearn.datasets import load_digits
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target, random_state=0)
final_model = grow_ensemble(RandomForestClassifier(), X_train, y_train,
n_estimators=10, n_jobs=2, random_state=42)
print("number of trees: {}".format(len(final_model.estimators_)))
score = final_model.score(X_test, y_test)
print("score: {:.3f}".format(score))
@lesteve
Copy link
Author

lesteve commented Jan 11, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment