Created
October 28, 2016 18:06
-
-
Save arccoder/e9328b3a9f9fa49d79cd56413aa1a1d8 to your computer and use it in GitHub Desktop.
K-means Algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def kmeans(data, numofclasses, options=None): | |
""" | |
Calculates the clusters using k-means algorithm and returns cluster labels and centroids | |
:param data: Data to cluster structured as [Number of observations x Data dimensions(variables)] | |
:param numofclasses: Number of classes you want to cluster the data into. | |
:param options: Optional data in dictionary form to overwrite the defaults | |
max_iterations - int: Maximum number of iterations | |
all_stages - bool: If true returns labels and clusters for iterations | |
:return: {"numofiterations": number of iterations used to define cluster labels, | |
"centroids": Centroids of the clusters for final or all iterations, | |
"labels": Cluster labels for the final or all iterations } | |
""" | |
allCentroids = [] | |
allLabels = [] | |
# Defaults | |
max_iteration = 100 | |
all_stages = False | |
# Overwrite defaults according to options | |
if options is not None: | |
if "max_iteration" in options: | |
max_iteration = options["max_iteration"] | |
if "all_stages" in options: | |
all_stages = options["all_stages"] | |
# Reference cluster centers | |
refcenters = np.zeros([numofclasses, 2]) | |
# Randomly pick up seed mean points | |
singleCentroid = data[np.random.randint(data.shape[0], size=numofclasses), :] | |
singleLabel = [] | |
# Number of iterations | |
iteration = 0 | |
while not np.allclose(refcenters, singleCentroid) and iteration < max_iteration: | |
# Update reference cluster centers | |
refcenters = np.copy(singleCentroid) | |
# Distance between the data points to individual mean points | |
d1 = np.square(np.tile(data[:, 0], [numofclasses, 1]).transpose()-singleCentroid[:, 0]) | |
d2 = np.square(np.tile(data[:, 1], [numofclasses, 1]).transpose()-singleCentroid[:, 1]) | |
dist = np.sqrt(d1 + d2) | |
# Data point label to the minimum distance | |
singleLabel = np.argmin(dist, axis=1) | |
for i in range(numofclasses): | |
x_ = data[:, 0][singleLabel == i] | |
y_ = data[:, 1][singleLabel == i] | |
if len(x_) > 0: | |
singleCentroid[i, 0] = x_.mean() | |
if len(y_) > 0: | |
singleCentroid[i, 1] = y_.mean() | |
# endfor | |
if all_stages: | |
allLabels.append(singleLabel) | |
allCentroids.append(np.copy(singleCentroid)) | |
iteration += 1 | |
# endwhile | |
if all_stages: | |
return {"numofiterations": iteration, "centroids": allCentroids, "labels": allLabels} | |
else: | |
return {"numofiterations": iteration, "centroids": singleCentroid, "labels": singleLabel} | |
# endifelse | |
# enddef | |
# | |
# Data generation and running k-means | |
# | |
import matplotlib.pyplot as plt | |
import glob, os | |
# Delete the current images in the folder | |
filelist = glob.glob("*.png") | |
for f in filelist: | |
os.remove(f) | |
numofclasses = 5 | |
# Data generation | |
mean = [-25, -25, 25, 25, -25, 25, 25, -25, 0, 0] | |
cov = [[50, 0], [0, 50]] | |
colors = ['r', 'g', 'b', 'brown', 'm'] | |
x = [] | |
y = [] | |
for i in range(0,2*numofclasses,2): | |
x_, y_ = np.random.multivariate_normal([mean[i], mean[i+1]], cov, 100).T | |
x.extend(x_) | |
y.extend(y_) | |
x = np.asarray(x) | |
y = np.asarray(y) | |
data = np.vstack((x, y)).T | |
output = kmeans(data, numofclasses, {"all_stages": True}) | |
# Display original data | |
plt.plot(data[:, 0], data[:, 1], 'x', color='black') | |
plt.axis('equal') | |
#plt.show() | |
plt.savefig('data.png') | |
plt.clf() | |
# Display the output stages | |
numofiterations = output["numofiterations"] | |
if numofiterations > 0: | |
for i in range(output["numofiterations"]): | |
plt.clf() | |
centers = output["centroids"][i] | |
for n in range(numofclasses): | |
x_ = data[:, 0][output["labels"][i] == n] | |
y_ = data[:, 1][output["labels"][i] == n] | |
plt.plot(x_, y_, 'x', color=colors[n]) | |
plt.plot(centers[n,0], centers[n,1], 'o', color=colors[n]) | |
plt.axis('equal') | |
plt.savefig(str(i) + ".png") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment