Last active
August 29, 2015 13:56
-
-
Save ruandao/9228082 to your computer and use it in GitHub Desktop.
模拟退火 想用来解决 文明盛世的超能饮料问题,不过 效率跟遍历没差 (看来只合适用来找近似最优解)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# encoding:utf-8 | |
import sys | |
import heapq | |
import math,random | |
import time | |
# 1 获取一个随机解 | |
# 2 从上一个解中获取一个随机解 | |
# 新解为最优时,选定新解;否则按概率在解与新解中选择 | |
def annealing(sa, newF, T=10000.0, cool=0.95): | |
""" | |
sa: (price, solution) | |
""" | |
chose = sa | |
while T>0.1: | |
# print "c" | |
sb = newF(sa) | |
d = sb[0] - sa[0] | |
# print sb[0], sa[0] | |
if d < 0: | |
sa = sb | |
elif math.exp(-d/T) > random.random(): | |
sa = sb | |
else: | |
# print "continue" | |
continue | |
T = T*cool | |
if chose[0] > sa[0]: | |
chose = sa | |
print sa[0],chose[0] | |
return chose | |
def gR(max): | |
return random.randint(0, max - 1) | |
nodes = {} # [x][y]: {price, name} | |
def getRandom(): | |
""" | |
获取一个随机解,要求符合过所有点的环 | |
""" | |
keys = nodes.keys() | |
lenKey = len(keys) | |
node = keys[random.randint(0,lenKey-1)] | |
keys2 = keys[:] | |
keys2.remove(node) | |
points = [node] | |
needOut = [node] | |
g = {} | |
sum = [0] | |
endPoints = set([node]) | |
def connect(x,y): | |
if x in needOut: | |
needOut.remove(x) | |
if y not in points: | |
needOut.append(y) | |
points.append(y) | |
if not g.get(x): | |
g[x] = {} | |
if g[x].get(y) is not True: | |
sum[0] = sum[0] + nodes[x][y][0] | |
g[x][y] = True | |
def exist(x,y): | |
return nodes[x].get(y) | |
def addPath(x,y): | |
# 将x 到y的路线上的点标识为终点 | |
# print "addPath %s->%s" % (x,y), " *" * 20 | |
def _findPath(x, prePoints, y): | |
if x in prePoints: | |
return | |
prePoints.append(x) | |
# print "\t",prePoints | |
if x == y: | |
for point in prePoints: | |
endPoints.add(point) | |
return | |
if not g.get(x): | |
return | |
for point in g[x].keys(): | |
_findPath(point, prePoints[:], y) | |
_findPath(x,[],y) | |
while not(len(points) == lenKey and len(needOut) == 0): | |
lenPoints = len(points) | |
if keys2: | |
p = points[gR(lenPoints)] | |
node2 = keys2[gR(lenKey - lenPoints)] | |
if not exist(p,node2): | |
continue | |
else: | |
keys2.remove(node2) | |
else: | |
p = needOut[gR(len(needOut))] | |
node2 = list(endPoints)[gR(len(endPoints))] | |
if not exist(p, node2): | |
# 检测needout 的点是否能够连接到 endPoints | |
allConnect = [] | |
for p in needOut: | |
allConnect.extend(nodes[p].keys()) | |
canConnect = False | |
for p in endPoints: | |
if p in allConnect: | |
canConnect = True | |
break | |
if not canConnect: | |
return getRandom() | |
continue | |
else: | |
addPath(node2, p) | |
if exist(p,node2): | |
connect(p, node2) | |
# print "connect %s->%s" %(p, node2), "\tneedOut: ", needOut | |
# print "\tpoints", points | |
# print "\tkeys2", keys2 | |
# print "\tendPoints", endPoints | |
return (sum[0], g) | |
def newF(sa): | |
(price1, g1) = sa | |
g = {} # 要对g 进行拷贝 | |
paths = [] | |
# 随机破坏g1 中的某条线路(x,y) | |
for x in g1.keys(): | |
g[x] = {} | |
for y in g1[x].keys(): | |
g[x][y] = True | |
paths.append((x,y)) | |
def choseAPath(): | |
path = paths[gR(len(paths))] | |
paths.remove(path) | |
if paths: | |
if not hasOtherPath(path[0], path[1]): | |
return choseAPath() | |
return path | |
return None | |
def hasOtherPath(x,y): | |
""" | |
找到从x到y的非直接路径(就是路径上不能有(x,y)) | |
""" | |
has = [False] | |
def findOther(x, prePoints,y,withOutPoint): | |
if x in prePoints or (len(prePoints) >0 and x == withOutPoint): | |
return | |
if x == y: | |
has[0] = True | |
return | |
prePoints.append(x) | |
while not has: | |
for k in nodes[x].keys(): | |
findOther(k, prePoints[:],y, withOutPoint) | |
findOther(x, [], y, x) | |
return has[0] | |
while paths: | |
path = choseAPath() | |
if not path: | |
return getRandom() | |
del g[x][y] | |
sum = price1 - nodes[x][y][0] | |
# 然后随机加上某些路径来保证 | |
# x 可以连到y | |
# 先获得x可以连到的端点有哪些,然后从这些端点出发,随机加入某些路径(这些路径不应该是x-y),直到连接到y | |
s = set([x]) | |
stack = [x] | |
while stack: | |
# print "www" | |
p = stack.pop() | |
for k in g[p].keys(): | |
if k not in s: | |
s.add(k) | |
stack.append(k) | |
# print "mmm" | |
while y not in s: | |
p = list(s)[gR(len(s))] | |
keys = nodes[p].keys() | |
p2 = keys[gR(len(keys))] | |
# print "pppppp", s,y | |
if p == x and p2 == y: | |
continue | |
if p2 in s: | |
continue | |
s.add(p2) | |
if g[p].get(p2) is not True: | |
sum += nodes[p][p2][0] | |
g[p][p2] = True | |
return (sum, g) | |
def initDataFromFile(f): | |
for line in f: | |
line = line.replace("\t", " ").split() | |
# from, to, price, machineName | |
machineName = line[0] | |
x = line[1] | |
y = line[2] | |
price = int(line[3]) | |
if not nodes.get(x): | |
nodes[x] = {} | |
nodes[x][y] = [price, machineName] | |
def main(f): | |
# 初始化数据 | |
# 计算任意两点的最短距离 | |
# 处理图 | |
# 输出结果 | |
initDataFromFile(f) | |
# log("init finish") | |
chose = None | |
l = [] | |
sum = 0 | |
for x in xrange(1,100): | |
sa = getRandom() | |
sum += sa[0] | |
l.append(sa[0]) | |
if chose is None or chose[0]>sa[0]: | |
chose = sa | |
avg = sum / 100 | |
l = [ (x-avg)*(x-avg) for x in l] | |
sum = 0 | |
for x in l: | |
sum += x | |
variance = sum / 100 | |
print variance | |
r = annealing(chose, newF, variance) | |
if r[0] < chose[0]: | |
chose = r | |
print chose | |
# print r | |
# process() | |
# log("process finish") | |
# output() | |
# log("output finish") | |
if __name__ == '__main__': | |
import sys | |
filename = sys.argv[1] | |
f = open(filename) | |
main(f) | |
f.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment