Skip to content

Instantly share code, notes, and snippets.

@zhenghaoz
Last active November 14, 2017 01:44
Show Gist options
  • Save zhenghaoz/894f380992be4f0df08740042d22b557 to your computer and use it in GitHub Desktop.
Save zhenghaoz/894f380992be4f0df08740042d22b557 to your computer and use it in GitHub Desktop.
Decision Tree Implementation
from queue import Queue
import graphviz as gv
import numpy as np
from sklearn import datasets
from sklearn.cross_validation import train_test_split
class DecisionTree:
# Decision tree
tree = {}
# Data sets
x = []
y = []
cv_x = []
cv_y = []
# Parameters
selector = None
pruning = 'off'
max_depth = 0
max_node = 0
# The index counter of nodes in graph
node_id = 0
# Constructor
def __init__(self, selector='id3', pruning='off', max_depth=0, max_node=0):
# Check parameters
if selector == 'id3':
self.selector = self.__id3__
elif selector == 'c4.5':
self.selector = self.__c45__
elif selector == 'gini':
self.selector = self.__gini__
else:
raise ValueError("No such selection measure '%s'" % selector)
if pruning not in ['off', 'pre', 'post']:
raise ValueError("No such pruning strategy '%s'" % pruning)
# Assign parameters
self.pruning = pruning
self.max_depth = max_depth
self.max_node = max_node
# Build decision tree according to data
def fit(self, x, y, cv_x=np.array([]), cv_y=np.array([])):
assert len(x) > 0
assert len(x[0]) > 0
assert len(x) == len(y)
assert len(cv_x) == len(cv_y)
assert len(cv_x) == 0 or len(x[0]) == len(cv_x[0])
# Analyze attributes
attributes = []
num_attr = len(x.transpose())
for i in range(0, num_attr):
value = x[0][i]
if type(value) is float or type(value) is np.float64:
attributes.append((i, 'continuous'))
elif type(value) is str or type(value) is np.str_:
attributes.append((i, 'discrete'))
else:
raise NotImplemented("Unsupported value type '%s'" % type(value))
# Save data sets
self.x = x
self.y = y
self.cv_x = cv_x
self.cv_y = cv_y
# Generate decision tree
queue = Queue()
super_node = {}
node_limit = self.max_node - 1
queue.put((super_node, 'root', np.arange(0, len(x)), np.arange(0, len(cv_x)), attributes, 0))
while not queue.empty():
parent_node, node_key, rows, cv_rows, attributes, depth = queue.get()
node_limit -= self.__build__(parent_node, node_key, rows, cv_rows, attributes, depth, node_limit, queue)
self.tree = super_node['root']
# Post-pruning: pruning after tree built
if self.pruning == 'post':
self.tree = self.__post_pruning__(self.tree, np.arange(0, len(cv_x)))
# Predict with decision tree
def predict(self, x):
node = self.tree
while node['type'] != 'leaf':
index = node['index']
value = node['value']
if value == 'discrete':
node = node['children'][x[index]]
elif value == 'continuous':
if x[index] <= node['divider']:
node = node['less_equal']
else:
node = node['greater']
return node['tag']
# Visualize decision tree with graphviz
def visualize(self, attr_names):
graph = gv.Digraph(format='svg')
self.node_id = 0
self.__draw__(graph, self.tree, attr_names)
return graph
def __build__(self, parent_node, node_key, rows, cv_rows, attributes, depth, node_limit, queue):
x = self.x[rows]
y = self.y[rows]
cv_x = self.cv_x[cv_rows]
cv_y = self.cv_y[cv_rows]
# Samples in data set have the same class
if len(np.unique(y)) == 1:
parent_node[node_key] = self.__build_leaf__(np.unique(y)[0])
return 0
# Samples in data set have the same attributes
if depth == self.max_depth > 0 or self.__same__(x, [w[0] for w in attributes]):
parent_node[node_key] = self.__build_leaf__(self.__most__(y))
return 0
# Find the best attribute
best_split_rows = []
best_va = 0
best_attr = None
best_divider = None
best_dividers = []
for attr in attributes:
split_rows = []
va = 0
divider = 0
dividers = []
if attr[1] == 'continuous':
va, split_rows, divider = self.__continuous_divide__(rows, attr[0])
elif attr[1] == 'discrete':
va, split_rows, dividers = self.__discrete_divide__(rows, attr[0])
if va > best_va:
best_va = va
best_attr = attr
best_split_rows = split_rows
if attr[1] == 'continuous':
best_divider = divider
elif attr[1] == 'discrete':
best_dividers = dividers
# When the value is too small
if best_va == 0:
return self.__build_leaf__(self.__most__(y))
# Pre-pruning
column_index, column_type = best_attr
split_cv_rows = []
if self.pruning == 'pre' and len(cv_rows) > 0:
correct_not_split = np.sum(np.equal(cv_y, self.__most__(y)))
correct_split = 0
if best_attr[1] == 'discrete':
# Split cv data sets
for div in best_dividers:
group = np.where(np.equal(cv_x[:, column_index], div))[0]
split_cv_rows.append(cv_rows[group])
for i in range(0, len(split_cv_rows)):
correct_split += np.sum(np.equal(self.cv_y[split_cv_rows[i]], self.__most__(self.y[best_split_rows[i]])))
elif best_attr[1] == 'continuous':
# Split cv data sets
low_rows = cv_rows[np.where(np.less_equal(cv_x[:, column_index], best_divider))]
high_rows = cv_rows[np.where(np.greater(cv_x[:, column_index], best_divider))]
split_cv_rows = [low_rows, high_rows]
correct_split = np.sum(
np.equal(self.cv_y[low_rows], self.__most__(self.y[best_split_rows[0]]))) + \
np.sum(
np.equal(self.cv_y[high_rows], self.__most__(self.y[best_split_rows[1]])))
if correct_split <= correct_not_split:
parent_node[node_key] = self.__build_leaf__(self.__most__(y))
return 0
else:
# Create empty cv data sets if pre-pruning not applied
if best_attr[1] == 'discrete':
split_cv_rows = np.empty(len(best_dividers), np.ndarray)
elif best_attr[1] == 'continuous':
split_cv_rows = np.empty(2, np.ndarray)
# Check limitation of nodes
if self.max_node > 0 and (best_attr[1] == 'discrete' and len(best_dividers) > node_limit or
best_attr[1] == 'continuous' and 2 > node_limit):
parent_node[node_key] = self.__build_leaf__(self.__most__(y))
return 0
# Generate sub trees
parent_node[node_key] = {'type': 'internal', 'index': column_index, 'value': column_type}
if self.pruning == 'post':
parent_node[node_key]['most'] = self.__most__(y)
if best_attr[1] == 'discrete':
parent_node[node_key]['children'] = {}
for i in range(0, len(best_split_rows)):
div = best_dividers[i]
queue.put((parent_node[node_key]['children'], div, best_split_rows[i], split_cv_rows[i], attributes, depth + 1))
attributes.remove(best_attr)
return len(best_split_rows)
elif best_attr[1] == 'continuous':
parent_node[node_key]['divider'] = best_divider
queue.put((parent_node[node_key], 'less_equal', best_split_rows[0], split_cv_rows[0], attributes, depth + 1))
queue.put((parent_node[node_key], 'greater', best_split_rows[1], split_cv_rows[1], attributes, depth + 1))
return 2
def __discrete_divide__(self, rows, column_index):
# Retrieve train data set
x = self.x[rows]
y = self.y[rows]
# Get all values
column = x[:, column_index]
values = np.unique(column)
total = len(x)
# Split and calculate the value of selection assessment
base_va = self.selector(y)
va = self.selector(y, base_va)
split_rows = []
for value in values:
group = np.where(column == value)[0]
sub_y = y[group]
va -= len(group)/total * self.selector(sub_y, base_va)
split_rows.append(rows[group])
return va, split_rows, values
def __continuous_divide__(self, rows, column_index):
# Retrieve train data set
x = self.x[rows]
y = self.y[rows]
# Get all dividers
column = x[:, column_index]
sorted_column = np.sort(column)
dividers = (sorted_column[0:-1] + sorted_column[1:]) / 2
total = len(x)
# Find the best divider
best_div = 0
base_va = self.selector(y)
min_va = base_va
split_rows = []
for div in dividers:
# Split train data set
low_group = np.where(column <= div)[0]
high_group = np.where(column > div)[0]
# Calculate the value of division assessment
low_count = len(low_group)
high_count = len(high_group)
low_y = y[low_group]
high_y = y[high_group]
temp_va = low_count/total*self.selector(low_y, base_va) + high_count / total * self.selector(high_y, base_va)
# Update best divider
if temp_va < min_va:
min_va = temp_va
best_div = div
split_rows = [rows[low_group], rows[high_group]]
return base_va - min_va, split_rows, best_div
@staticmethod
def __id3__(y, base=1):
ent = 0
tags = np.unique(y)
total = len(y)
for tag in tags:
p = np.sum(y == tag) / total
ent -= p * np.log2(p)
return ent
@staticmethod
def __c45__(y, base=1):
ent = 0
tags = np.unique(y)
total = len(y)
for tag in tags:
p = np.sum(y == tag) / total
ent -= p * np.log2(p)
return ent / base
@staticmethod
def __gini__(y, base=1):
gini = 1
tags = np.unique(y)
total = len(y)
for tag in tags:
p = np.sum(y == tag) / total
gini -= p*p
return gini
@staticmethod
def __build_leaf__(tag):
return {'type':'leaf','tag':tag}
@staticmethod
def __most__(arr):
(values, counts) = np.unique(arr, return_counts=True)
ind = np.argmax(counts)
return values[ind]
@staticmethod
def __same__(x, column_indices):
for index in column_indices:
if len(np.unique(x[:, index])) > 1:
return False
return True
def __draw__(self, g, node, header):
if node['type'] == 'internal':
parent = header[node['index']] + ' [' + str(self.node_id) + ']'
self.node_id += 1
g.node(parent)
if node['value'] == 'discrete':
for value, sub in node['children'].items():
child = self.__draw__(g, sub, header)
g.edge(parent, child, value)
elif node['value'] == 'continuous':
left = self.__draw__(g, node['less_equal'], header)
right = self.__draw__(g, node['greater'], header)
divider = node['divider']
g.edge(str(parent), str(left), '<=%f' % divider)
g.edge(str(parent), str(right), '>%f' % divider)
return parent
elif node['type'] == 'leaf':
tag = str(node['tag']) + ' [' + str(self.node_id) + ']'
self.node_id += 1
g.node(tag)
return tag
def __post_pruning__(self, tree, cv_rows):
# Pruning only apply in internal nodes
if tree['type'] == 'leaf':
return tree
# Calculate the accuracy when not split
correct_not_split = np.sum(np.equal(self.cv_y[cv_rows], tree['most']))
correct_split = 0
pruned = True
index = tree['index']
if tree['value'] == 'discrete': # Discrete attribute
children = tree['children']
# Split cv data set
for key in children.keys():
child = children[key]
sub_cv_rows = cv_rows[np.where(np.equal(self.cv_x[cv_rows, index], key))]
sub_cv_y = self.cv_y[sub_cv_rows]
tree['children'][key] = self.__post_pruning__(child, sub_cv_rows)
if tree['children'][key]['type'] == 'leaf':
correct_split += np.sum(np.equal(sub_cv_y, tree['children'][key]['tag']))
else:
pruned = False
elif tree['value'] == 'continuous': # Continuous attribute
low_child = tree['less_equal']
high_child = tree['greater']
# Split cv data set
divider = tree['divider']
low_rows = cv_rows[np.where(self.cv_x[cv_rows, index] <= divider)]
high_rows = cv_rows[np.where(self.cv_x[cv_rows, index] > divider)]
tree['less_equal'] = self.__post_pruning__(low_child, low_rows)
tree['greater'] = self.__post_pruning__(high_child, high_rows)
if tree['less_equal']['type'] == 'leaf' and tree['greater']['type'] == 'leaf':
correct_split = np.sum(np.equal(self.cv_y[low_rows], tree['less_equal']['tag'])) \
+ np.sum(np.equal(self.cv_y[high_rows], tree['greater']['tag']))
else:
pruned = False
# Pruning when all children are leaves and the accuracy is greater when not split
if pruned and correct_not_split >= correct_split:
return self.__build_leaf__(tree['most'])
return tree
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment