Skip to content

Instantly share code, notes, and snippets.

@adusak
Created June 2, 2015 19:31
Show Gist options
  • Save adusak/f3a99030ae4dbde4ef03 to your computer and use it in GitHub Desktop.
Save adusak/f3a99030ae4dbde4ef03 to your computer and use it in GitHub Desktop.
K-Means cluster finding algorithm
from math import sqrt
from operator import itemgetter
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from python.less_11.regression import generate_points, load_file
def cluster_detection(x, y, k=2, iterations=2):
max_x, min_x, max_y, min_y = max(x), min(x), max(y), min(y)
fig = plt.figure()
centroids = {}
c = lambda: np.random.randint(0, 255) / 255
for i in range(k):
centroids[i] = {'coords': (np.random.randint(min_x, max_x), np.random.randint(min_y, max_y)),
'cluster': [],
'color': (c(), c(), c())}
def step(i):
for px, py in zip(x, y):
centroid_distance = []
for centroid in centroids.values():
cx, cy = centroid['coords']
distance = sqrt((px - cx) ** 2 + (py - cy) ** 2)
centroid_distance.append((centroid, distance))
centroid = min(centroid_distance, key=itemgetter(1))[0]
centroid['cluster'].append((px, py))
for centroid in centroids.values():
x_list, y_list = zip(*centroid['cluster'])
new_x = sum(x_list) / len(x_list)
new_y = sum(y_list) / len(y_list)
centroid['coords'] = (new_x, new_y)
plt.clf()
for centroid in centroids.values():
cx, cy = centroid['coords']
for px, py in centroid['cluster']:
plt.scatter(px, py, color=centroid['color'])
plt.scatter(cx, cy, s=200, color=centroid['color'])
for centroid in centroids.values():
centroid['cluster'].clear()
animation.FuncAnimation(fig, step, frames=iterations, repeat=0)
plt.show()
x, y = load_file("exfiles/oldfailthful.txt")
# x, y = generate_points(3, 5, 300)
cluster_detection(x, y, 4, iterations=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment