Last active
October 19, 2015 20:20
-
-
Save janinge/89bd0662121496f7f88d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
''' | |
Implementation of ID3. Tested using Python 3.5. | |
https://gist.github.com/janinge/89bd0662121496f7f88d | |
Created on 27. sep. 2015 | |
@author: Jan Inge Sande | |
''' | |
from math import log2 | |
from operator import itemgetter | |
def Entropy(distribution): | |
return sum([-p * log2(p) for p in distribution]) | |
def Gain(examples, attribute): | |
entropyReduction = Entropy(examples.distribution()) | |
for a in attribute: | |
entropyReduction -= (float(len(a)) / len(examples)) * Entropy(a.distribution()) | |
return entropyReduction | |
def SplitInformation(examples, attribute): | |
return -sum([ (float(len(a)) / len(examples)) * log2(float(len(a)) / len(examples)) for a in attribute ]) | |
def GainRatio(examples, attribute): | |
return Gain(examples, attribute) / max(SplitInformation(examples, attribute), 0.1) # TODO: avg. | |
def ID3(examples): | |
# If all examples are positive/negative/... | |
if examples.all_equal(): | |
# return a single-node tree root, with label = +/-/etc. | |
return Tree(examples.label(0)) | |
# If attributes is empty... | |
if not examples.not_partitioned(): | |
# return root with label = most common value | |
return Tree(examples.most_common_label()) | |
# Otherwise | |
gain = { attr : GainRatio(examples, part.values()) for attr, part in examples.partitions().items() } | |
best_attribute = max(gain.items(), key=lambda x: x[1])[0] # A | |
children = [] | |
for attribute, subset in examples[best_attribute].items(): | |
children.append((attribute, ID3(subset))) | |
return Tree(attribute=best_attribute, children=tuple(children)) | |
# Types | |
from collections import namedtuple, Counter | |
class Tree(namedtuple('Tree', ('label', 'attribute', 'children'))): | |
def __new__(cls, label=None, attribute=None, children=()): | |
return super(Tree, cls).__new__(cls, label, attribute, children) | |
class Examples(): | |
def __init__(self, examples_list, target_attribute, partitioned_attributes=()): | |
self.examples = list(examples_list) | |
self.target = target_attribute | |
self.partitioned = partitioned_attributes | |
if 0 < target_attribute >= len(self.examples[0]): | |
raise IndexError('Target attribute out of range') | |
def __getitem__(self, key): | |
split = {} | |
for example in self.examples: | |
split.setdefault(example[key], []).append(example) | |
return { k: Examples(ex, self.target, self.partitioned + (key,)) for (k, ex) in split.items() } | |
def __len__(self): | |
return len(self.examples) | |
def __lt__(self, o): | |
return len(self.examples) < len(o.examples) | |
def __eq__(self, o): | |
return len(self.examples) == len(o.examples) | |
def attributes(self): | |
return tuple(range(len(self.examples[0]))) | |
def attribute_names(self): | |
if hasattr(self.examples[0], '_fields'): | |
return self.examples[0]._fields | |
return self.attributes() | |
def not_partitioned(self): | |
return set(self.attributes()) - set(self.partitioned) - set((self.target,)) | |
def partitions(self): | |
return { attribute : self[attribute] for attribute in self.not_partitioned() } | |
def distribution(self): | |
dist = [] | |
for examples in self[self.target].values(): | |
dist.append(float(len(examples)) / len(self)) | |
return dist | |
def all_equal(self): | |
return len(self[self.target]) == 1 | |
def most_common_label(self): | |
return Counter(self[self.target]).most_common(1)[0][0] | |
def label(self, index): | |
return self.examples[index][self.target] | |
def predict(tree, sample): | |
if not tree.children: | |
return tree.label | |
observation = sample[tree.attribute] | |
for attribute, children in tree.children: | |
if observation == attribute: | |
return predict(children, sample) | |
def prune(tree, validation): | |
def walk(tree, validation): | |
if not tree.children or not validation: | |
return | |
most_common[tree] = validation.most_common_label() | |
accuracy[tree] = float(correct_predictions(tree, validation)) / len(validation.examples) | |
partitions = validation[tree.attribute] | |
for attribute, children in tree.children: | |
walk(children, partitions.get(attribute)) | |
most_common = {} | |
accuracy = {} | |
walk(tree, validation) | |
limit = correct_predictions(tree, validation) | |
for node,_ in sorted(accuracy.items(), key=itemgetter(1), reverse=True): | |
new_tree = rebuild(tree, node, Tree(label=most_common[node])) | |
if new_tree != tree and correct_predictions(new_tree, validation) >= limit: | |
return prune(new_tree, validation) | |
return tree | |
def rebuild(tree, node, leaf): | |
if not tree.children: | |
return Tree(tree.label) | |
subtrees = [] | |
for attribute, child in tree.children: | |
if child == node: | |
subtrees.append((attribute, leaf)) | |
else: | |
subtrees.append((attribute, rebuild(child, node, leaf))) | |
return Tree(attribute=tree.attribute, children=tuple(subtrees)) | |
# Mostly I/O related functions bellow this line ---- | |
def build_tree(examples_csv, target_attribute, validation_fraction): | |
from random import sample | |
examples = import_examples(examples_csv) | |
target = int(target_attribute) | |
samples = max(min(int(float(validation_fraction) * len(examples)), len(examples) - 1), 0) | |
validation = sample(examples, samples) | |
examples = Examples(set(examples) - set(validation), target) | |
validation = Examples(validation, target) | |
print("Imported {0:d} examples for training, and {1:d} for validation.". | |
format(len(examples), len(validation))) | |
tree = ID3(examples) | |
pruned_tree = prune(tree, validation) | |
print("Accuracy on training set before pruning:", | |
float(correct_predictions(tree, examples)) / len(examples.examples)) | |
print("Accuracy on training set after pruning:", | |
float(correct_predictions(pruned_tree, examples)) / len(examples.examples)) | |
print("Accuracy on validation set before pruning:", | |
float(correct_predictions(tree, validation)) / len(validation.examples)) | |
print("Accuracy on validation set after pruning:", | |
float(correct_predictions(pruned_tree, validation)) / len(validation.examples)) | |
print_sens_spec(pruned_tree, validation, zip(validation[validation.target].keys())) | |
print("Tree description before pruning:") | |
print_rules(tree, examples.target, examples.attribute_names()) | |
print("Tree description after pruning:") | |
print_rules(pruned_tree, examples.target, examples.attribute_names()) | |
def print_sens_spec(tree, validation, cases): | |
for case in cases: | |
tp, tn, fp, fn = calculate_confusion(tree, validation, case) | |
print("Sensitivity (positive case: %s):" % '/'.join(case), float(tp) / (tp + fn)) | |
print("Specificity (positive case: %s):" % '/'.join(case), float(tn) / (tn + fp)) | |
def print_rules(tree, target, names): | |
rules = find_paths(tree) | |
rules.sort(key=lambda x: len(x)) | |
for rule in rules: | |
rl = [] | |
for node in rule[:-1]: | |
rl.append('<' + names[node[0]] + ' = ' + node[1] + '>') | |
print(' AND '.join(rl), '-->', names[target], '=', rule[-1]) | |
def find_paths(tree): | |
paths = [] | |
def walk(tree, path): | |
if not tree.children: | |
paths.append(path + (tree.label,)) | |
return | |
for child in tree.children: | |
p = path + ((tree.attribute, child[0]), ) | |
walk(child[1], p) | |
walk(tree, ()) | |
return paths | |
def correct_predictions(tree, examples): | |
correct = 0 | |
for sample in examples.examples: | |
if predict(tree, sample) == sample[examples.target]: | |
correct += 1 | |
return correct | |
def calculate_confusion(tree, examples, positives): | |
tp = tn = fp = fn = 0 | |
for sample in examples.examples: | |
prediction = predict(tree, sample) | |
validation = sample[examples.target] | |
if validation in positives: | |
if prediction == validation: | |
tp += 1 | |
else: | |
fn += 1 | |
else: | |
if prediction == validation: | |
tn += 1 | |
else: | |
fp += 1 | |
return (tp, tn, fp, fn) | |
def import_examples(filename): | |
from csv import Sniffer, reader | |
with open(filename) as fp: | |
dialect = Sniffer().sniff(fp.read(2048)) | |
fp.seek(0) | |
has_header = Sniffer().has_header(fp.read(2048)) | |
fp.seek(0) | |
csv = reader(fp, dialect) | |
example = namedtuple('Example', [s.replace('-', '_') for s in next(csv)], | |
rename=True) if has_header else lambda *x: tuple(x) | |
examples = [] | |
for line in csv: # Missing attributes are denoted by '?' | |
examples.append(example(*[x if x != '?' else None for x in line])) | |
return examples | |
if __name__ == '__main__': | |
from sys import argv, exit | |
if len(argv) != 4: | |
print("Usage: ./ID3.py <examples-csv> <target-attribute> <validation-fraction>") | |
exit(1) | |
build_tree(*argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment