Last active
January 11, 2017 15:18
-
-
Save lesteve/89754ebae39d441bc4edee4ca788a6dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import copy | |
import numpy as np | |
from sklearn.base import clone | |
from sklearn.utils import check_random_state | |
from sklearn.model_selection import train_test_split | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.externals import joblib | |
def combine(all_ensembles): | |
final_ensemble = copy(all_ensembles[0]) | |
final_ensemble.estimators_ = [] | |
for ensemble in all_ensembles: | |
final_ensemble.estimators_ += ensemble.estimators_ | |
return final_ensemble | |
def train_model(model, X, y, sample_weight=None, random_state=None): | |
model.set_params(random_state=random_state) | |
if sample_weight is not None: | |
model.fit(X, y, sample_weight=sample_weight) | |
else: | |
model.fit(X, y) | |
return model | |
def grow_ensemble(base_model, X, y, sample_weight=None, n_estimators=1, | |
n_jobs=1, random_state=None): | |
random_state = check_random_state(random_state) | |
max_seed = np.iinfo('uint32').max | |
random_states = random_state.randint(max_seed + 1, size=n_estimators) | |
results = joblib.Parallel(n_jobs=n_jobs)( | |
joblib.delayed(train_model)( | |
clone(base_model), X, y, | |
sample_weight=sample_weight, random_state=rs) | |
for rs in random_states) | |
return combine(results) | |
if __name__ == '__main__': | |
from sklearn.datasets import load_digits | |
digits = load_digits() | |
X_train, X_test, y_train, y_test = train_test_split( | |
digits.data, digits.target, random_state=0) | |
final_model = grow_ensemble(RandomForestClassifier(), X_train, y_train, | |
n_estimators=10, n_jobs=2, random_state=42) | |
print("number of trees: {}".format(len(final_model.estimators_))) | |
score = final_model.score(X_test, y_test) | |
print("score: {:.3f}".format(score)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
cc @ogrisel