Created
August 21, 2018 15:02
-
-
Save aronwc/0ade41235e2470271e4bb994ac410b3d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# show that max_depth affects floating point precision of predict_proba in RandomForest | |
from collections import Counter | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.datasets import make_classification | |
X, y = make_classification(n_samples=1000, n_features=20, | |
n_informative=5, n_redundant=10, | |
random_state=42) | |
print('with no max depth:') | |
clf = RandomForestClassifier(n_estimators=10) | |
print(Counter(clf.fit(X,y).predict_proba(X)[:,1]).most_common(10)) | |
print('\nwith max depth=5:') | |
clf = RandomForestClassifier(n_estimators=10, max_depth=5) | |
print(Counter(clf.fit(X,y).predict_proba(X)[:,1]).most_common(10)) | |
# output: | |
# with no max depth: | |
# [(1.0, 370), (0.0, 359), (0.9, 93), (0.1, 89), (0.2, 32), (0.8, 24), (0.3, 12), (0.7, 8), (0.6, 6), (0.4, 4)] | |
# | |
# with max depth=5: | |
# [(0.035039564760949896, 48), (0.9961472577676984, 37), (0.01221724785227863, 32), (0.9961229150218367, 19), (0.013133914518945296, 18), (0.992163639003737, 17), (0.9980459919449135, 10), (0.025767534109608903, 10), (0.30845687643280967, 9), (0.01815332091163936, 8)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment