Last active
August 29, 2015 14:07
-
-
Save bufas/831b8225e94cd3fbf2c7 to your computer and use it in GitHub Desktop.
Barebones ID3 Classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = 'Mathias Bak Bertelsen' | |
__email__ = '[email protected]' | |
from __future__ import division | |
import math | |
def segment_set_by_class(es): | |
return segment_set_by_attr(es, len(es[0])-1) | |
def segment_set_by_attr(es, attr_idx): | |
segmented = {} | |
for s in es: | |
attr_val = s[attr_idx] | |
if attr_val not in segmented: | |
segmented[attr_val] = [s] | |
else: | |
segmented[attr_val].append(s) | |
return segmented | |
def entropy(s): | |
return sum([-len(Sc)/len(s) * math.log(len(Sc)/len(s)) for _, Sc in segment_set_by_class(s).iteritems()]) | |
def gain(s, attr_idx): | |
return entropy(s) - sum([(len(sv))/len(s) * entropy(sv) for _, sv in segment_set_by_attr(s, attr_idx).iteritems()]) | |
def grow(train, attributes): | |
# Check if this should be a leaf | |
segmented = segment_set_by_class(train) | |
if len(segmented) == 1: | |
return '#LEAF#', segmented.keys()[0] | |
# Pick the root node | |
split_attr = (None, -1) | |
for i in range(len(attributes) - 1): | |
information_gain = gain(train, i) | |
if information_gain > split_attr[1]: | |
split_attr = (i, information_gain) | |
# Recurse on subtrees | |
# (attr, {val1: subtree, val2: subtree, val3: subtree}) | |
subtree_representation = {} | |
subtrees = segment_set_by_attr(train, split_attr[0]) | |
for val, subtree in subtrees.iteritems(): | |
subtree_representation[val] = grow(subtree, attributes) | |
return split_attr[0], subtree_representation | |
def classify(tree, instance): | |
if tree[0] == '#LEAF#': | |
return tree[1] | |
return classify(tree[1][instance[tree[0]]], instance) | |
# USAGE | |
# Create a training set | |
a = ('outlook', 'temp', 'humidity', 'wind', 'play') | |
x = [('sunny', 'hot', 'high', 'weak', 'No'), | |
('sunny', 'hot', 'high', 'strong', 'No'), | |
('overcast', 'hot', 'high', 'weak', 'Yes'), | |
('rain', 'mild', 'high', 'weak', 'Yes'), | |
('rain', 'cool', 'normal', 'weak', 'Yes'), | |
('rain', 'cool', 'normal', 'strong', 'No'), | |
('overcast', 'cool', 'normal', 'strong', 'Yes'), | |
('sunny', 'mild', 'high', 'weak', 'No'), | |
('sunny', 'cold', 'normal', 'weak', 'Yes'), | |
('rain', 'mild', 'normal', 'weak', 'Yes'), | |
('sunny', 'mild', 'normal', 'strong', 'Yes'), | |
('overcast', 'mild', 'high', 'strong', 'Yes'), | |
('overcast', 'hot', 'normal', 'weak', 'Yes'), | |
('rain', 'mild', 'high', 'strong', 'No')] | |
# Train the model | |
my_tree = grow(x, a) | |
# Classify some instances | |
print classify(my_tree, ('sunny', 'cool', 'normal', 'strong')) | |
print classify(my_tree, ('overcast', 'mild', 'normal', 'weak')) | |
print classify(my_tree, ('rain', 'hot', 'high', 'strong')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment