Created
December 1, 2021 04:02
-
-
Save billju/fdcc9d9bdf66c9f6fbd271b2b90e8c9d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
k = 2 # K個群集 | |
max_iter = 10 # 最大迭代 | |
data = np.array([1, 2, 3, 4, 11, 12]) | |
last = np.random.choice(data, k, replace=False) # 初始化K個中心點 | |
func = lambda x,y: np.abs(x-y) # 計算距離方法 | |
steps = [] | |
for _ in range(max_iter): | |
# 對每筆資料求出最近的中心點 | |
argmin = np.array([func(last, row).argmin() for row in data]) | |
# 重新分配群集 | |
clusters = [data[argmin==i] for i in range(k)] | |
steps.append([*last, *clusters]) | |
# 更新中心點 | |
new = np.array([c.mean() for c in clusters]) | |
# 停止條件 | |
if func(new, last).sum()==0: break | |
last = new | |
center_cols = map(lambda i:f'群心{i}', range(k)) | |
cluster_cols = map(lambda i:f'群集{i}', range(k)) | |
pd.DataFrame(steps,columns=[*center_cols,*cluster_cols]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment