decision tree
Created on Oct 12, 2010
Decision Tree Source Code for Machine Learning in Action Ch. 3
@author: Peter Harrington
from math import log
import operator
# prepare data set
# create data set and corrsponding labels
def createDataSet():
# yes: flippers
# no: no surfacing
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
# change to discrete values
return dataSet, labels
# calculate shannon entropy
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: # the the number of unique elements and their occurance
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
# calculate probability
prob = float(labelCounts[key]) / numEntries
# calculate shannon entropy
shannonEnt -= prob * log(prob, 2) # log base 2
# the shannon entropy is a sum value
return shannonEnt
# split data set by the feature[axis] and its feature value
# this method is just like a filter method and return the
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# collect the left data set
# the value is the selected feature and should be eliminate
reducedFeatVec = featVec[:axis] # chop out axis used for splitting
# collect the right data set
reducedFeatVec.extend(featVec[axis + 1:])
return retDataSet
# choose the best feature to split data set, return the best feature
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # the last column is used for the labels
# calculate the original shannon entropy
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures): # iterate over all the features
# collect the specific feature value
# create a list of all the examples of this feature
featList = [example[i] for example in dataSet]
# get unique feature values
uniqueVals = set(featList) # get a set of unique values
newEntropy = 0.0
for value in uniqueVals:
# split the data set using the specific feature and the
# corresponding feature value
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# calculate the changed value
# calculate the info gain; ie reduction in entropy
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain): # compare this to the best gain so far
bestInfoGain = infoGain # if better than current best, set to best
bestFeature = i
return bestFeature # returns an integer
# the algorithm of majority cnt
def majorityCnt(classList):
# define a class counter collector
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# iteritems is the iterator of the dict
# operator is a another module support sorting function
sortedClassCount = sorted(classCount.iteritems(),
key=operator.itemgetter(1), reverse=True)
# return the majority classifier of the remain data set
return sortedClassCount[0][0]
# create the decision tree
def createTree(dataSet, labels):
# collect the labels as the classified classes
classList = [example[-1] for example in dataSet]
# case 1: all the data are belong to the the same class
if classList.count(classList[0]) == len(classList):
return classList[0] # stop splitting when all of the classes are equal
# case 2: there are no data to split any more. using the majority algorithm
if len(dataSet[0]) == 1: # stop splitting when there are no more features in dataSet
return majorityCnt(classList)
# case 3: running the algorithm
# choose the best feature to split the data set
bestFeat = chooseBestFeatureToSplit(dataSet)
# get the best feature labels
bestFeatLabel = labels[bestFeat]
# build up the decision tree
myTree = {bestFeatLabel: {}}
# delete the best feature and its corresponding labels
# collect the feature values
featValues = [example[bestFeat] for example in dataSet]
# distinct the feature values
uniqueVals = set(featValues)
for value in uniqueVals:
# simple deep copy
# copy all of labels, so trees don't mess up existing labels
subLabels = labels[:]
# using the best feature and its value to split the data set
# the split data set is a sub data set after split by the best feature and its values
# with the sub labels correspondingly
myTree[bestFeatLabel][value] = createTree(
splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
# using the tree classifier to classify the data set
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
# the decision node classify the testVec recursion
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
# the leaf node classify the testVec as this label
classLabel = valueOfFeat
return classLabel
# store the classify tree persistence
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
# load classify tree for the disk
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
