Created
April 24, 2012 20:52
-
-
Save rodrigosetti/2483607 to your computer and use it in GitHub Desktop.
K-Means algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
from __future__ import division | |
from random import uniform, choice | |
from math import sqrt | |
def euclidean(x, y): | |
"calculates the euclidean distance between two points" | |
assert( len(x) == len(y) ) | |
return sqrt(sum((i-j)**2 for i,j in zip(x,y))) | |
def kmeans( points, k, dist=euclidean ): | |
"calculates the centers of the K-Means algorithm" | |
assert( len(points) > 0 ) | |
assert( k > 0 ) | |
zero = (0,) * len(points[0]) | |
# Start the centers in K points using the Kmeans++ | |
# initialization procedure | |
centers = [] | |
points_probability = (1,)*len(points) | |
for i in xrange(k): | |
# Roulette wheel: select a random index from points_probability | |
# weighted by its value | |
probability_max = sum(points_probability) | |
rand_choice = uniform(0, probability_max) | |
probability_sum = 0 | |
for j,p in enumerate(points_probability): | |
if probability_sum <= rand_choice < probability_sum + p: | |
break | |
probability_sum += p | |
# append the new selected index as a new center | |
centers.append( points[j] ) | |
# recalculate points_probability by settings each value to be | |
# the squared distance to the nearest center | |
points_probability = (min(dist(center,point)**2 for center in centers) for point in points) | |
# While the centers change, keep clustering | |
# i.e. find the fixed point of the centers | |
has_change = True | |
while has_change: | |
has_change = False | |
# take the data and group them by center | |
center_by_data = {} | |
for point in points: | |
nearest = min(range(k), key=lambda c: dist(point,centers[c])) | |
if nearest in center_by_data: | |
center_by_data[nearest].append( point ) | |
else: | |
center_by_data[nearest] = [point] | |
# For each center | |
for c in xrange(k): | |
# if the center has no point, relocate it to anywhere | |
if c not in center_by_data: | |
new_center = choice(points) | |
else: | |
# calculate the new position of the center: the mean of the its points | |
new_center = tuple(i/len(center_by_data[c]) for i in reduce(lambda x,y: (i+j for i,j in zip(x,y)), center_by_data[c], zero)) | |
# if the center have to move: update the has_change flag and | |
# actually moves it to the new position | |
if new_center != centers[c]: | |
has_change = True | |
centers[c] = new_center | |
return centers |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment