Skip to content

Instantly share code, notes, and snippets.

@dansondergaard
Created October 30, 2013 00:11
Show Gist options
  • Select an option

  • Save dansondergaard/7225034 to your computer and use it in GitHub Desktop.

Select an option

Save dansondergaard/7225034 to your computer and use it in GitHub Desktop.
A simple decision tree for classification.
"""Decision Tree"""
__all__ = ['grow', 'classify']
class Node(object):
def __init__(self, error, dimension, value, left, right):
self.error = error
self.dimension = dimension
self.value = value
self.left = left
self.right = right
@property
def depth(self):
return 1 + max(self.left.depth, self.right.depth)
class Leaf(object):
def __init__(self, category):
self.category = category
@property
def depth(self):
return 1
def compute_partial_error(partition):
def majority_vote(partition):
categories = list(category for _, category in partition)
return max(set(categories), key=categories.count)
def wrongly_classified(partition, majority):
return len([point
for point, category in partition
if category != majority])
majority = majority_vote(partition)
return wrongly_classified(partition, majority) / float(len(partition))
def compute_error(left, right):
return compute_partial_error(left) + compute_partial_error(right)
def minimize_error(dataset):
min_error = 10 ** 100
min_dimension, min_value = 0, 0
min_left, min_right = [], []
first_point, _ = dataset[0]
dimensions = len(first_point)
for dimension in range(dimensions):
sorted_dataset = sorted(dataset, key=lambda v: v[0][dimension])
for index, (point, category) in enumerate(sorted_dataset):
if index == 0 or index == len(dataset):
continue
left, right = sorted_dataset[0:index], sorted_dataset[index:]
curr_error = compute_error(left, right)
if curr_error < min_error:
min_error = curr_error
min_dimension, min_value = dimension, point[dimension]
min_left, min_right = left, right
return min_error, min_dimension, min_value, min_left, min_right
def grow(dataset):
if compute_partial_error(dataset) == 0:
_, category = dataset[0]
return Leaf(category)
error, dimension, value, left, right = minimize_error(dataset)
return Node(error,
dimension,
value,
grow(left),
grow(right))
def classify(tree, point):
if hasattr(tree, 'category'):
return tree.category
dimension, value = tree.dimension, tree.value
if point[dimension] < value:
return classify(tree.left, point)
else:
return classify(tree.right, point)
if __name__ == '__main__':
dataset = [((1, 2), 0), ((2, 3), 0), ((3, 2), 0), ((3, 5), 0),
((4, 1), 1), ((4, 4), 1), ((5, 4), 1), ((6, 3), 1),
((6, 5), 1), ((2, 7), 1)]
test_point1 = (5, 3)
test_point2 = (2, 6)
tree = grow(dataset)
print test_point1, classify(tree, test_point1)
print test_point2, classify(tree, test_point2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment