Skip to content

Instantly share code, notes, and snippets.

@twolodzko
Created August 1, 2018 10:36
Show Gist options
  • Select an option

  • Save twolodzko/e64d18d2b322800b0676b6e116c4ab85 to your computer and use it in GitHub Desktop.

Select an option

Save twolodzko/e64d18d2b322800b0676b6e116c4ab85 to your computer and use it in GitHub Desktop.
Prune sklearn decision tree
# source: https://github.com/scikit-learn/scikit-learn/issues/10810#issuecomment-373164104
import copy
def prune(tree):
tree = copy.deepcopy(tree)
dat = tree.tree_
nodes = range(0, dat.node_count)
ls = dat.children_left
rs = dat.children_right
classes = [[list(e).index(max(e)) for e in v] for v in dat.value]
leaves = [(ls[i] == rs[i]) for i in nodes]
LEAF = -1
for i in reversed(nodes):
if leaves[i]:
continue
if leaves[ls[i]] and leaves[rs[i]] and classes[ls[i]] == classes[rs[i]]:
ls[i] = rs[i] = LEAF
leaves[i] = True
return tree
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment