Created
February 16, 2021 09:16
-
-
Save j-adamczyk/dc82f7b54d49f81cb48ac87329dba95e to your computer and use it in GitHub Desktop.
Plotting decision trees with Graphviz with disk operations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List | |
import graphviz | |
import matplotlib.image as plt_img | |
import numpy as np | |
from sklearn.tree import DecisionTreeClassifier, export_graphviz | |
def plot_disk_operations(clf: DecisionTreeClassifier, | |
feature_names: List[str], | |
class_names: List[str]) -> np.ndarray: | |
# 1st disk operation: write DOT | |
export_graphviz(clf, out_file="decision_tree.dot", | |
feature_names=feature_names, | |
class_names=class_names, | |
label="all", filled=True, impurity=False, | |
proportion=True, rounded=True, precision=2) | |
# 2nd disk operation: read DOT | |
graph = graphviz.Source.from_file("decision_tree.dot") | |
# 3rd disk operation: write image | |
graph.render("decision_tree", format="png") | |
# 4th disk operation: read image | |
image = plt_img.imread("decision_tree.png") | |
# 5th and 6th disk operations: delete files | |
os.remove("decision_tree.dot") | |
os.remove("decision_tree.png") | |
return image |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment