Skip to content

Instantly share code, notes, and snippets.

@zhenghaoz
Created November 19, 2017 06:40
Show Gist options
  • Save zhenghaoz/09508ebaf155b46d3269da991cbcc0cf to your computer and use it in GitHub Desktop.
Save zhenghaoz/09508ebaf155b46d3269da991cbcc0cf to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import ConvexHull
class KMeans:
mean_vec = np.array([])
X = np.array([])
y = np.array([])
clusters = []
def fit(self, X, k, visual=False, epsilon=0.0001):
self.X = X
# Initialize man vector
self.mean_vec = X[np.random.randint(0, len(X), k)]
stop = False
iter = 0
while not stop:
stop = True
# Clustering
self.y = np.empty([len(X)])
for i in range(0, len(X)):
self.y[i] = np.argmin(np.linalg.norm(X[i] - self.mean_vec, axis=1))
self.clusters = []
for i in range(0, k):
self.clusters.append(self.X[np.where(np.equal(self.y, i))])
# Visualization
if len(X[0]) == 2 and visual:
for cluster in self.clusters:
plt.plot(cluster[:, 0], cluster[:, 1], 'o')
if len(cluster) > 2:
hull = ConvexHull(cluster)
plt.plot(cluster[hull.vertices, 0], cluster[hull.vertices, 1], 'r--', lw=2)
plt.plot(cluster[hull.vertices[[-1, 0]], 0], cluster[hull.vertices[[-1, 0]], 1], 'r--')
plt.plot(self.mean_vec[:, 0], self.mean_vec[:, 1], 'r+')
plt.xlabel('密度')
plt.ylabel('含糖率')
plt.title('第%d次迭代之后' % iter)
plt.show()
# Update mean vectors
for i in range(0, k):
next_vec = np.mean(self.clusters[i], axis=0)
if np.linalg.norm(next_vec - self.mean_vec[i]) > epsilon:
self.mean_vec[i] = next_vec
stop = False
iter += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment