Created
August 1, 2018 10:36
-
-
Save twolodzko/e64d18d2b322800b0676b6e116c4ab85 to your computer and use it in GitHub Desktop.
Prune sklearn decision tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # source: https://github.com/scikit-learn/scikit-learn/issues/10810#issuecomment-373164104 | |
| import copy | |
| def prune(tree): | |
| tree = copy.deepcopy(tree) | |
| dat = tree.tree_ | |
| nodes = range(0, dat.node_count) | |
| ls = dat.children_left | |
| rs = dat.children_right | |
| classes = [[list(e).index(max(e)) for e in v] for v in dat.value] | |
| leaves = [(ls[i] == rs[i]) for i in nodes] | |
| LEAF = -1 | |
| for i in reversed(nodes): | |
| if leaves[i]: | |
| continue | |
| if leaves[ls[i]] and leaves[rs[i]] and classes[ls[i]] == classes[rs[i]]: | |
| ls[i] = rs[i] = LEAF | |
| leaves[i] = True | |
| return tree |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment