Skip to content

Instantly share code, notes, and snippets.

@taojy123
Last active March 26, 2018 06:58
Show Gist options
  • Save taojy123/a82e865a9d90815b02420027ca28abde to your computer and use it in GitHub Desktop.
Save taojy123/a82e865a9d90815b02420027ca28abde to your computer and use it in GitHub Desktop.
K-modes
#!python2
#coding=utf8
import random
class Man(object):
def __init__(self, name, values):
self.name = name
self.values = values
def __repr__(self):
return self.name
def get_d(self, r):
d = 0
for i in range(len(r)):
if r[i] != self.values[i]:
d += 1
return d
def split_group(data, r0, r1):
result = [[], []]
for man in data:
d0 = man.get_d(r0)
d1 = man.get_d(r1)
if d0 < d1:
result[0].append(man)
else:
result[1].append(man)
return result
def get_center(group):
assert len(group) > 0
r = []
for i in range(len(group[0].values)):
t = {}
for man in group:
v = man.values[i]
t[v] = t.get(v, 0) + 1
vs = sorted(t, key=lambda v: t[v])
v = vs[-1]
r.append(v)
return r
def get_score(group, r):
score = 0
for man in group:
d = man.get_d(r)
score += d
return score
def find_once(data, r0, r1):
flag = 0
last_score = 999
for i in range(10):
print u'----------第', i+1, u'次迭代------------'
print u'两点', r0, r1
group0, group1 = split_group(data, r0, r1)
score0 = get_score(group0, r0)
score1 = get_score(group1, r1)
print u'第一组', group0, u'总距离', score0
print u'第二组', group1, u'总距离', score1
if not group0 or not group1:
print u'含有空组不可用!'
break
score = score0 + score1
if score >= last_score:
flag += 1
if flag >= 3:
print '总距离趋近最小值 结束迭代!'
last_score = score
break
last_score = score
r0 = get_center(group0)
r1 = get_center(group1)
print u'找出新质心', r0, r1
return last_score, r0, r1
def find(data):
result = []
for i in range(10):
print u'\n===================== 第', i+1, u'次求解==========================='
r0 = random.choice(data).values
r1 = random.choice(data).values
print u'生成随机两点', r0, r1, u'开始'
score, r0, r1 = find_once(data, r0, r1)
result.append([score, r0, r1, i])
print '=================================================='
print result
print '=================================================='
result.sort()
return result[0]
DATA = [
Man('A', ['F', 'JG', 'MZ']),
Man('B', ['F', 'CX', 'HH']),
Man('C', ['M', 'CX', 'HH']),
Man('D', ['F', 'CX', 'MZ']),
Man('E', ['F', 'JG', 'HH']),
Man('F', ['M', 'CX', 'YX']),
Man('G', ['M', 'JG', 'YX']),
Man('H', ['M', 'JG', 'HH']),
]
print u"""
初始数据:
姓名 性别 职业 爱好
A F JG MZ
B F CX HH
C M CX HH
D F CX MZ
E F JG HH
F M CX YX
G M JG YX
H M JG HH
F => 女
M => 男
JG => 结构
CX => 程序
MZ => 美妆
HH => 绘画
YX => 游戏
"""
score, r0, r1, i = find(DATA)
print u'最佳结果(最终距离总和最小) 为第', i+1, u'次求解的结果'
print u'总距离', score
print u'两质心', r0, r1
group0, group1 = split_group(DATA, r0, r1)
print u'两组', group0, group1
#!python3
import random
def display_group(group, name_col=None):
if not name_col:
return str(group)
names = [r[name_col] for r in group]
return '[' + ', '.join(names) + ']'
def display_groups(groups, name_col=None):
ds = [display_group(group, name_col) for group in groups]
return '\n'.join(ds)
def split_group(data, rs):
k = len(rs)
groups = [[] for i in range(k)]
for r1 in data:
rank = []
for i in range(k):
r = rs[i]
d = get_d(r1, r)
rank.append((d, i))
rank.sort()
i = rank[0][1]
groups[i].append(r1)
return groups
def get_center(group, col_num):
assert len(group) > 0
r = []
for i in range(col_num):
col_type = type(group[0][i])
if col_type == int or col_type == float:
col_sum = 0
for r1 in group:
col = r1[i]
col_sum += col
col_avg = 1.0 * col_sum / len(group)
elif col_type == str:
tmap = {}
for r1 in group:
col = r1[i]
tmap[col] = tmap.get(col, 0) + 1
ts = sorted(tmap, key=lambda col: tmap[col])
col_avg = ts[-1]
else:
raise NotImplementedError()
r.append(col_avg)
return r
def get_d(r1, r2):
d = 0
for col1, col2 in zip(r1, r2):
if isinstance(col1, int) or isinstance(col1, float):
d += (col1 - col2) ** 2
elif isinstance(col1, str):
if col1 != col2:
d += 1
else:
raise NotImplementedError()
return d
def get_diff(group, r):
diff = 0
for r1 in group:
d = get_d(r1, r)
diff += d
return diff
def find_once(data, rs, n=10, name_col=None):
# 单次试算
# n 迭代次数
assert len(rs) > 0
k = len(rs)
col_num = len(rs[0])
flag = 0
last_diff = 99999999
for i in range(n):
print('----------第', i+1, '次迭代------------')
print('类中心向量', rs)
groups = split_group(data, rs)
total_diff = 0 # 所有组差异总和
for j in range(k):
group = groups[j]
r = rs[j]
diff = get_diff(group, r)
print('第', j+1, '组', display_group(group, name_col), '总差异', diff)
total_diff += diff
print('所有组差异总和', total_diff)
if not all(groups):
print('含有空组,此次试算失败!')
break
if total_diff >= last_diff:
flag += 1
last_diff = total_diff
if flag >= 3:
print('总差异趋近最小值 结束迭代!')
break
rs = []
for group in groups:
r = get_center(group, col_num)
rs.append(r)
print('找出新质心', rs)
return last_diff, rs
def find(data, k=2, col_num=None, n=10, name_col=None):
# 多次试算找出最优解
# n 试算次数
assert len(data) > 0
if not col_num:
col_num = len(data[0])
assert col_num > 0
results = []
for i in range(n):
print('\n=====================第', i+1, '次试算===========================')
rs = []
for j in range(k):
r = tuple(random.choice(data)[:col_num])
rs.append(r)
print('生成随机初始向量', rs, '开始试算')
total_diff, rs = find_once(data, rs, name_col=name_col)
results.append((total_diff, rs, i+1))
results.sort()
print('==================================================')
for result in results:
print(result)
print('==================================================')
total_diff, rs, i = results[0]
groups = split_group(data, rs)
return rs, groups, total_diff, i
if __name__ == '__main__':
DATA = [
['女', 22, 161, '结构', '美妆', 'A'],
['女', 31, 168, '程序', '绘画', 'B'],
['男', 25, 188, '程序', '绘画', 'C'],
['女', 29, 157, '程序', '美妆', 'D'],
['女', 38, 165, '结构', '绘画', 'E'],
['男', 39, 168, '程序', '游戏', 'F'],
['男', 32, 178, '结构', '游戏', 'G'],
['男', 24, 173, '结构', '绘画', 'H'],
]
print("""
初始数据:
性别 年龄 身高 职业 爱好 姓名(不参与)
女 22 161 结构 美妆 A
女 31 168 程序 绘画 B
男 25 188 程序 绘画 C
女 29 157 程序 美妆 D
女 38 165 结构 绘画 E
男 39 168 程序 游戏 F
男 32 178 结构 游戏 G
男 24 173 结构 绘画 H
""")
rs, groups, total_diff, i = find(DATA, col_num=5, name_col=-1)
print('最佳结果(最终差异总和最小) 为第', i, '次试算的结果')
print('总差异', total_diff)
print('质心', rs)
print('分组')
print(display_groups(groups, -1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment