-
-
Save lttzzlll/4a6e012f20494babcc3725a75dccbb11 to your computer and use it in GitHub Desktop.
kNN algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# classifier method | |
# dataSet: the training data set | |
# inx: the test sample | |
# labels: the corresponding labels about the data set | |
# k: the top k classes to be selected | |
def classify0(inX, dataSet, labels, k): | |
sortedDistIndicies = euclideanMetric(dataSet, inX) | |
classCount = {} | |
for i in range(k): | |
voteIlabel = labels[sortedDistIndicies[i]] | |
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 | |
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) | |
return sortedClassCount[0][0] | |
# euclidean metric | |
# dataSet: the training data set | |
# inX: the testing data sample | |
def euclideanMetric(dataSet, inX): | |
dataSetSize = dataSet.shape[0] | |
diffMat = tile(inX, (dataSetSize, 1)) - dataSet | |
sqDiffMat = diffMat ** 2 | |
sqDistances = sqDiffMat.sum(axis=1) | |
distances = sqDistances ** 0.5 | |
sortedDistIndicies = distances.argsort() | |
return sortedDistIndicies | |
# create data set and related labels | |
def createDataSet(): | |
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) | |
labels = ['A', 'A', 'B', 'B'] | |
return group, labels | |
# get matrix from the file data | |
def file2matrix(filename): | |
fr = open(filename) | |
numberOfLines = len(fr.readlines()) # get the number of lines in the file | |
returnMat = zeros((numberOfLines, 3)) # prepare matrix to return | |
classLabelVector = [] # prepare labels return | |
fr = open(filename) | |
index = 0 | |
for line in fr.readlines(): | |
line = line.strip() | |
listFromLine = line.split('\t') | |
returnMat[index, :] = listFromLine[0:3] | |
classLabelVector.append(int(listFromLine[-1])) | |
index += 1 | |
return returnMat, classLabelVector | |
# data normalization | |
# matrix operation | |
def autoNorm(dataSet): | |
minVals = dataSet.min(0) | |
maxVals = dataSet.max(0) | |
ranges = maxVals - minVals | |
m = dataSet.shape[0] | |
normDataSet = dataSet - tile(minVals, (m, 1)) | |
normDataSet = normDataSet / tile(ranges, (m, 1)) # element wise divide | |
return normDataSet, ranges, minVals | |
# dating class test method | |
# the same logic compared with handwriting digit method | |
# the differences between it is the [data source] | |
def datingClassTest(): | |
hoRatio = 0.50 # hold out 10% | |
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') # load data setfrom file | |
normMat, ranges, minVals = autoNorm(datingDataMat) | |
m = normMat.shape[0] | |
numTestVecs = int(m * hoRatio) | |
errorCount = 0.0 | |
for i in range(numTestVecs): | |
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3) | |
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) | |
if (classifierResult != datingLabels[i]): | |
errorCount += 1.0 | |
print "the total error rate is: %f" % (errorCount / float(numTestVecs)) | |
print errorCount | |
# convert a two dimensional matrix to a one dimensional vector | |
def img2vector(filename): | |
returnVect = zeros((1, 1024)) | |
fr = open(filename) | |
for i in range(32): | |
lineStr = fr.readline() | |
for j in range(32): | |
returnVect[0, 32 * i + j] = int(lineStr[j]) | |
return returnVect | |
# handwriting digit test method | |
def handwritingClassTest(): | |
hwLabels, trainingMat = training() | |
testFileList = listdir('testDigits') # iterate through the test set | |
errorCount = 0.0 | |
mTest = len(testFileList) | |
for i in range(mTest): | |
fileNameStr = testFileList[i] | |
classNumStr = extractLabel(fileNameStr) | |
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) | |
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) | |
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) | |
if (classifierResult != classNumStr): | |
errorCount += 1.0 | |
print "\nthe total number of errors is: %d" % errorCount | |
print "\nthe total error rate is: %f" % (errorCount / float(mTest)) | |
# get the training result | |
# trainingMat: the trainingMat | |
# hwLabels: the corresponding trainingMat labels | |
def training(): | |
hwLabels = [] | |
trainingFileList = listdir('trainingDigits') # load the training set | |
m = len(trainingFileList) | |
trainingMat = zeros((m, 1024)) | |
for i in range(m): | |
fileNameStr = trainingFileList[i] | |
classNumStr = extractLabel(fileNameStr) | |
hwLabels.append(classNumStr) | |
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr) | |
return hwLabels, trainingMat | |
# extract label from the file name | |
def extractLabel(fileNameStr): | |
fileStr = fileNameStr.split('.')[0] # take off .txt | |
classNumStr = int(fileStr.split('_')[0]) | |
return classNumStr |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
#kNN Algorithm