Skip to content

Instantly share code, notes, and snippets.

@WillKoehrsen
Created August 18, 2018 23:16
Show Gist options
  • Save WillKoehrsen/3c82f2227f7e1364bff3eace9c688b7f to your computer and use it in GitHub Desktop.
Save WillKoehrsen/3c82f2227f7e1364bff3eace9c688b7f to your computer and use it in GitHub Desktop.
from sklearn.datasets import load_iris
iris = load_iris()
from sklearn.ensemble import RandomForestClassifier
# Limit max depth
model = RandomForestClassifier(max_depth = 3, n_estimators=10)
# Train
model.fit(iris.data, iris.target)
# Extract single tree
estimator_limited = model.estimators_[5]
# No max depth
model = RandomForestClassifier(max_depth = None, n_estimators=10)
model.fit(iris.data, iris.target)
estimator_nonlimited = model.estimators_[5]
from sklearn.tree import export_graphviz
export_graphviz(estimator_limited, out_file='tree_limited.dot', feature_names = iris.feature_names,
class_names = iris.target_names,
rounded = True, proportion = False, precision = 2, filled = True)
export_graphviz(estimator_nonlimited, out_file='tree_nonlimited.dot', feature_names = iris.feature_names,
class_names = iris.target_names,
rounded = True, proportion = False, precision = 2, filled = True)
!dot -Tpng tree_limited.dot -o tree_limited.png -Gdpi=600
from IPython.display import Image
Image(filename = 'tree_limited.png')
!dot -Tpng tree_nonlimited.dot -o tree_nonlimited.png -Gdpi=600
Image(filename = 'tree_nonlimited.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment