Created
October 28, 2014 08:36
-
-
Save kzinglzy/6c1f2135095a682bc1f3 to your computer and use it in GitHub Desktop.
k-means & k-means ++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding:utf-8 | |
from collections import defaultdict | |
from random import uniform | |
from math import sqrt | |
def point_avg(points): | |
""" 计算并返回给定点的中心值 | |
""" | |
points = zip(*points) | |
avg = lambda points, i: sum(points[i]) / float(len(points[i])) | |
return [avg(points, 0), avg(points, 1)] | |
def update_centers(data_set, assignments): | |
""" 更新种子点的位置,进行聚类 | |
返回每个聚类的平均中心(也就是新的种子点) | |
""" | |
new_means = defaultdict(list) | |
centers = [] | |
for assignment, point in zip(assignments, data_set): | |
new_means[assignment].append(point) | |
for points in new_means.values(): | |
centers.append(point_avg(points)) | |
return centers | |
def distance(a, b): | |
""" 求x维的两点间距离 | |
""" | |
dimensions = len(a) | |
res = 0 | |
for dimension in range(dimensions): | |
difference_sq = (a[dimension] - b[dimension]) ** 2 | |
res += difference_sq | |
return sqrt(res) | |
def assign_points(data_points, centers): | |
""" 对于每一个数据集的点, 找出离它最近的种子点. | |
返回一个包含最近的种子点的 索引 的列表 | |
如: 返回了[0, 0, 1, 2, 2]. | |
则对于数据集, 第1,2个数据点离第0个种子点最近; 第3个数据点离第1个种子最近, | |
第4,5个数据点离第2个种子最近. | |
""" | |
assignments = [] | |
for point in data_points: | |
shortest = () # 无穷大 | |
shortest_index = 0 | |
for i in range(len(centers)): | |
val = distance(point, centers[i]) | |
if val < shortest: | |
shortest = val | |
shortest_index = i | |
assignments.append(shortest_index) | |
return assignments | |
def generate_k(data_set, k): | |
"""在各个纬度的最小最大范围内, 生成 k 个随机的种子点 | |
""" | |
centers = [] | |
dimensions = len(data_set[0]) # 纬度 | |
data = zip(*data_set) | |
for _ in range(k): | |
rand_point = [] | |
for i in range(dimensions): | |
min_val = min(data[i]) | |
max_val = max(data[i]) | |
rand_point.append(uniform(min_val, max_val)) | |
centers.append(rand_point) | |
return centers | |
def k_means(dataset, k): | |
""" k-means(k均值聚类)算法 | |
步骤: | |
1. 计算k个种子随机点 ----- generate_k | |
2. 用欧式距离进行聚类, 求出聚类中心. ----- assign_points | |
3. 重复第二个步骤, 直到聚类中心不再变化. ----- update_centers | |
""" | |
k_points = generate_k(dataset, k) | |
assignments = assign_points(dataset, k_points) | |
old_assignments = None | |
while assignments != old_assignments: | |
new_centers = update_centers(dataset, assignments) | |
old_assignments = assignments | |
assignments = assign_points(dataset, new_centers) | |
return zip(assignments, dataset) | |
def main(): | |
""" 测试 | |
""" | |
points = [ | |
[1, 2], | |
[2, 1], | |
[3, 1], | |
[5, 4], | |
[5, 5], | |
[6, 5], | |
[10, 8], | |
[7, 9], | |
[11, 5], | |
[14, 9], | |
[14, 14], | |
] | |
print(k_means(points, 3)) | |
# def k_means_improve(dataset, k): | |
# """ 改进后的k-means算法. | |
# """ | |
# def init_feed(): | |
# """生成一个随机种子 | |
# """ | |
# data = zip(*dataset) | |
# feed = [] | |
# for i in range(dimensions): | |
# min_val = min(data[i]) | |
# max_val = max(data[i]) | |
# feed.append(uniform(min_val, max_val)) | |
# return feed | |
# def shortest_distance(point, feeds): | |
# """ 求出到 point 点最近的种子点的距离 | |
# """ | |
# shortest = () | |
# for feed in feeds: | |
# _distance = distance(point, feed) | |
# shortest = min(_distance, shortest) | |
# return shortest | |
# dimensions = len(dataset[0]) | |
# feeds = [init_feed()] | |
# while len(feeds) < k: | |
# sum_D = 0 | |
# for point in dataset: | |
# sum_D += shortest_distance(point, feeds) | |
# random_value = uniform(0, sum_D) # 范围在[0, sum_D]内的随机值 | |
# while random_value > 0: | |
# # random_value -= sum_D | |
# # ??????????????????? | |
# pass | |
# feeds.append(random_value) | |
# pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment