Skip to content

Instantly share code, notes, and snippets.

@lttzzlll
Created March 2, 2017 11:54
Show Gist options
  • Save lttzzlll/48a99d18db8a36a76b8683836b3493ca to your computer and use it in GitHub Desktop.
Save lttzzlll/48a99d18db8a36a76b8683836b3493ca to your computer and use it in GitHub Desktop.
decision tree
'''
Created on Oct 12, 2010
Decision Tree Source Code for Machine Learning in Action Ch. 3
@author: Peter Harrington
'''
from math import log
import operator
# prepare data set
# create data set and corrsponding labels
def createDataSet():
# yes: flippers
# no: no surfacing
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
# change to discrete values
return dataSet, labels
# calculate shannon entropy
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: # the the number of unique elements and their occurance
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
# calculate probability
prob = float(labelCounts[key]) / numEntries
# calculate shannon entropy
shannonEnt -= prob * log(prob, 2) # log base 2
# the shannon entropy is a sum value
return shannonEnt
# split data set by the feature[axis] and its feature value
# this method is just like a filter method and return the
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# collect the left data set
# the value is the selected feature and should be eliminate
reducedFeatVec = featVec[:axis] # chop out axis used for splitting
# collect the right data set
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# choose the best feature to split data set, return the best feature
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # the last column is used for the labels
# calculate the original shannon entropy
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures): # iterate over all the features
# collect the specific feature value
# create a list of all the examples of this feature
featList = [example[i] for example in dataSet]
# get unique feature values
uniqueVals = set(featList) # get a set of unique values
newEntropy = 0.0
for value in uniqueVals:
# split the data set using the specific feature and the
# corresponding feature value
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# calculate the changed value
# calculate the info gain; ie reduction in entropy
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain): # compare this to the best gain so far
bestInfoGain = infoGain # if better than current best, set to best
bestFeature = i
return bestFeature # returns an integer
# the algorithm of majority cnt
def majorityCnt(classList):
# define a class counter collector
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# iteritems is the iterator of the dict
# operator is a another module support sorting function
sortedClassCount = sorted(classCount.iteritems(),
key=operator.itemgetter(1), reverse=True)
# return the majority classifier of the remain data set
return sortedClassCount[0][0]
# create the decision tree
def createTree(dataSet, labels):
# collect the labels as the classified classes
classList = [example[-1] for example in dataSet]
# case 1: all the data are belong to the the same class
if classList.count(classList[0]) == len(classList):
return classList[0] # stop splitting when all of the classes are equal
# case 2: there are no data to split any more. using the majority algorithm
if len(dataSet[0]) == 1: # stop splitting when there are no more features in dataSet
return majorityCnt(classList)
# case 3: running the algorithm
# choose the best feature to split the data set
bestFeat = chooseBestFeatureToSplit(dataSet)
# get the best feature labels
bestFeatLabel = labels[bestFeat]
# build up the decision tree
myTree = {bestFeatLabel: {}}
# delete the best feature and its corresponding labels
del(labels[bestFeat])
# collect the feature values
featValues = [example[bestFeat] for example in dataSet]
# distinct the feature values
uniqueVals = set(featValues)
for value in uniqueVals:
# simple deep copy
# copy all of labels, so trees don't mess up existing labels
subLabels = labels[:]
# using the best feature and its value to split the data set
# the split data set is a sub data set after split by the best feature and its values
# with the sub labels correspondingly
myTree[bestFeatLabel][value] = createTree(
splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
# using the tree classifier to classify the data set
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
# the decision node classify the testVec recursion
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
# the leaf node classify the testVec as this label
else:
classLabel = valueOfFeat
return classLabel
# store the classify tree persistence
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
# load classify tree for the disk
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
@lttzzlll
Copy link
Author

lttzzlll commented Mar 2, 2017

decision tree code samples

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment