Created
April 24, 2012 20:04
-
-
Save rodrigosetti/2483186 to your computer and use it in GitHub Desktop.
K Nearest Neighbors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
from __future__ import division | |
from random import uniform | |
from math import sqrt | |
def euclidian(x, y): | |
"calculates the euclidian distance between two points" | |
assert( len(x) == len(y) ) | |
return sqrt(sum((i-j)**2 for i,j in zip(x,y))) | |
def knn( points, k, n=None, dist=euclidian ): | |
"calculates the K centers of the KNN algorithm" | |
assert( len(points) > 0 ) | |
assert( k > 0 ) | |
zero = (0,) * len(points[0]) | |
# if N is not given, N will be set to the size of the | |
# data divided by K. Therefore, each center will be guided | |
# by approximately an equal fraction of the total data | |
n = n if n else max(len(points)//k, 1) | |
# Start the centers in K points using the Kmeans++ | |
# initialization procedure | |
centers = [] | |
points_probability = (1,)*len(points) | |
for i in xrange(k): | |
# Roulette wheel: select a random index from points_probability | |
# weighted by its value | |
probability_max = sum(points_probability) | |
rand_choice = uniform(0, probability_max) | |
probability_sum = 0 | |
for j,p in enumerate(points_probability): | |
if probability_sum <= rand_choice < probability_sum + p: | |
break | |
probability_sum += p | |
# append the new selected index as a new center | |
centers.append( points[j] ) | |
# recalculate points_probability by settings each value to be | |
# the squared distance to the nearest center | |
points_probability = tuple(min(dist(center,point)**2 for center in centers) for point in points) | |
# While the centers change, keep clustering | |
# i.e. find the fixed point of the centers | |
has_change = True | |
while has_change: | |
has_change = False | |
# For each center | |
for c in xrange(k): | |
# get the nearest K neighbours of this particular center | |
nearest = sorted(points, key=lambda x:dist(centers[c], x))[:n] | |
# calculate the new position of the center: the mean of the | |
# nearest neighbours | |
new_center = tuple(i/n for i in reduce(lambda x,y: [i+j for i,j in zip(x,y)], nearest, zero)) | |
# if the center have to move: update the has_change flag and | |
# actualy moves it to the new position | |
if new_center != centers[c]: | |
has_change = True | |
centers[c] = new_center | |
return centers |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment