Created
November 24, 2012 13:30
-
-
Save royguo/4139718 to your computer and use it in GitHub Desktop.
k-means
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#coding:utf-8 | |
import os | |
import random | |
import math | |
class KMeans(object): | |
def __init__(self): | |
# 文件名及其包含的单词 {'file_name':{word1:1,word2:1 ...}} | |
self.file_words = {} | |
# [{'point':{},'files':{}},...] | |
self.centers = [] | |
self.last_cost = 0.0 | |
# 读取文件数据 | |
def loadData(self,dir_name): | |
i = 0 | |
for file_name in os.listdir(dir_name): | |
# 可以先用少量样本测试一下,否则速度可能比较慢 | |
i += 1 | |
if i > 3000: | |
break | |
self.file_words[file_name] = {} | |
with open(dir_name+'/'+file_name,'r') as f: | |
content = f.read().strip() | |
words = content.split() | |
for word in words: | |
self.file_words[file_name][word] = 1.0 | |
# 使向量正规化 | |
self.normalize(self.file_words[file_name]) | |
print len(self.file_words),' files loaded' | |
# 讲数据向量正规化,使所有的向量欧式长度都为1,抛弃长度对新闻主题的影响 | |
def normalize(self,vector): | |
sqrt = 0.0 | |
for (wid, freq) in vector.items(): | |
sqrt += freq*freq | |
sqrt = math.sqrt(sqrt) | |
for (wid, freq) in vector.items(): | |
vector[wid] = freq/(float)(sqrt) | |
# 初始化聚类点 | |
def initCenters(self,class_num): | |
# 完全随机 | |
random_centers = random.sample(self.file_words.values(), class_num) | |
for c in random_centers: | |
self.centers.append({'point':c}) | |
print '初始化随机中心点个数:',len(self.centers) | |
# 选择几个特殊点 | |
# business = 0 | |
# sports = 0 | |
# it = 0 | |
# yule = 0 | |
# auto = 0 | |
# for file_name,words in self.file_words.items(): | |
# if file_name.find('business') > -1 and business == 0: | |
# self.centers.append({'point':words}) | |
# business += 1 | |
# if file_name.find('auto') > -1 and auto == 0: | |
# self.centers.append({'point':words}) | |
# auto += 1 | |
# if file_name.find('sports') > -1 and sports == 0: | |
# self.centers.append({'point':words}) | |
# sports += 1 | |
# if file_name.find('it') > -1 and it == 0: | |
# self.centers.append({'point':words}) | |
# it += 1 | |
# if file_name.find('yule') > -1 and yule == 0: | |
# self.centers.append({'point':words}) | |
# yule += 1 | |
def start(self): | |
# 开始进行优化,不断的进行三步操作:划分、重新定位中心点、最小化损失 | |
self.last_cost = self.split() | |
self.locateCenter() | |
i = 0 | |
while True: | |
i += 1 | |
print '第 ',i,' 次迭代:' | |
current_cost = self.split() | |
self.locateCenter() | |
print '损失(上一次 - 当前):',self.last_cost,' - ',current_cost,' = ',(self.last_cost - current_cost) | |
if self.last_cost - current_cost <= 1: | |
break | |
else: | |
self.last_cost = current_cost | |
def output(self): | |
for i in range(len(self.centers)): | |
print '第',i+1,'组:' | |
for s in ['business','it','sports','yule','auto']: | |
s_count = 0 | |
for file_name in self.centers[i]['files']: | |
if file_name.find(s) > 0: | |
s_count += 1 | |
print s,' = ',s_count | |
print '---------------------------------------' | |
""" 重新获得划分对象后的中心点 """ | |
def locateCenter(self): | |
# 遍历中心点,使用每个中心点所包含的文件重新求中心点 | |
for i in range(len(self.centers)): | |
print '计算第 ',i+1,' 点的新中心点...' | |
files_count = float(len(self.centers[i]['files'])) | |
# 新的中心点,格式为 {word1:0,word2:5...} | |
point = {} | |
# 遍历所有该中心包含的文件 | |
for file_name in self.centers[i]['files']: | |
# 遍历该文件包含的单词 | |
for word,freq in self.file_words[file_name].items(): | |
if word not in point: | |
point[word] = freq | |
else: | |
point[word] += freq | |
# point[word] += self.file_word_freq[file_name][word] | |
for word,freq in point.items(): | |
point[word] = freq/files_count | |
self.centers[i]['point'] = point | |
def split(self): | |
print '划分对象并计算损失...' | |
# 先清空上一次中心点表里的对象数据,需要重新划分 | |
for i in range(len(self.centers)): | |
self.centers[i]['files'] = [] | |
# 损失函数可以在划分的时候直接计算,省去单独计算的性能消耗 | |
total_cost = 0.0 | |
# 遍历所有文件 | |
for file_name in self.file_words.keys(): | |
# 对当前文件遍历所有中心点,如果该文件离某个中心点最近,则把该文件加入到该中心点的files中 | |
min_distance = -1 | |
min_i = 0 | |
for i in range(len(self.centers)): | |
p = self.file_words[file_name] # {word1:1,word2:23,word3:0 ...} | |
c = self.centers[i]['point'] # {word1:1,word2:23,word3:0 ...} | |
current_distance = self.distance(c,p) | |
if min_distance == -1 or current_distance < min_distance: | |
min_distance = current_distance | |
min_i = i | |
# 把当前文件放到距离最近的中心点中 | |
total_cost += math.pow(min_distance,2) | |
self.centers[min_i]['files'].append(file_name) | |
return total_cost | |
def distance(self,center,point): | |
square_sum = 0.0 | |
for word in center: | |
if word not in point: | |
a,b = center[word],0 | |
square_sum += a*a | |
if word in point: | |
a,b = center[word],point[word] | |
square_sum += (a-b)*(a-b) | |
for word in point: | |
if word not in center: | |
a,b = 0,point[word] | |
square_sum += b*b | |
result = math.sqrt(square_sum) | |
return result | |
if __name__ == '__main__': | |
km = KMeans() | |
km.loadData('allfiles') | |
km.initCenters(5) | |
km.start() | |
km.output() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#coding:utf-8 | |
import os | |
import random | |
import math | |
class KMeans(object): | |
""" 准备数据,把新闻数据向量化""" | |
def __init__(self,dir_name): | |
self.dir_name = dir_name | |
# {'file_name':{word:3,word:4}} | |
self.file_word_freq = {} | |
# {第1类:{file_name:'file_name','point':{word1:1,word2:2}, 'files':[当前中心点的文件] }, 第2类:...} | |
# 即中心点hash表存储三个数据,一是中心点用的那个文件名,二是改点向量数据,三是属于改点的文件 | |
self.centers = {} | |
# 上一次中心点的损失 | |
self.last_cost = 0.0 | |
self.unique_words = {} | |
# 从指定目录加载文件 | |
for file_name in os.listdir(dir_name): | |
with open(dir_name+'/'+file_name,'r') as f: | |
self.file_word_freq[file_name] = {} | |
text = f.read() | |
# 把文章的所有词频放进hash表 | |
for word in text.split(): | |
# 计算一共有多少单词,即数据有多少维度 | |
self.unique_words[word] = 1 | |
# 把文件数据存储到hash表中 | |
if word not in self.file_word_freq[file_name]: | |
self.file_word_freq[file_name][word] = 1.0 | |
# 由于我们是要判断新闻的主题,所以应该忽略新闻文本的长度对距离的影响 | |
# 即这里只用出现和不出现来计算,不应该考虑词频 | |
# else: | |
# self.file_word_freq[file_name][word] += 1.0 | |
""" 分配初始中心点,计算初始损失,并开始聚类 """ | |
def start(self,class_num): | |
# 从现有的所有文章中,随机选出class_num个点作为初始聚类中心 | |
random_files = random.sample(os.listdir(self.dir_name),class_num) | |
for i in range(class_num): | |
self.centers[i] = {'file_name':random_files[i],'point':self.file_word_freq[random_files[i]]} | |
# 初始划分,并计算初始损失 | |
print 'init center points' | |
self.split() | |
self.locateCenter() | |
self.last_cost = self.costFunction() | |
print 'start optimization' | |
# 开始进行优化,不断的进行三步操作:划分、重新定位中心点、最小化损失 | |
i = 0 | |
while True: | |
i += 1 | |
print '第 ',i,' 次优化:' | |
self.split() | |
self.locateCenter() | |
current_cost = self.costFunction() | |
print '损失(上一次 - 当前):',self.last_cost,' - ',current_cost,' = ',(self.last_cost - current_cost) | |
if math.fabs(self.last_cost - current_cost) <= 1: | |
break | |
else: | |
self.last_cost = current_cost | |
# 迭代优化损失函数,当损失函数与上一次损失相差非常小的时候,停止优化 | |
for center_key in self.centers.keys(): | |
print '第',int(center_key)+1,'组:' | |
for s in ['business','it','sports','yule','auto']: | |
s_count = 0 | |
for file_name in self.centers[center_key]['files']: | |
if file_name.find(s) > 0: | |
s_count += 1 | |
print s,' = ',s_count | |
print '---------------------------------------' | |
""" | |
根据每个聚类的中心点,计算每个对象与这些中心的距离,根据最小距离重新划分每个对象所属的分类 | |
""" | |
def split(self): | |
print '划分对象...' | |
# 先清空上一次中心点表里的对象数据,需要重新划分 | |
for center_key in self.centers.keys(): | |
self.centers[center_key]['files'] = [] | |
# 遍历所有文件 | |
for file_name in self.file_word_freq.keys(): | |
# 对当前文件遍历所有中心点,如果该文件离某个中心点最近,则把该文件加入到该中心点的files中 | |
min_distance = -1 | |
center_key = None | |
for center in self.centers.keys(): | |
p = self.file_word_freq[file_name] # {word1:1,word2:23,word3:0 ...} | |
c = self.centers[center]['point'] # {word1:1,word2:23,word3:0 ...} | |
current_distance = self.distance(c,p) | |
if min_distance == -1 or current_distance < min_distance: | |
min_distance = current_distance | |
center_key = center | |
# 把当前文件放到距离最近的中心点中 | |
self.centers[center_key]['files'].append(file_name) | |
""" 重新获得划分对象后的中心点 """ | |
def locateCenter(self): | |
# 遍历中心点,使用每个中心点所包含的文件重新求中心点 | |
for center_key in self.centers.keys(): | |
print '计算第 ',int(center_key)+1,' 点的新中心点...' | |
files_count = float(len(self.centers[center_key]['files'])) | |
# 新的中心点,格式为 {word1:0,word2:5...} | |
point = {} | |
# 遍历所有该中心包含的文件 | |
for file_name in self.centers[center_key]['files']: | |
# 遍历该文件包含的单词 | |
for word in self.file_word_freq[file_name].keys(): | |
if word not in point: | |
point[word] = self.file_word_freq[file_name][word] | |
else: | |
# 由于不使用词频计算,所以出现的词都是加1 | |
point[word] += 1 | |
# point[word] += self.file_word_freq[file_name][word] | |
for word in point.keys(): | |
point[word] = point[word]/files_count | |
self.centers[center_key]['point'] = point | |
""" 损失函数 """ | |
def costFunction(self): | |
print '开始计算损失函数' | |
total_cost = 0.0 | |
for center_key in self.centers.keys(): | |
print '计算第',int(center_key)+1,'点的损失' | |
c = self.centers[center_key]['point'] | |
for file_name in self.centers[center_key]['files']: | |
p = self.file_word_freq[file_name] | |
# 求距离平方作为损失 | |
total_cost += self.distance(c,p) | |
print '本轮损失为:',total_cost | |
return total_cost | |
""" | |
计算两个新闻向量之间的欧几里得距离,注意数据维度上的值非常稀疏 | |
""" | |
def distance(self,center,point): | |
# return random.random() | |
square_sum = 0.0 | |
for word in center: | |
if word not in point: | |
a,b = center[word],0 | |
square_sum += a*a | |
if word in point: | |
a,b = center[word],point[word] | |
square_sum += (a-b)*(a-b) | |
for word in point: | |
if word not in center: | |
a,b = 0,point[word] | |
square_sum += b*b | |
result = math.sqrt(square_sum) | |
return result | |
if __name__ == '__main__': | |
km = KMeans('allfiles') | |
km.start(5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2454 files loaded | |
初始化随机中心点个数: 5 | |
划分对象并计算损失... | |
计算第 1 点的新中心点... | |
计算第 2 点的新中心点... | |
计算第 3 点的新中心点... | |
计算第 4 点的新中心点... | |
计算第 5 点的新中心点... | |
第 1 次迭代: | |
划分对象并计算损失... | |
计算第 1 点的新中心点... | |
计算第 2 点的新中心点... | |
计算第 3 点的新中心点... | |
计算第 4 点的新中心点... | |
计算第 5 点的新中心点... | |
损失(上一次 - 当前): 4375.51247663 - 2034.9234715 = 2340.58900512 | |
第 2 次迭代: | |
划分对象并计算损失... | |
计算第 1 点的新中心点... | |
计算第 2 点的新中心点... | |
计算第 3 点的新中心点... | |
计算第 4 点的新中心点... | |
计算第 5 点的新中心点... | |
损失(上一次 - 当前): 2034.9234715 - 2033.91112679 = 1.01234470898 | |
第 3 次迭代: | |
划分对象并计算损失... | |
计算第 1 点的新中心点... | |
计算第 2 点的新中心点... | |
计算第 3 点的新中心点... | |
计算第 4 点的新中心点... | |
计算第 5 点的新中心点... | |
损失(上一次 - 当前): 2033.91112679 - 2039.61744347 = -5.70631667332 | |
第 1 组: | |
business = 314 | |
it = 68 | |
sports = 15 | |
yule = 10 | |
auto = 79 | |
--------------------------------------- | |
第 2 组: | |
business = 10 | |
it = 16 | |
sports = 26 | |
yule = 132 | |
auto = 10 | |
--------------------------------------- | |
第 3 组: | |
business = 4 | |
it = 15 | |
sports = 4 | |
yule = 4 | |
auto = 16 | |
--------------------------------------- | |
第 4 组: | |
business = 158 | |
it = 226 | |
sports = 388 | |
yule = 281 | |
auto = 338 | |
--------------------------------------- | |
第 5 组: | |
business = 14 | |
it = 183 | |
sports = 40 | |
yule = 33 | |
auto = 70 | |
--------------------------------------- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment