Visualization of scikit-learn decision trees. Hover with the mouse over the paths to see the decision rules and a histogram of the classes that remain at a certain node in the tree. Checkout this repository to see how to use it in python and also directly in the jupyter notebook.
Last active
January 15, 2018 23:43
-
-
Save MMesch/d34839e664aa5ccde11879e9d3c9cb68 to your computer and use it in GitHub Desktop.
Scikit-learn Decision Tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
.chart { | |
width: 800px; | |
height: 600px; | |
padding: 10px; | |
display: block; | |
} | |
</style> | |
<script src="https://d3js.org/d3.v4.min.js"></script> | |
<script src="tree.js"></script> | |
<body> | |
<div id="chart1" class="chart"></div> | |
<script> | |
var chart1 = document.getElementById("chart1"); | |
d3.json("iris_tree.json", function(data) { | |
console.log(data['name']); | |
plot_tree(d3, data, chart1); | |
}); | |
</script> | |
</body> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"children": [ | |
{ | |
"impurity": 0.0, | |
"name": "node1", | |
"rule": "petal_width < 0.8 ", | |
"values": [ | |
50.0, | |
0.0, | |
0.0 | |
] | |
}, | |
{ | |
"children": [ | |
{ | |
"children": [ | |
{ | |
"children": [ | |
{ | |
"impurity": 0.0, | |
"name": "node5", | |
"rule": "petal_width < 1.7 ", | |
"values": [ | |
0.0, | |
47.0, | |
0.0 | |
] | |
}, | |
{ | |
"impurity": 0.0, | |
"name": "node6", | |
"rule": "petal_width > 1.7 ", | |
"values": [ | |
0.0, | |
0.0, | |
1.0 | |
] | |
} | |
], | |
"impurity": 0.04079861111111116, | |
"name": "node4", | |
"rule": "petal_length < 4.9 ", | |
"values": [ | |
0.0, | |
47.0, | |
1.0 | |
] | |
}, | |
{ | |
"children": [ | |
{ | |
"impurity": 0.0, | |
"name": "node8", | |
"rule": "petal_width < 1.5 ", | |
"values": [ | |
0.0, | |
0.0, | |
3.0 | |
] | |
}, | |
{ | |
"children": [ | |
{ | |
"impurity": 0.0, | |
"name": "node10", | |
"rule": "petal_length < 5.4 ", | |
"values": [ | |
0.0, | |
2.0, | |
0.0 | |
] | |
}, | |
{ | |
"impurity": 0.0, | |
"name": "node11", | |
"rule": "petal_length > 5.4 ", | |
"values": [ | |
0.0, | |
0.0, | |
1.0 | |
] | |
} | |
], | |
"impurity": 0.0, | |
"name": "node9", | |
"rule": "petal_width > 1.5 ", | |
"values": [ | |
0.0, | |
2.0, | |
1.0 | |
] | |
} | |
], | |
"impurity": 0.04079861111111116, | |
"name": "node7", | |
"rule": "petal_length > 4.9 ", | |
"values": [ | |
0.0, | |
2.0, | |
4.0 | |
] | |
} | |
], | |
"impurity": 0.16803840877914955, | |
"name": "node3", | |
"rule": "petal_width < 1.8 ", | |
"values": [ | |
0.0, | |
49.0, | |
5.0 | |
] | |
}, | |
{ | |
"children": [ | |
{ | |
"children": [ | |
{ | |
"impurity": 0.0, | |
"name": "node14", | |
"rule": "sepal_length < 5.9 ", | |
"values": [ | |
0.0, | |
1.0, | |
0.0 | |
] | |
}, | |
{ | |
"impurity": 0.0, | |
"name": "node15", | |
"rule": "sepal_length > 5.9 ", | |
"values": [ | |
0.0, | |
0.0, | |
2.0 | |
] | |
} | |
], | |
"impurity": 0.4444444444444444, | |
"name": "node13", | |
"rule": "petal_length < 4.9 ", | |
"values": [ | |
0.0, | |
1.0, | |
2.0 | |
] | |
}, | |
{ | |
"impurity": 0.4444444444444444, | |
"name": "node16", | |
"rule": "petal_length > 4.9 ", | |
"values": [ | |
0.0, | |
0.0, | |
43.0 | |
] | |
} | |
], | |
"impurity": 0.16803840877914955, | |
"name": "node12", | |
"rule": "petal_width > 1.8 ", | |
"values": [ | |
0.0, | |
1.0, | |
45.0 | |
] | |
} | |
], | |
"impurity": 0.0, | |
"name": "node2", | |
"rule": "petal_width > 0.8 ", | |
"values": [ | |
0.0, | |
50.0, | |
50.0 | |
] | |
} | |
], | |
"class_names": [ | |
"setosa", | |
"versicolor", | |
"virginica" | |
], | |
"name": "node0", | |
"rule": "root", | |
"values": [ | |
50.0, | |
50.0, | |
50.0 | |
] | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function plot_tree(d3, data, chart){ | |
// set some layout variables | |
var positionInfo = chart.getBoundingClientRect(); | |
var el_height = positionInfo.height; | |
var el_width = positionInfo.width; | |
var margin = {top: 20, right: 20, bottom: 30, left: 40}, | |
svg_width = el_width - margin.left - margin.right, | |
svg_height = el_height - margin.top - margin.bottom; | |
var svg = d3.select(chart).append("svg") | |
.attr("width", svg_width + margin.left + margin.right) | |
.attr("height", svg_height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
var nodes = d3.hierarchy(data); | |
var treemap = d3.tree() | |
.size([svg_width, svg_height]); | |
nodes = treemap(nodes); | |
var n_classes = nodes.descendants().slice(1)[0].data.values.length; | |
var n_samples = nodes.data.values.reduce(add, 0); | |
var hue_scale = d3.scaleLinear().domain([0, n_classes]).range([0, 360]); | |
function add(a, b) {return a + b;}; | |
function purity(values){return Math.max(...values)/values.reduce(add, 0);}; | |
function indexOfMax(arr) { | |
if (arr.length === 0) { | |
return -1; | |
} | |
var max = arr[0]; | |
var maxIndex = 0; | |
for (var i = 1; i < arr.length; i++) { | |
if (arr[i] > max) { | |
maxIndex = i; | |
max = arr[i]; | |
} | |
}; | |
return maxIndex; | |
}; | |
function get_link_color(values){ | |
var hue = (hue_scale(indexOfMax(values)) + 30)%360; | |
var saturation = 100*purity(values); | |
var lightness = 120 * (1-purity(values)/2); | |
var color = d3.hcl(hue, saturation, lightness); | |
return color; | |
}; | |
function getBB(selection) { | |
selection.each(function(d){d.bbox = this.getBBox();}) | |
}; | |
function legend_colors(i){ | |
var values = [0, 0, 0]; | |
values[i] = 1; | |
return get_link_color(values); | |
}; | |
function path_width(values){ | |
var path_width = 20; | |
return values.reduce(add,0)*path_width/n_samples; | |
}; | |
// this function draws all tree paths and arrows | |
function draw_tree(nodes){ | |
var graph = svg.append('g').attr('id', 'graph'); | |
var link = graph.selectAll('.g') | |
.data(nodes.descendants().slice(1).reverse()) | |
.enter().append('g') | |
var paths = link.append("path") | |
.attr("class", "link") | |
.attr("d", function(d) { | |
return "M" + d.x + "," + d.y | |
+ "C" + d.x + "," + (d.y + d.parent.y) / 2 | |
+ " " + d.parent.x + "," + (d.y + d.parent.y) / 2 | |
+ " " + d.parent.x + "," + d.parent.y; | |
}) | |
.attr('stroke-linecap', 'round') | |
.attr('style', function(d) {return 'fill:None;stroke:' | |
+ get_link_color(d.data.values) + ';stroke-width:' | |
+ path_width(d.data.values) + ';'}); | |
var arrows = link.filter(function (d) { | |
return path_width(d.data.values) > 1;}).append("path") | |
.attr('class', 'arrowhead') | |
.attr('stroke-linecap', 'round') | |
.attr("d", function(d) { | |
width = path_width(d.data.values); | |
return "M" + (d.x - width/1.5) + "," + d.y | |
+ "L" + d.x + "," + (d.y + width) | |
+ "L" + (d.x + width/1.5) + "," + d.y;}) | |
.style("fill", function (d) { | |
return get_link_color(d.data.values);}) | |
.style('stroke', 'white') | |
.style('stroke-width', function(d){ | |
return path_width(d.data.values)/5}); | |
return link; | |
}; | |
function draw_rules(nodes){ | |
var node = svg.selectAll(".rule").data(nodes) | |
.enter().append("g") | |
.attr('class', 'rule') | |
.attr('transform', function(d) { | |
return "translate(" + d.x + "," + d.y + ")";}); | |
var path_nodes = node.filter(function (d, i) { return i > 0;}); | |
path_nodes.append("text") | |
.attr("dy", ".35em") | |
.attr("y", -10) | |
.style("text-anchor", "middle") | |
.text(function(d) { return d.data.rule;}).call(getBB) | |
.attr('pointer-events', 'none') | |
.attr('font-size', '12'); | |
path_nodes.insert("rect","text") | |
.attr("x", function(d){return -d.bbox.width/2}) | |
.attr("y", function(d){return -d.bbox.height}) | |
.attr("width", function(d){return d.bbox.width}) | |
.attr("height", function(d){return d.bbox.height}) | |
.style("fill", "white") | |
.attr('pointer-events', 'none') | |
.style('opacity', 0.8); | |
var leaf_node = node.filter(function (d, i) { return i == 0;}); | |
leaf_node.append("text") | |
.attr("dy", ".35em") | |
.attr("y", -10) | |
.style("text-anchor", "middle") | |
.attr('font-size', '12') | |
.text(function(d) { return d.data.rule;}).call(getBB) | |
.attr('pointer-events', 'none'); | |
leaf_node.insert("rect", "text") | |
.attr("x", function(d){return -d.bbox.width/2}) | |
.attr("y", function(d){return -d.bbox.height}) | |
.attr("width", function(d){return d.bbox.width}) | |
.attr("height", function(d){return 50}) | |
.style("fill", "white") | |
.attr('pointer-events', 'none') | |
.style('opacity', 0.8); | |
var total_width = 100; | |
var bar_width = total_width / n_classes; | |
var bar_height = 20; | |
var max_value = Math.max(...nodes[0].data.values); | |
var histogram = leaf_node.append('g').selectAll('.rect') | |
.data(nodes[0].data.values) | |
.enter(); | |
histogram.append('rect') | |
.attr('x', function(d, i) {return -total_width/2 + i * bar_width;}) | |
.attr('y', function(d) { | |
return 10 + bar_height - d * bar_height/max_value}) | |
.attr('width', 0.9 * bar_width) | |
.attr('height', function(d) {return d * bar_height/max_value}) | |
.attr('fill', function(d, i) {return legend_colors(i)}); | |
}; | |
// draw tree and add the mouse events to the tree | |
var link = draw_tree(nodes); | |
link.on("mouseover", function (d) { | |
var path = d3.select(this); | |
path.transition().duration("4000"); | |
var ancestors = d.ancestors(); | |
draw_rules(ancestors); | |
}) | |
.on("mouseout", function (d) { | |
var path = d3.select(this).transition().duration("4000"); | |
svg.selectAll('.rule.rule').remove(); | |
}); | |
// legend | |
var rect_width = 20; | |
var max_value = Math.max(...nodes.data.values); | |
var legend = svg.append('g').attr('id', 'legend') | |
.selectAll('.rect') | |
.data(nodes.data.class_names) | |
.enter(); | |
legend.append('rect') | |
.attr('x', 0) | |
.attr('y', function(d, i) {return i*(1.1 * rect_width);}) | |
.attr('width', function(d, i) { | |
return nodes.data.values[i]*2*rect_width/max_value;}) | |
.attr('height', 20) | |
.attr('fill', function(d, i) {return legend_colors(i);}); | |
legend.insert('text') | |
.attr('x', 2) | |
.attr('y', function(d, i) {return i*(1.1 * rect_width);}) | |
.attr('dy', rect_width/2 + 4) | |
.text(function(d, i) {return nodes.data.values[i];}) | |
.style('fill', 'black') | |
.attr('font-size', '12') | |
.style('text-anchor', "left"); | |
legend.insert('text') | |
.attr('x', function(d, i) { | |
return 5 + 2*nodes.data.values[i]*rect_width/max_value;}) | |
.attr('y', function(d, i) {return i*(1.1 * rect_width);}) | |
.attr('dy', rect_width/2 + 4) | |
.text(function(d) {return d;}) | |
.style('fill', 'black') | |
.attr('font-size', '12') | |
.style('text-anchor', "left"); | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment