Skip to content

Instantly share code, notes, and snippets.

@bistaumanga
Last active December 7, 2020 10:16
Show Gist options
  • Save bistaumanga/6023692 to your computer and use it in GitHub Desktop.
Save bistaumanga/6023692 to your computer and use it in GitHub Desktop.
KMeans Clustering Implemented in python with numpy
'''Implementation and of K Means Clustering
Requires : python 2.7.x, Numpy 1.7.1+'''
import numpy as np
def kMeans(X, K, maxIters = 10, plot_progress = None):
centroids = X[np.random.choice(np.arange(len(X)), K), :]
for i in range(maxIters):
# Cluster Assignment step
C = np.array([np.argmin([np.dot(x_i-y_k, x_i-y_k) for y_k in centroids]) for x_i in X])
# Move centroids step
centroids = [X[C == k].mean(axis = 0) for k in range(K)]
if plot_progress != None: plot_progress(X, C, np.array(centroids))
return np.array(centroids) , C
'''dEMONSTRATION of K Means Clustering
Requires : python 2.7.x, Numpy 1.7.1+, Matplotlib, 1.2.1+'''
import sys
import pylab as plt
import numpy as np
plt.ion()
def show(X, C, centroids, keep = False):
import time
time.sleep(0.5)
plt.cla()
plt.plot(X[C == 0, 0], X[C == 0, 1], '*b',
X[C == 1, 0], X[C == 1, 1], '*r',
X[C == 2, 0], X[C == 2, 1], '*g')
plt.plot(centroids[:,0],centroids[:,1],'*m',markersize=20)
plt.draw()
if keep :
plt.ioff()
plt.show()
# generate 3 cluster data
# data = np.genfromtxt('data1.csv', delimiter=',')
m1, cov1 = [9, 8], [[1.5, 2], [1, 2]]
m2, cov2 = [5, 13], [[2.5, -1.5], [-1.5, 1.5]]
m3, cov3 = [3, 7], [[0.25, 0.5], [-0.1, 0.5]]
data1 = np.random.multivariate_normal(m1, cov1, 250)
data2 = np.random.multivariate_normal(m2, cov2, 180)
data3 = np.random.multivariate_normal(m3, cov3, 100)
X = np.vstack((data1,np.vstack((data2,data3))))
np.random.shuffle(X)
from kMeans import kMeans
centroids, C = kMeans(X, K = 3, plot_progress = show)
show(X, C, centroids, True)
@tvwerkhoven
Copy link

tvwerkhoven commented May 31, 2018

kMeans:12 fails always when the function is given a list of zeroes as input for X. It also fails on 'real' data a few percent of the time (in my application). It's because the number of clusters in C is less than K, such that X[C == k].mean() gives an error for k which is not in C. I solved this by checking if len(np.unique(C)) < K, and if so, reset centroids to (new) random sample. See https://gist.github.com/tvwerkhoven/4fdc9baad760240741a09292901d3abd for fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment