Created
October 30, 2013 00:11
-
-
Save dansondergaard/7225034 to your computer and use it in GitHub Desktop.
A simple decision tree for classification.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Decision Tree""" | |
| __all__ = ['grow', 'classify'] | |
| class Node(object): | |
| def __init__(self, error, dimension, value, left, right): | |
| self.error = error | |
| self.dimension = dimension | |
| self.value = value | |
| self.left = left | |
| self.right = right | |
| @property | |
| def depth(self): | |
| return 1 + max(self.left.depth, self.right.depth) | |
| class Leaf(object): | |
| def __init__(self, category): | |
| self.category = category | |
| @property | |
| def depth(self): | |
| return 1 | |
| def compute_partial_error(partition): | |
| def majority_vote(partition): | |
| categories = list(category for _, category in partition) | |
| return max(set(categories), key=categories.count) | |
| def wrongly_classified(partition, majority): | |
| return len([point | |
| for point, category in partition | |
| if category != majority]) | |
| majority = majority_vote(partition) | |
| return wrongly_classified(partition, majority) / float(len(partition)) | |
| def compute_error(left, right): | |
| return compute_partial_error(left) + compute_partial_error(right) | |
| def minimize_error(dataset): | |
| min_error = 10 ** 100 | |
| min_dimension, min_value = 0, 0 | |
| min_left, min_right = [], [] | |
| first_point, _ = dataset[0] | |
| dimensions = len(first_point) | |
| for dimension in range(dimensions): | |
| sorted_dataset = sorted(dataset, key=lambda v: v[0][dimension]) | |
| for index, (point, category) in enumerate(sorted_dataset): | |
| if index == 0 or index == len(dataset): | |
| continue | |
| left, right = sorted_dataset[0:index], sorted_dataset[index:] | |
| curr_error = compute_error(left, right) | |
| if curr_error < min_error: | |
| min_error = curr_error | |
| min_dimension, min_value = dimension, point[dimension] | |
| min_left, min_right = left, right | |
| return min_error, min_dimension, min_value, min_left, min_right | |
| def grow(dataset): | |
| if compute_partial_error(dataset) == 0: | |
| _, category = dataset[0] | |
| return Leaf(category) | |
| error, dimension, value, left, right = minimize_error(dataset) | |
| return Node(error, | |
| dimension, | |
| value, | |
| grow(left), | |
| grow(right)) | |
| def classify(tree, point): | |
| if hasattr(tree, 'category'): | |
| return tree.category | |
| dimension, value = tree.dimension, tree.value | |
| if point[dimension] < value: | |
| return classify(tree.left, point) | |
| else: | |
| return classify(tree.right, point) | |
| if __name__ == '__main__': | |
| dataset = [((1, 2), 0), ((2, 3), 0), ((3, 2), 0), ((3, 5), 0), | |
| ((4, 1), 1), ((4, 4), 1), ((5, 4), 1), ((6, 3), 1), | |
| ((6, 5), 1), ((2, 7), 1)] | |
| test_point1 = (5, 3) | |
| test_point2 = (2, 6) | |
| tree = grow(dataset) | |
| print test_point1, classify(tree, test_point1) | |
| print test_point2, classify(tree, test_point2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment