Last active
November 14, 2017 01:44
-
-
Save zhenghaoz/894f380992be4f0df08740042d22b557 to your computer and use it in GitHub Desktop.
Decision Tree Implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from queue import Queue | |
import graphviz as gv | |
import numpy as np | |
from sklearn import datasets | |
from sklearn.cross_validation import train_test_split | |
class DecisionTree: | |
# Decision tree | |
tree = {} | |
# Data sets | |
x = [] | |
y = [] | |
cv_x = [] | |
cv_y = [] | |
# Parameters | |
selector = None | |
pruning = 'off' | |
max_depth = 0 | |
max_node = 0 | |
# The index counter of nodes in graph | |
node_id = 0 | |
# Constructor | |
def __init__(self, selector='id3', pruning='off', max_depth=0, max_node=0): | |
# Check parameters | |
if selector == 'id3': | |
self.selector = self.__id3__ | |
elif selector == 'c4.5': | |
self.selector = self.__c45__ | |
elif selector == 'gini': | |
self.selector = self.__gini__ | |
else: | |
raise ValueError("No such selection measure '%s'" % selector) | |
if pruning not in ['off', 'pre', 'post']: | |
raise ValueError("No such pruning strategy '%s'" % pruning) | |
# Assign parameters | |
self.pruning = pruning | |
self.max_depth = max_depth | |
self.max_node = max_node | |
# Build decision tree according to data | |
def fit(self, x, y, cv_x=np.array([]), cv_y=np.array([])): | |
assert len(x) > 0 | |
assert len(x[0]) > 0 | |
assert len(x) == len(y) | |
assert len(cv_x) == len(cv_y) | |
assert len(cv_x) == 0 or len(x[0]) == len(cv_x[0]) | |
# Analyze attributes | |
attributes = [] | |
num_attr = len(x.transpose()) | |
for i in range(0, num_attr): | |
value = x[0][i] | |
if type(value) is float or type(value) is np.float64: | |
attributes.append((i, 'continuous')) | |
elif type(value) is str or type(value) is np.str_: | |
attributes.append((i, 'discrete')) | |
else: | |
raise NotImplemented("Unsupported value type '%s'" % type(value)) | |
# Save data sets | |
self.x = x | |
self.y = y | |
self.cv_x = cv_x | |
self.cv_y = cv_y | |
# Generate decision tree | |
queue = Queue() | |
super_node = {} | |
node_limit = self.max_node - 1 | |
queue.put((super_node, 'root', np.arange(0, len(x)), np.arange(0, len(cv_x)), attributes, 0)) | |
while not queue.empty(): | |
parent_node, node_key, rows, cv_rows, attributes, depth = queue.get() | |
node_limit -= self.__build__(parent_node, node_key, rows, cv_rows, attributes, depth, node_limit, queue) | |
self.tree = super_node['root'] | |
# Post-pruning: pruning after tree built | |
if self.pruning == 'post': | |
self.tree = self.__post_pruning__(self.tree, np.arange(0, len(cv_x))) | |
# Predict with decision tree | |
def predict(self, x): | |
node = self.tree | |
while node['type'] != 'leaf': | |
index = node['index'] | |
value = node['value'] | |
if value == 'discrete': | |
node = node['children'][x[index]] | |
elif value == 'continuous': | |
if x[index] <= node['divider']: | |
node = node['less_equal'] | |
else: | |
node = node['greater'] | |
return node['tag'] | |
# Visualize decision tree with graphviz | |
def visualize(self, attr_names): | |
graph = gv.Digraph(format='svg') | |
self.node_id = 0 | |
self.__draw__(graph, self.tree, attr_names) | |
return graph | |
def __build__(self, parent_node, node_key, rows, cv_rows, attributes, depth, node_limit, queue): | |
x = self.x[rows] | |
y = self.y[rows] | |
cv_x = self.cv_x[cv_rows] | |
cv_y = self.cv_y[cv_rows] | |
# Samples in data set have the same class | |
if len(np.unique(y)) == 1: | |
parent_node[node_key] = self.__build_leaf__(np.unique(y)[0]) | |
return 0 | |
# Samples in data set have the same attributes | |
if depth == self.max_depth > 0 or self.__same__(x, [w[0] for w in attributes]): | |
parent_node[node_key] = self.__build_leaf__(self.__most__(y)) | |
return 0 | |
# Find the best attribute | |
best_split_rows = [] | |
best_va = 0 | |
best_attr = None | |
best_divider = None | |
best_dividers = [] | |
for attr in attributes: | |
split_rows = [] | |
va = 0 | |
divider = 0 | |
dividers = [] | |
if attr[1] == 'continuous': | |
va, split_rows, divider = self.__continuous_divide__(rows, attr[0]) | |
elif attr[1] == 'discrete': | |
va, split_rows, dividers = self.__discrete_divide__(rows, attr[0]) | |
if va > best_va: | |
best_va = va | |
best_attr = attr | |
best_split_rows = split_rows | |
if attr[1] == 'continuous': | |
best_divider = divider | |
elif attr[1] == 'discrete': | |
best_dividers = dividers | |
# When the value is too small | |
if best_va == 0: | |
return self.__build_leaf__(self.__most__(y)) | |
# Pre-pruning | |
column_index, column_type = best_attr | |
split_cv_rows = [] | |
if self.pruning == 'pre' and len(cv_rows) > 0: | |
correct_not_split = np.sum(np.equal(cv_y, self.__most__(y))) | |
correct_split = 0 | |
if best_attr[1] == 'discrete': | |
# Split cv data sets | |
for div in best_dividers: | |
group = np.where(np.equal(cv_x[:, column_index], div))[0] | |
split_cv_rows.append(cv_rows[group]) | |
for i in range(0, len(split_cv_rows)): | |
correct_split += np.sum(np.equal(self.cv_y[split_cv_rows[i]], self.__most__(self.y[best_split_rows[i]]))) | |
elif best_attr[1] == 'continuous': | |
# Split cv data sets | |
low_rows = cv_rows[np.where(np.less_equal(cv_x[:, column_index], best_divider))] | |
high_rows = cv_rows[np.where(np.greater(cv_x[:, column_index], best_divider))] | |
split_cv_rows = [low_rows, high_rows] | |
correct_split = np.sum( | |
np.equal(self.cv_y[low_rows], self.__most__(self.y[best_split_rows[0]]))) + \ | |
np.sum( | |
np.equal(self.cv_y[high_rows], self.__most__(self.y[best_split_rows[1]]))) | |
if correct_split <= correct_not_split: | |
parent_node[node_key] = self.__build_leaf__(self.__most__(y)) | |
return 0 | |
else: | |
# Create empty cv data sets if pre-pruning not applied | |
if best_attr[1] == 'discrete': | |
split_cv_rows = np.empty(len(best_dividers), np.ndarray) | |
elif best_attr[1] == 'continuous': | |
split_cv_rows = np.empty(2, np.ndarray) | |
# Check limitation of nodes | |
if self.max_node > 0 and (best_attr[1] == 'discrete' and len(best_dividers) > node_limit or | |
best_attr[1] == 'continuous' and 2 > node_limit): | |
parent_node[node_key] = self.__build_leaf__(self.__most__(y)) | |
return 0 | |
# Generate sub trees | |
parent_node[node_key] = {'type': 'internal', 'index': column_index, 'value': column_type} | |
if self.pruning == 'post': | |
parent_node[node_key]['most'] = self.__most__(y) | |
if best_attr[1] == 'discrete': | |
parent_node[node_key]['children'] = {} | |
for i in range(0, len(best_split_rows)): | |
div = best_dividers[i] | |
queue.put((parent_node[node_key]['children'], div, best_split_rows[i], split_cv_rows[i], attributes, depth + 1)) | |
attributes.remove(best_attr) | |
return len(best_split_rows) | |
elif best_attr[1] == 'continuous': | |
parent_node[node_key]['divider'] = best_divider | |
queue.put((parent_node[node_key], 'less_equal', best_split_rows[0], split_cv_rows[0], attributes, depth + 1)) | |
queue.put((parent_node[node_key], 'greater', best_split_rows[1], split_cv_rows[1], attributes, depth + 1)) | |
return 2 | |
def __discrete_divide__(self, rows, column_index): | |
# Retrieve train data set | |
x = self.x[rows] | |
y = self.y[rows] | |
# Get all values | |
column = x[:, column_index] | |
values = np.unique(column) | |
total = len(x) | |
# Split and calculate the value of selection assessment | |
base_va = self.selector(y) | |
va = self.selector(y, base_va) | |
split_rows = [] | |
for value in values: | |
group = np.where(column == value)[0] | |
sub_y = y[group] | |
va -= len(group)/total * self.selector(sub_y, base_va) | |
split_rows.append(rows[group]) | |
return va, split_rows, values | |
def __continuous_divide__(self, rows, column_index): | |
# Retrieve train data set | |
x = self.x[rows] | |
y = self.y[rows] | |
# Get all dividers | |
column = x[:, column_index] | |
sorted_column = np.sort(column) | |
dividers = (sorted_column[0:-1] + sorted_column[1:]) / 2 | |
total = len(x) | |
# Find the best divider | |
best_div = 0 | |
base_va = self.selector(y) | |
min_va = base_va | |
split_rows = [] | |
for div in dividers: | |
# Split train data set | |
low_group = np.where(column <= div)[0] | |
high_group = np.where(column > div)[0] | |
# Calculate the value of division assessment | |
low_count = len(low_group) | |
high_count = len(high_group) | |
low_y = y[low_group] | |
high_y = y[high_group] | |
temp_va = low_count/total*self.selector(low_y, base_va) + high_count / total * self.selector(high_y, base_va) | |
# Update best divider | |
if temp_va < min_va: | |
min_va = temp_va | |
best_div = div | |
split_rows = [rows[low_group], rows[high_group]] | |
return base_va - min_va, split_rows, best_div | |
@staticmethod | |
def __id3__(y, base=1): | |
ent = 0 | |
tags = np.unique(y) | |
total = len(y) | |
for tag in tags: | |
p = np.sum(y == tag) / total | |
ent -= p * np.log2(p) | |
return ent | |
@staticmethod | |
def __c45__(y, base=1): | |
ent = 0 | |
tags = np.unique(y) | |
total = len(y) | |
for tag in tags: | |
p = np.sum(y == tag) / total | |
ent -= p * np.log2(p) | |
return ent / base | |
@staticmethod | |
def __gini__(y, base=1): | |
gini = 1 | |
tags = np.unique(y) | |
total = len(y) | |
for tag in tags: | |
p = np.sum(y == tag) / total | |
gini -= p*p | |
return gini | |
@staticmethod | |
def __build_leaf__(tag): | |
return {'type':'leaf','tag':tag} | |
@staticmethod | |
def __most__(arr): | |
(values, counts) = np.unique(arr, return_counts=True) | |
ind = np.argmax(counts) | |
return values[ind] | |
@staticmethod | |
def __same__(x, column_indices): | |
for index in column_indices: | |
if len(np.unique(x[:, index])) > 1: | |
return False | |
return True | |
def __draw__(self, g, node, header): | |
if node['type'] == 'internal': | |
parent = header[node['index']] + ' [' + str(self.node_id) + ']' | |
self.node_id += 1 | |
g.node(parent) | |
if node['value'] == 'discrete': | |
for value, sub in node['children'].items(): | |
child = self.__draw__(g, sub, header) | |
g.edge(parent, child, value) | |
elif node['value'] == 'continuous': | |
left = self.__draw__(g, node['less_equal'], header) | |
right = self.__draw__(g, node['greater'], header) | |
divider = node['divider'] | |
g.edge(str(parent), str(left), '<=%f' % divider) | |
g.edge(str(parent), str(right), '>%f' % divider) | |
return parent | |
elif node['type'] == 'leaf': | |
tag = str(node['tag']) + ' [' + str(self.node_id) + ']' | |
self.node_id += 1 | |
g.node(tag) | |
return tag | |
def __post_pruning__(self, tree, cv_rows): | |
# Pruning only apply in internal nodes | |
if tree['type'] == 'leaf': | |
return tree | |
# Calculate the accuracy when not split | |
correct_not_split = np.sum(np.equal(self.cv_y[cv_rows], tree['most'])) | |
correct_split = 0 | |
pruned = True | |
index = tree['index'] | |
if tree['value'] == 'discrete': # Discrete attribute | |
children = tree['children'] | |
# Split cv data set | |
for key in children.keys(): | |
child = children[key] | |
sub_cv_rows = cv_rows[np.where(np.equal(self.cv_x[cv_rows, index], key))] | |
sub_cv_y = self.cv_y[sub_cv_rows] | |
tree['children'][key] = self.__post_pruning__(child, sub_cv_rows) | |
if tree['children'][key]['type'] == 'leaf': | |
correct_split += np.sum(np.equal(sub_cv_y, tree['children'][key]['tag'])) | |
else: | |
pruned = False | |
elif tree['value'] == 'continuous': # Continuous attribute | |
low_child = tree['less_equal'] | |
high_child = tree['greater'] | |
# Split cv data set | |
divider = tree['divider'] | |
low_rows = cv_rows[np.where(self.cv_x[cv_rows, index] <= divider)] | |
high_rows = cv_rows[np.where(self.cv_x[cv_rows, index] > divider)] | |
tree['less_equal'] = self.__post_pruning__(low_child, low_rows) | |
tree['greater'] = self.__post_pruning__(high_child, high_rows) | |
if tree['less_equal']['type'] == 'leaf' and tree['greater']['type'] == 'leaf': | |
correct_split = np.sum(np.equal(self.cv_y[low_rows], tree['less_equal']['tag'])) \ | |
+ np.sum(np.equal(self.cv_y[high_rows], tree['greater']['tag'])) | |
else: | |
pruned = False | |
# Pruning when all children are leaves and the accuracy is greater when not split | |
if pruned and correct_not_split >= correct_split: | |
return self.__build_leaf__(tree['most']) | |
return tree |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment