Skip to content

Instantly share code, notes, and snippets.

@griesmey
Created August 18, 2015 16:31
Show Gist options
  • Save griesmey/c8f144a51bf698cbae71 to your computer and use it in GitHub Desktop.
Save griesmey/c8f144a51bf698cbae71 to your computer and use it in GitHub Desktop.
implementation of KMeans in Python
from math import sqrt
import sys
from collections import defaultdict
from itertools import izip
class Point(object):
def __init__(self, x, y, id=None):
self.x = x
self.y = y
self.id = id
class KMeans(object):
def __init__(self, centroids, points):
self.centroids = centroids
self.points = points
self.assigned = defaultdict(list)
self.delta = 0.0001
def distance(self, p1, p2):
dist = (p2.x - p1.x)**2 + (p2.y - p1.y)**2
if dist < 0:
return 0
return sqrt(dist)
def clear_assignments(self):
self.assigned.clear()
def assign_centroids(self):
for point in self.points:
closest_dist = sys.maxsize
closest_cent = self.centroids[0]
for centroid in self.centroids:
cur_dist = self.distance(point, centroid)
if cur_dist < closest_dist:
closest_dist = cur_dist
closest_cent = centroid
self.assigned[closest_cent].append(point)
for centroid in self.centroids:
if centroid not in self.assigned:
self.assigned[centroid] = []
def avg_points(self, points):
x_avg = 0
y_avg = 0
length = len(points)
for point in points:
x_avg += point.x
y_avg += point.y
return Point(x_avg / float(length), y_avg / float(length)
def assign_new_centroids(self):
for centroid, points in self.assigned.iteritems():
if not points:
avg = None
else:
avg = self.avg_points(points)
the_delta = self.distance(avg, centroid)
if the_delta < self.delta:
return True
centroid.x = avg.x
centroid.y = avg.y
return False
def run(self):
while(True):
self.assign_centroids()
if(self.assign_new_centroids()):
return
self.clear_assignments()
def plot_graph(self):
from matplotlib import pyplot as plt
colors = ['go', 'bo', 'yo']
for color, (centroid, points) in izip(colors, self.assigned.iteritems()):
x_axis = []
y_axis = []
for p in points:
x_axis.append(p.x)
y_axis.append(p.y)
plt.plot(x_axis, y_axis, color)
plt.plot([centroid.x], [centroid.y], 'ro')
plt.show()
if __name__ == '__main__':
import random
points = []
centroids = []
MAXX = 100
MAXY = 100
for i in xrange(100):
x = random.randint(0, MAXX)
y = random.randint(0, MAXY)
points.append(Point(x, y, i))
for k in xrange(100, 103):
new_p = Point(random.randint(0, 0), random.randint(0, 0), k)
centroids.append(new_p)
km = KMeans(centroids, points)
km.run()
km.plot_graph()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment