Created
February 15, 2018 21:29
-
-
Save amueller/1f8d3c03305642ab3bad677d9b443b80 to your computer and use it in GitHub Desktop.
Stand-alone matplotlib based tree plotting from https://github.com/scikit-learn/scikit-learn/pull/9251
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numbers import Integral | |
from sklearn.externals import six | |
from sklearn.tree.export import _color_brew, _criterion, _tree | |
def plot_tree(decision_tree, max_depth=None, feature_names=None, | |
class_names=None, label='all', filled=False, | |
leaves_parallel=False, impurity=True, node_ids=False, | |
proportion=False, rotate=False, rounded=False, | |
special_characters=False, precision=3, ax=None, fontsize=None): | |
"""Plot a decision tree. | |
The sample counts that are shown are weighted with any sample_weights that | |
might be present. | |
Parameters | |
---------- | |
decision_tree : decision tree classifier | |
The decision tree to be exported to GraphViz. | |
max_depth : int, optional (default=None) | |
The maximum depth of the representation. If None, the tree is fully | |
generated. | |
feature_names : list of strings, optional (default=None) | |
Names of each of the features. | |
class_names : list of strings, bool or None, optional (default=None) | |
Names of each of the target classes in ascending numerical order. | |
Only relevant for classification and not supported for multi-output. | |
If ``True``, shows a symbolic representation of the class name. | |
label : {'all', 'root', 'none'}, optional (default='all') | |
Whether to show informative labels for impurity, etc. | |
Options include 'all' to show at every node, 'root' to show only at | |
the top root node, or 'none' to not show at any node. | |
filled : bool, optional (default=False) | |
When set to ``True``, paint nodes to indicate majority class for | |
classification, extremity of values for regression, or purity of node | |
for multi-output. | |
leaves_parallel : bool, optional (default=False) | |
When set to ``True``, draw all leaf nodes at the bottom of the tree. | |
impurity : bool, optional (default=True) | |
When set to ``True``, show the impurity at each node. | |
node_ids : bool, optional (default=False) | |
When set to ``True``, show the ID number on each node. | |
proportion : bool, optional (default=False) | |
When set to ``True``, change the display of 'values' and/or 'samples' | |
to be proportions and percentages respectively. | |
rotate : bool, optional (default=False) | |
When set to ``True``, orient tree left to right rather than top-down. | |
rounded : bool, optional (default=False) | |
When set to ``True``, draw node boxes with rounded corners and use | |
Helvetica fonts instead of Times-Roman. | |
special_characters : bool, optional (default=False) | |
When set to ``False``, ignore special characters for PostScript | |
compatibility. | |
precision : int, optional (default=3) | |
Number of digits of precision for floating point in the values of | |
impurity, threshold and value attributes of each node. | |
ax : matplotlib axis, optional (default=None) | |
Axes to plot to. If None, use current axis. | |
Examples | |
-------- | |
>>> from sklearn.datasets import load_iris | |
>>> clf = tree.DecisionTreeClassifier() | |
>>> iris = load_iris() | |
>>> clf = clf.fit(iris.data, iris.target) | |
>>> plot_tree(clf) # doctest: +SKIP | |
""" | |
exporter = _MPLTreeExporter( | |
max_depth=max_depth, feature_names=feature_names, | |
class_names=class_names, label=label, filled=filled, | |
leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, | |
proportion=proportion, rotate=rotate, rounded=rounded, | |
special_characters=special_characters, precision=precision, | |
fontsize=fontsize) | |
exporter.export(decision_tree, ax=ax) | |
class _BaseTreeExporter(object): | |
def get_color(self, value): | |
# Find the appropriate color & intensity for a node | |
if self.colors['bounds'] is None: | |
# Classification tree | |
color = list(self.colors['rgb'][np.argmax(value)]) | |
sorted_values = sorted(value, reverse=True) | |
if len(sorted_values) == 1: | |
alpha = 0 | |
else: | |
alpha = ((sorted_values[0] - sorted_values[1]) | |
/ (1 - sorted_values[1])) | |
else: | |
# Regression tree or multi-output | |
color = list(self.colors['rgb'][0]) | |
alpha = ((value - self.colors['bounds'][0]) / | |
(self.colors['bounds'][1] - self.colors['bounds'][0])) | |
# unpack numpy scalars | |
alpha = float(alpha) | |
# compute the color as alpha against white | |
color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] | |
# Return html color code in #RRGGBB format | |
hex_codes = [str(i) for i in range(10)] | |
hex_codes.extend(['a', 'b', 'c', 'd', 'e', 'f']) | |
color = [hex_codes[c // 16] + hex_codes[c % 16] for c in color] | |
return '#' + ''.join(color) | |
def get_fill_color(self, tree, node_id): | |
# Fetch appropriate color for node | |
if 'rgb' not in self.colors: | |
# Initialize colors and bounds if required | |
self.colors['rgb'] = _color_brew(tree.n_classes[0]) | |
if tree.n_outputs != 1: | |
# Find max and min impurities for multi-output | |
self.colors['bounds'] = (np.min(-tree.impurity), | |
np.max(-tree.impurity)) | |
elif (tree.n_classes[0] == 1 and | |
len(np.unique(tree.value)) != 1): | |
# Find max and min values in leaf nodes for regression | |
self.colors['bounds'] = (np.min(tree.value), | |
np.max(tree.value)) | |
if tree.n_outputs == 1: | |
node_val = (tree.value[node_id][0, :] / | |
tree.weighted_n_node_samples[node_id]) | |
if tree.n_classes[0] == 1: | |
# Regression | |
node_val = tree.value[node_id][0, :] | |
else: | |
# If multi-output color node by impurity | |
node_val = -tree.impurity[node_id] | |
return self.get_color(node_val) | |
def node_to_str(self, tree, node_id, criterion): | |
# Generate the node content string | |
if tree.n_outputs == 1: | |
value = tree.value[node_id][0, :] | |
else: | |
value = tree.value[node_id] | |
# Should labels be shown? | |
labels = (self.label == 'root' and node_id == 0) or self.label == 'all' | |
characters = self.characters | |
node_string = characters[-1] | |
# Write node ID | |
if self.node_ids: | |
if labels: | |
node_string += 'node ' | |
node_string += characters[0] + str(node_id) + characters[4] | |
# Write decision criteria | |
if tree.children_left[node_id] != _tree.TREE_LEAF: | |
# Always write node decision criteria, except for leaves | |
if self.feature_names is not None: | |
feature = self.feature_names[tree.feature[node_id]] | |
else: | |
feature = "X%s%s%s" % (characters[1], | |
tree.feature[node_id], | |
characters[2]) | |
node_string += '%s %s %s%s' % (feature, | |
characters[3], | |
round(tree.threshold[node_id], | |
self.precision), | |
characters[4]) | |
# Write impurity | |
if self.impurity: | |
if isinstance(criterion, _criterion.FriedmanMSE): | |
criterion = "friedman_mse" | |
elif not isinstance(criterion, six.string_types): | |
criterion = "impurity" | |
if labels: | |
node_string += '%s = ' % criterion | |
node_string += (str(round(tree.impurity[node_id], self.precision)) | |
+ characters[4]) | |
# Write node sample count | |
if labels: | |
node_string += 'samples = ' | |
if self.proportion: | |
percent = (100. * tree.n_node_samples[node_id] / | |
float(tree.n_node_samples[0])) | |
node_string += (str(round(percent, 1)) + '%' + | |
characters[4]) | |
else: | |
node_string += (str(tree.n_node_samples[node_id]) + | |
characters[4]) | |
# Write node class distribution / regression value | |
if self.proportion and tree.n_classes[0] != 1: | |
# For classification this will show the proportion of samples | |
value = value / tree.weighted_n_node_samples[node_id] | |
if labels: | |
node_string += 'value = ' | |
if tree.n_classes[0] == 1: | |
# Regression | |
value_text = np.around(value, self.precision) | |
elif self.proportion: | |
# Classification | |
value_text = np.around(value, self.precision) | |
elif np.all(np.equal(np.mod(value, 1), 0)): | |
# Classification without floating-point weights | |
value_text = value.astype(int) | |
else: | |
# Classification with floating-point weights | |
value_text = np.around(value, self.precision) | |
# Strip whitespace | |
value_text = str(value_text.astype('S32')).replace("b'", "'") | |
value_text = value_text.replace("' '", ", ").replace("'", "") | |
if tree.n_classes[0] == 1 and tree.n_outputs == 1: | |
value_text = value_text.replace("[", "").replace("]", "") | |
value_text = value_text.replace("\n ", characters[4]) | |
node_string += value_text + characters[4] | |
# Write node majority class | |
if (self.class_names is not None and | |
tree.n_classes[0] != 1 and | |
tree.n_outputs == 1): | |
# Only done for single-output classification trees | |
if labels: | |
node_string += 'class = ' | |
if self.class_names is not True: | |
class_name = self.class_names[np.argmax(value)] | |
else: | |
class_name = "y%s%s%s" % (characters[1], | |
np.argmax(value), | |
characters[2]) | |
node_string += class_name | |
# Clean up any trailing newlines | |
if node_string.endswith(characters[4]): | |
node_string = node_string[:-len(characters[4])] | |
return node_string + characters[5] | |
class _MPLTreeExporter(_BaseTreeExporter): | |
def __init__(self, max_depth=None, feature_names=None, | |
class_names=None, label='all', filled=False, | |
leaves_parallel=False, impurity=True, node_ids=False, | |
proportion=False, rotate=False, rounded=False, | |
special_characters=False, precision=3, fontsize=None): | |
self.max_depth = max_depth | |
self.feature_names = feature_names | |
self.class_names = class_names | |
self.label = label | |
self.filled = filled | |
self.leaves_parallel = leaves_parallel | |
self.impurity = impurity | |
self.node_ids = node_ids | |
self.proportion = proportion | |
self.rotate = rotate | |
self.rounded = rounded | |
self.special_characters = special_characters | |
self.precision = precision | |
self.fontsize = fontsize | |
self._scaley = 10 | |
# validate | |
if isinstance(precision, Integral): | |
if precision < 0: | |
raise ValueError("'precision' should be greater or equal to 0." | |
" Got {} instead.".format(precision)) | |
else: | |
raise ValueError("'precision' should be an integer. Got {}" | |
" instead.".format(type(precision))) | |
# The depth of each node for plotting with 'leaf' option | |
self.ranks = {'leaves': []} | |
# The colors to render each node with | |
self.colors = {'bounds': None} | |
self.characters = ['#', '[', ']', '<=', '\n', '', ''] | |
self.bbox_args = dict(fc='w') | |
if self.rounded: | |
self.bbox_args['boxstyle'] = "round" | |
self.arrow_args = dict(arrowstyle="<-") | |
def _make_tree(self, node_id, et): | |
# traverses _tree.Tree recursively, builds intermediate | |
# "_reingold_tilford.Tree" object | |
name = self.node_to_str(et, node_id, criterion='entropy') | |
if (et.children_left[node_id] != et.children_right[node_id]): | |
children = [self._make_tree(et.children_left[node_id], et), | |
self._make_tree(et.children_right[node_id], et)] | |
else: | |
return Tree(name, node_id) | |
return Tree(name, node_id, *children) | |
def export(self, decision_tree, ax=None): | |
import matplotlib.pyplot as plt | |
from matplotlib.text import Annotation | |
if ax is None: | |
ax = plt.gca() | |
ax.set_axis_off() | |
my_tree = self._make_tree(0, decision_tree.tree_) | |
dt = buchheim(my_tree) | |
self._scalex = 1 | |
self.recurse(dt, decision_tree.tree_, ax) | |
anns = [ann for ann in ax.get_children() | |
if isinstance(ann, Annotation)] | |
# get all the annotated points | |
xys = [ann.xyann for ann in anns] | |
mins = np.min(xys, axis=0) | |
maxs = np.max(xys, axis=0) | |
ax.set_xlim(mins[0], maxs[0]) | |
ax.set_ylim(maxs[1], mins[1]) | |
if self.fontsize is None: | |
# get figure to data transform | |
inv = ax.transData.inverted() | |
renderer = ax.figure.canvas.get_renderer() | |
# update sizes of all bboxes | |
for ann in anns: | |
ann.update_bbox_position_size(renderer) | |
# get max box width | |
widths = [inv.get_matrix()[0, 0] | |
* ann.get_bbox_patch().get_window_extent().width | |
for ann in anns] | |
# get minimum max size to not be too big. | |
max_width = max(max(widths), 1) | |
# adjust fontsize to avoid overlap | |
# width should be around 1 in data coordinates | |
size = anns[0].get_fontsize() / max_width | |
for ann in anns: | |
ann.set_fontsize(size) | |
def recurse(self, node, tree, ax, depth=0): | |
kwargs = dict(bbox=self.bbox_args, ha='center', va='center', | |
zorder=100 - 10 * depth) | |
if self.fontsize is not None: | |
kwargs['fontsize'] = self.fontsize | |
xy = (node.x * self._scalex, node.y * self._scaley) | |
if self.max_depth is None or depth <= self.max_depth: | |
if self.filled: | |
kwargs['bbox']['fc'] = self.get_fill_color(tree, | |
node.tree.node_id) | |
if node.parent is None: | |
# root | |
ax.annotate(node.tree.node, xy, **kwargs) | |
else: | |
xy_parent = (node.parent.x * self._scalex, | |
node.parent.y * self._scaley) | |
kwargs["arrowprops"] = self.arrow_args | |
ax.annotate(node.tree.node, xy_parent, xy, **kwargs) | |
for child in node.children: | |
self.recurse(child, tree, ax, depth=depth + 1) | |
else: | |
xy_parent = (node.parent.x * self._scalex, node.parent.y * | |
self._scaley) | |
kwargs["arrowprops"] = self.arrow_args | |
kwargs['bbox']['fc'] = 'grey' | |
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) | |
class DrawTree(object): | |
def __init__(self, tree, parent=None, depth=0, number=1): | |
self.x = -1. | |
self.y = depth | |
self.tree = tree | |
self.children = [DrawTree(c, self, depth + 1, i + 1) | |
for i, c | |
in enumerate(tree.children)] | |
self.parent = parent | |
self.thread = None | |
self.mod = 0 | |
self.ancestor = self | |
self.change = self.shift = 0 | |
self._lmost_sibling = None | |
# this is the number of the node in its group of siblings 1..n | |
self.number = number | |
def left(self): | |
return self.thread or len(self.children) and self.children[0] | |
def right(self): | |
return self.thread or len(self.children) and self.children[-1] | |
def lbrother(self): | |
n = None | |
if self.parent: | |
for node in self.parent.children: | |
if node == self: | |
return n | |
else: | |
n = node | |
return n | |
def get_lmost_sibling(self): | |
if not self._lmost_sibling and self.parent and self != \ | |
self.parent.children[0]: | |
self._lmost_sibling = self.parent.children[0] | |
return self._lmost_sibling | |
lmost_sibling = property(get_lmost_sibling) | |
def __str__(self): | |
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod) | |
def __repr__(self): | |
return self.__str__() | |
def buchheim(tree): | |
dt = firstwalk(DrawTree(tree)) | |
min = second_walk(dt) | |
if min < 0: | |
third_walk(dt, -min) | |
return dt | |
def third_walk(tree, n): | |
tree.x += n | |
for c in tree.children: | |
third_walk(c, n) | |
def firstwalk(v, distance=1.): | |
if len(v.children) == 0: | |
if v.lmost_sibling: | |
v.x = v.lbrother().x + distance | |
else: | |
v.x = 0. | |
else: | |
default_ancestor = v.children[0] | |
for w in v.children: | |
firstwalk(w) | |
default_ancestor = apportion(w, default_ancestor, distance) | |
# print("finished v =", v.tree, "children") | |
execute_shifts(v) | |
midpoint = (v.children[0].x + v.children[-1].x) / 2 | |
w = v.lbrother() | |
if w: | |
v.x = w.x + distance | |
v.mod = v.x - midpoint | |
else: | |
v.x = midpoint | |
return v | |
def apportion(v, default_ancestor, distance): | |
w = v.lbrother() | |
if w is not None: | |
# in buchheim notation: | |
# i == inner; o == outer; r == right; l == left; r = +; l = - | |
vir = vor = v | |
vil = w | |
vol = v.lmost_sibling | |
sir = sor = v.mod | |
sil = vil.mod | |
sol = vol.mod | |
while vil.right() and vir.left(): | |
vil = vil.right() | |
vir = vir.left() | |
vol = vol.left() | |
vor = vor.right() | |
vor.ancestor = v | |
shift = (vil.x + sil) - (vir.x + sir) + distance | |
if shift > 0: | |
move_subtree(ancestor(vil, v, default_ancestor), v, shift) | |
sir = sir + shift | |
sor = sor + shift | |
sil += vil.mod | |
sir += vir.mod | |
sol += vol.mod | |
sor += vor.mod | |
if vil.right() and not vor.right(): | |
vor.thread = vil.right() | |
vor.mod += sil - sor | |
else: | |
if vir.left() and not vol.left(): | |
vol.thread = vir.left() | |
vol.mod += sir - sol | |
default_ancestor = v | |
return default_ancestor | |
def move_subtree(wl, wr, shift): | |
subtrees = wr.number - wl.number | |
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees, | |
# 'shift', shift) | |
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees | |
wr.change -= shift / subtrees | |
wr.shift += shift | |
wl.change += shift / subtrees | |
wr.x += shift | |
wr.mod += shift | |
def execute_shifts(v): | |
shift = change = 0 | |
for w in v.children[::-1]: | |
# print("shift:", w, shift, w.change) | |
w.x += shift | |
w.mod += shift | |
change += w.change | |
shift += w.shift + change | |
def ancestor(vil, v, default_ancestor): | |
# the relevant text is at the bottom of page 7 of | |
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al, | |
# (2002) | |
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf | |
if vil.ancestor in v.parent.children: | |
return vil.ancestor | |
else: | |
return default_ancestor | |
def second_walk(v, m=0, depth=0, min=None): | |
v.x += m | |
v.y = depth | |
if min is None or v.x < min: | |
min = v.x | |
for w in v.children: | |
min = second_walk(w, m + v.mod, depth + 1, min) | |
return min | |
class Tree(object): | |
def __init__(self, node="", node_id=-1, *children): | |
self.node = node | |
self.width = len(node) | |
self.node_id = node_id | |
if children: | |
self.children = children | |
else: | |
self.children = [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment