Skip to content

Instantly share code, notes, and snippets.

@janinge
Last active October 19, 2015 20:20
Show Gist options
  • Save janinge/89bd0662121496f7f88d to your computer and use it in GitHub Desktop.
Save janinge/89bd0662121496f7f88d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
'''
Implementation of ID3. Tested using Python 3.5.
https://gist.github.com/janinge/89bd0662121496f7f88d
Created on 27. sep. 2015
@author: Jan Inge Sande
'''
from math import log2
from operator import itemgetter
def Entropy(distribution):
return sum([-p * log2(p) for p in distribution])
def Gain(examples, attribute):
entropyReduction = Entropy(examples.distribution())
for a in attribute:
entropyReduction -= (float(len(a)) / len(examples)) * Entropy(a.distribution())
return entropyReduction
def SplitInformation(examples, attribute):
return -sum([ (float(len(a)) / len(examples)) * log2(float(len(a)) / len(examples)) for a in attribute ])
def GainRatio(examples, attribute):
return Gain(examples, attribute) / max(SplitInformation(examples, attribute), 0.1) # TODO: avg.
def ID3(examples):
# If all examples are positive/negative/...
if examples.all_equal():
# return a single-node tree root, with label = +/-/etc.
return Tree(examples.label(0))
# If attributes is empty...
if not examples.not_partitioned():
# return root with label = most common value
return Tree(examples.most_common_label())
# Otherwise
gain = { attr : GainRatio(examples, part.values()) for attr, part in examples.partitions().items() }
best_attribute = max(gain.items(), key=lambda x: x[1])[0] # A
children = []
for attribute, subset in examples[best_attribute].items():
children.append((attribute, ID3(subset)))
return Tree(attribute=best_attribute, children=tuple(children))
# Types
from collections import namedtuple, Counter
class Tree(namedtuple('Tree', ('label', 'attribute', 'children'))):
def __new__(cls, label=None, attribute=None, children=()):
return super(Tree, cls).__new__(cls, label, attribute, children)
class Examples():
def __init__(self, examples_list, target_attribute, partitioned_attributes=()):
self.examples = list(examples_list)
self.target = target_attribute
self.partitioned = partitioned_attributes
if 0 < target_attribute >= len(self.examples[0]):
raise IndexError('Target attribute out of range')
def __getitem__(self, key):
split = {}
for example in self.examples:
split.setdefault(example[key], []).append(example)
return { k: Examples(ex, self.target, self.partitioned + (key,)) for (k, ex) in split.items() }
def __len__(self):
return len(self.examples)
def __lt__(self, o):
return len(self.examples) < len(o.examples)
def __eq__(self, o):
return len(self.examples) == len(o.examples)
def attributes(self):
return tuple(range(len(self.examples[0])))
def attribute_names(self):
if hasattr(self.examples[0], '_fields'):
return self.examples[0]._fields
return self.attributes()
def not_partitioned(self):
return set(self.attributes()) - set(self.partitioned) - set((self.target,))
def partitions(self):
return { attribute : self[attribute] for attribute in self.not_partitioned() }
def distribution(self):
dist = []
for examples in self[self.target].values():
dist.append(float(len(examples)) / len(self))
return dist
def all_equal(self):
return len(self[self.target]) == 1
def most_common_label(self):
return Counter(self[self.target]).most_common(1)[0][0]
def label(self, index):
return self.examples[index][self.target]
def predict(tree, sample):
if not tree.children:
return tree.label
observation = sample[tree.attribute]
for attribute, children in tree.children:
if observation == attribute:
return predict(children, sample)
def prune(tree, validation):
def walk(tree, validation):
if not tree.children or not validation:
return
most_common[tree] = validation.most_common_label()
accuracy[tree] = float(correct_predictions(tree, validation)) / len(validation.examples)
partitions = validation[tree.attribute]
for attribute, children in tree.children:
walk(children, partitions.get(attribute))
most_common = {}
accuracy = {}
walk(tree, validation)
limit = correct_predictions(tree, validation)
for node,_ in sorted(accuracy.items(), key=itemgetter(1), reverse=True):
new_tree = rebuild(tree, node, Tree(label=most_common[node]))
if new_tree != tree and correct_predictions(new_tree, validation) >= limit:
return prune(new_tree, validation)
return tree
def rebuild(tree, node, leaf):
if not tree.children:
return Tree(tree.label)
subtrees = []
for attribute, child in tree.children:
if child == node:
subtrees.append((attribute, leaf))
else:
subtrees.append((attribute, rebuild(child, node, leaf)))
return Tree(attribute=tree.attribute, children=tuple(subtrees))
# Mostly I/O related functions bellow this line ----
def build_tree(examples_csv, target_attribute, validation_fraction):
from random import sample
examples = import_examples(examples_csv)
target = int(target_attribute)
samples = max(min(int(float(validation_fraction) * len(examples)), len(examples) - 1), 0)
validation = sample(examples, samples)
examples = Examples(set(examples) - set(validation), target)
validation = Examples(validation, target)
print("Imported {0:d} examples for training, and {1:d} for validation.".
format(len(examples), len(validation)))
tree = ID3(examples)
pruned_tree = prune(tree, validation)
print("Accuracy on training set before pruning:",
float(correct_predictions(tree, examples)) / len(examples.examples))
print("Accuracy on training set after pruning:",
float(correct_predictions(pruned_tree, examples)) / len(examples.examples))
print("Accuracy on validation set before pruning:",
float(correct_predictions(tree, validation)) / len(validation.examples))
print("Accuracy on validation set after pruning:",
float(correct_predictions(pruned_tree, validation)) / len(validation.examples))
print_sens_spec(pruned_tree, validation, zip(validation[validation.target].keys()))
print("Tree description before pruning:")
print_rules(tree, examples.target, examples.attribute_names())
print("Tree description after pruning:")
print_rules(pruned_tree, examples.target, examples.attribute_names())
def print_sens_spec(tree, validation, cases):
for case in cases:
tp, tn, fp, fn = calculate_confusion(tree, validation, case)
print("Sensitivity (positive case: %s):" % '/'.join(case), float(tp) / (tp + fn))
print("Specificity (positive case: %s):" % '/'.join(case), float(tn) / (tn + fp))
def print_rules(tree, target, names):
rules = find_paths(tree)
rules.sort(key=lambda x: len(x))
for rule in rules:
rl = []
for node in rule[:-1]:
rl.append('<' + names[node[0]] + ' = ' + node[1] + '>')
print(' AND '.join(rl), '-->', names[target], '=', rule[-1])
def find_paths(tree):
paths = []
def walk(tree, path):
if not tree.children:
paths.append(path + (tree.label,))
return
for child in tree.children:
p = path + ((tree.attribute, child[0]), )
walk(child[1], p)
walk(tree, ())
return paths
def correct_predictions(tree, examples):
correct = 0
for sample in examples.examples:
if predict(tree, sample) == sample[examples.target]:
correct += 1
return correct
def calculate_confusion(tree, examples, positives):
tp = tn = fp = fn = 0
for sample in examples.examples:
prediction = predict(tree, sample)
validation = sample[examples.target]
if validation in positives:
if prediction == validation:
tp += 1
else:
fn += 1
else:
if prediction == validation:
tn += 1
else:
fp += 1
return (tp, tn, fp, fn)
def import_examples(filename):
from csv import Sniffer, reader
with open(filename) as fp:
dialect = Sniffer().sniff(fp.read(2048))
fp.seek(0)
has_header = Sniffer().has_header(fp.read(2048))
fp.seek(0)
csv = reader(fp, dialect)
example = namedtuple('Example', [s.replace('-', '_') for s in next(csv)],
rename=True) if has_header else lambda *x: tuple(x)
examples = []
for line in csv: # Missing attributes are denoted by '?'
examples.append(example(*[x if x != '?' else None for x in line]))
return examples
if __name__ == '__main__':
from sys import argv, exit
if len(argv) != 4:
print("Usage: ./ID3.py <examples-csv> <target-attribute> <validation-fraction>")
exit(1)
build_tree(*argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment