Created
March 2, 2017 11:54
-
-
Save lttzzlll/48a99d18db8a36a76b8683836b3493ca to your computer and use it in GitHub Desktop.
decision tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Created on Oct 12, 2010 | |
Decision Tree Source Code for Machine Learning in Action Ch. 3 | |
@author: Peter Harrington | |
''' | |
from math import log | |
import operator | |
# prepare data set | |
# create data set and corrsponding labels | |
def createDataSet(): | |
# yes: flippers | |
# no: no surfacing | |
dataSet = [[1, 1, 'yes'], | |
[1, 1, 'yes'], | |
[1, 0, 'no'], | |
[0, 1, 'no'], | |
[0, 1, 'no']] | |
labels = ['no surfacing', 'flippers'] | |
# change to discrete values | |
return dataSet, labels | |
# calculate shannon entropy | |
def calcShannonEnt(dataSet): | |
numEntries = len(dataSet) | |
labelCounts = {} | |
for featVec in dataSet: # the the number of unique elements and their occurance | |
currentLabel = featVec[-1] | |
if currentLabel not in labelCounts.keys(): | |
labelCounts[currentLabel] = 0 | |
labelCounts[currentLabel] += 1 | |
shannonEnt = 0.0 | |
for key in labelCounts: | |
# calculate probability | |
prob = float(labelCounts[key]) / numEntries | |
# calculate shannon entropy | |
shannonEnt -= prob * log(prob, 2) # log base 2 | |
# the shannon entropy is a sum value | |
return shannonEnt | |
# split data set by the feature[axis] and its feature value | |
# this method is just like a filter method and return the | |
def splitDataSet(dataSet, axis, value): | |
retDataSet = [] | |
for featVec in dataSet: | |
if featVec[axis] == value: | |
# collect the left data set | |
# the value is the selected feature and should be eliminate | |
reducedFeatVec = featVec[:axis] # chop out axis used for splitting | |
# collect the right data set | |
reducedFeatVec.extend(featVec[axis + 1:]) | |
retDataSet.append(reducedFeatVec) | |
return retDataSet | |
# choose the best feature to split data set, return the best feature | |
def chooseBestFeatureToSplit(dataSet): | |
numFeatures = len(dataSet[0]) - 1 # the last column is used for the labels | |
# calculate the original shannon entropy | |
baseEntropy = calcShannonEnt(dataSet) | |
bestInfoGain = 0.0 | |
bestFeature = -1 | |
for i in range(numFeatures): # iterate over all the features | |
# collect the specific feature value | |
# create a list of all the examples of this feature | |
featList = [example[i] for example in dataSet] | |
# get unique feature values | |
uniqueVals = set(featList) # get a set of unique values | |
newEntropy = 0.0 | |
for value in uniqueVals: | |
# split the data set using the specific feature and the | |
# corresponding feature value | |
subDataSet = splitDataSet(dataSet, i, value) | |
prob = len(subDataSet) / float(len(dataSet)) | |
newEntropy += prob * calcShannonEnt(subDataSet) | |
# calculate the changed value | |
# calculate the info gain; ie reduction in entropy | |
infoGain = baseEntropy - newEntropy | |
if (infoGain > bestInfoGain): # compare this to the best gain so far | |
bestInfoGain = infoGain # if better than current best, set to best | |
bestFeature = i | |
return bestFeature # returns an integer | |
# the algorithm of majority cnt | |
def majorityCnt(classList): | |
# define a class counter collector | |
classCount = {} | |
for vote in classList: | |
if vote not in classCount.keys(): | |
classCount[vote] = 0 | |
classCount[vote] += 1 | |
# iteritems is the iterator of the dict | |
# operator is a another module support sorting function | |
sortedClassCount = sorted(classCount.iteritems(), | |
key=operator.itemgetter(1), reverse=True) | |
# return the majority classifier of the remain data set | |
return sortedClassCount[0][0] | |
# create the decision tree | |
def createTree(dataSet, labels): | |
# collect the labels as the classified classes | |
classList = [example[-1] for example in dataSet] | |
# case 1: all the data are belong to the the same class | |
if classList.count(classList[0]) == len(classList): | |
return classList[0] # stop splitting when all of the classes are equal | |
# case 2: there are no data to split any more. using the majority algorithm | |
if len(dataSet[0]) == 1: # stop splitting when there are no more features in dataSet | |
return majorityCnt(classList) | |
# case 3: running the algorithm | |
# choose the best feature to split the data set | |
bestFeat = chooseBestFeatureToSplit(dataSet) | |
# get the best feature labels | |
bestFeatLabel = labels[bestFeat] | |
# build up the decision tree | |
myTree = {bestFeatLabel: {}} | |
# delete the best feature and its corresponding labels | |
del(labels[bestFeat]) | |
# collect the feature values | |
featValues = [example[bestFeat] for example in dataSet] | |
# distinct the feature values | |
uniqueVals = set(featValues) | |
for value in uniqueVals: | |
# simple deep copy | |
# copy all of labels, so trees don't mess up existing labels | |
subLabels = labels[:] | |
# using the best feature and its value to split the data set | |
# the split data set is a sub data set after split by the best feature and its values | |
# with the sub labels correspondingly | |
myTree[bestFeatLabel][value] = createTree( | |
splitDataSet(dataSet, bestFeat, value), subLabels) | |
return myTree | |
# using the tree classifier to classify the data set | |
def classify(inputTree, featLabels, testVec): | |
firstStr = inputTree.keys()[0] | |
secondDict = inputTree[firstStr] | |
featIndex = featLabels.index(firstStr) | |
key = testVec[featIndex] | |
valueOfFeat = secondDict[key] | |
# the decision node classify the testVec recursion | |
if isinstance(valueOfFeat, dict): | |
classLabel = classify(valueOfFeat, featLabels, testVec) | |
# the leaf node classify the testVec as this label | |
else: | |
classLabel = valueOfFeat | |
return classLabel | |
# store the classify tree persistence | |
def storeTree(inputTree, filename): | |
import pickle | |
fw = open(filename, 'w') | |
pickle.dump(inputTree, fw) | |
fw.close() | |
# load classify tree for the disk | |
def grabTree(filename): | |
import pickle | |
fr = open(filename) | |
return pickle.load(fr) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
decision tree code samples