Skip to content

Instantly share code, notes, and snippets.

@stormxuwz
Created January 31, 2017 06:21
Show Gist options
  • Save stormxuwz/a7151cb7b56e9c31bd7adf5d11b34d49 to your computer and use it in GitHub Desktop.
Save stormxuwz/a7151cb7b56e9c31bd7adf5d11b34d49 to your computer and use it in GitHub Desktop.
# This is a pure python K-means implementation
import numpy as np
def calDistance(x,y):
# return the distance of x and y
return np.sum((x-y)**2)
def assignClusters(centers,data):
distance = np.zeros((len(data),len(centers)))
for i in range(len(centers)):
distance[:,i] = [calDistance(point,centers[i]) for point in data]
return np.argmin(distance,1)
def updateCenters(data,clusterIndex,K):
# return new centers
newCenters = []
for i in range(K):
newCenters.append(np.mean(data[clusterIndex == i,:],axis = 0))
return newCenters
def converge(oldCenters,newCenters):
return set([tuple(a) for a in oldCenters]) == set([tuple(a) for a in newCenters])
def k_means(data,K):
# data: with shape [nsample, nfeature]
# K: the number of cluster
# intialize start
N = data.shape[0]
initializedRandom = np.random.choice(range(N),size = 3, replace = False)
centers = [data[i,:] for i in initializedRandom]
while True:
clusterIndex = assignClusters(centers, data)
newCenters = updateCenters(data, clusterIndex, K)
if converge(centers, newCenters):
break
else:
centers = newCenters
return newCenters, clusterIndex
def init_board_gauss(N, k):
n = float(N)/k
X = []
for i in range(k):
c = (np.random.uniform(-1, 1), np.random.uniform(-1, 1))
print c
s = np.random.uniform(0.03,0.3)
x = []
while len(x) < n:
a, b = np.array([np.random.normal(c[0], s), np.random.normal(c[1], s)])
# Continue drawing points from the distribution in the range [-1,1]
if abs(a) < 1 and abs(b) < 1:
x.append([a,b])
X.extend(x)
X = np.array(X)[:N]
return X
data = init_board_gauss(1000, 3)
import matplotlib.pyplot as plt
#plt.plot(data[:,0],data[:,1],"o")
newCenters, clusterIndex = k_means(data, 3)
print newCenters
plt.plot([center[0] for center in newCenters],[center[1] for center in newCenters],"ro")
plt.scatter(data[:,0],data[:,1],c = clusterIndex)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment