Created
October 7, 2019 17:31
-
-
Save Redchards/e3f44d59aaa2727a42640a232f0058ee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import cv2 | |
import tsp | |
import pickle as pkl | |
def euclidian_distance(x1, x2): | |
return np.sqrt((x2[0] - x1[0]) ** 2 + (x2[1] - x1[1]) ** 2) | |
if __name__ == '__main__': | |
img = cv2.imread('map2.jpg') | |
mp = img.copy() | |
mp[:, :, 2] = 0 | |
mp = cv2.cvtColor(mp, cv2.COLOR_BGR2GRAY) | |
# _, mp = cv2.threshold(mp, 127, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
'''cv2.imshow('map', mp) | |
cv2.waitKey(0) | |
cv2.destroyAllWindows()''' | |
params = cv2.SimpleBlobDetector_Params() | |
params.filterByColor = True | |
params.blobColor = 0 | |
params.minThreshold = 0 | |
params.maxThreshold = 127 | |
blob_detector = cv2.SimpleBlobDetector_create(params) | |
keypoints = blob_detector.detect(mp) | |
print('hello') | |
print(keypoints) | |
im_keypoints = cv2.drawKeypoints(mp, keypoints, np.array([]), (0, 0, 255), | |
cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS) | |
cv2.imshow('map', im_keypoints) | |
cv2.waitKey(0) | |
cv2.destroyAllWindows() | |
coords = [k.pt for k in keypoints] | |
coord_dict = dict() | |
for i in range(len(coords)): | |
for j in range(len(coords)): | |
coord_dict[(i, j)] = euclidian_distance(coords[i], coords[j]) | |
print(coord_dict) | |
print(len(keypoints)) | |
tsp_res = tsp.tsp(range(len(keypoints)), coord_dict) | |
print(tsp_res) | |
route = tsp_res[1] | |
coords = [(int(x), int(y)) for (x, y) in coords] | |
for idx in range(len(route) - 1): | |
cv2.arrowedLine(img, coords[route[idx]], coords[route[idx + 1]], (0, 0, 255), 2) | |
cv2.imshow('route', img) | |
cv2.waitKey(0) | |
cv2.destroyAllWindows() | |
cv2.imwrite('route.png', img) | |
pkl.dump((coords, route), open('route.pkl', 'wb')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code adapted from TSPLib | |
def tsp(nodes, dist=None): | |
import numpy as np, pandas as pd | |
from more_itertools import iterate, take | |
import pulp | |
from pulp import LpProblem, LpVariable, LpBinary, lpDot, lpSum, value | |
n = len(nodes) | |
if not dist: | |
dist = {(i,j):np.linalg.norm(np.subtract(nodes[i],nodes[j])) | |
for i in range(n) for j in range(i+1,n)} | |
dist.update({(j,i):d for (i,j),d in dist.items()}) | |
a = pd.DataFrame([(i,j,dist[i,j]) | |
for i in range(n) for j in range(n) if i!=j], columns=['NodeI','NodeJ','Dist']) | |
m = LpProblem() | |
a['VarIJ'] = [LpVariable('x%d'%i, cat=LpBinary) for i in a.index] | |
a['VarJI'] = a.sort_values(['NodeJ', 'NodeI']).VarIJ.values | |
u = [0]+[LpVariable('y%d'%i, lowBound=0) for i in range(n-1)] | |
m += lpDot(a.Dist, a.VarIJ) | |
for _,v in a.groupby('NodeI'): | |
m += lpSum(v.VarIJ) == 1 #出次数制約 | |
m += lpSum(v.VarJI) == 1 #入次数制約 | |
for _,(i,j,_,vij,vji) in a.query('NodeI!=0 & NodeJ!=0').iterrows(): | |
m += u[i]+1 -(n-1)*(1-vij) + (n-3)*vji <= u[j] #持ち上げポテンシャル制約(MTZ) | |
for _,(_,j,_,v0j,vj0) in a.query('NodeI==0').iterrows(): | |
m += 1+(1-v0j) +(n-3)*vj0 <= u[j] #持ち上げ下界制約 | |
for _,(i,_,_,vi0,v0i) in a.query('NodeJ==0').iterrows(): | |
m += u[i] <=(n-1)-(1-vi0)-(n-3)*v0i #持ち上げ上界制約 | |
m.solve(pulp.PULP_CBC_CMD(msg=True, fracGap=0.00001, maxSeconds=25000)) | |
a['ValIJ'] = a.VarIJ.apply(value) | |
dc = dict(a[a.ValIJ>0.5][['NodeI','NodeJ']].values) | |
return value(m.objective), list(take(n, iterate(lambda k: dc[k], 0))) | |
def tsp2(pos): | |
import numpy as np | |
from pulp import LpProblem, LpVariable, LpBinary, lpDot, lpSum, value | |
pos = np.array(pos) | |
N = len(pos) | |
m = LpProblem() | |
v = {} | |
for i in range(N): | |
for j in range(i+1,N): | |
v[i,j] = v[j,i] = LpVariable('v%d%d'%(i,j), cat=LpBinary) | |
m += lpDot([np.linalg.norm(pos[i]-pos[j]) for i,j in v | |
if i<j], [x for (i,j),x in v.items() if i<j]) | |
for i in range(N): | |
m+= lpSum(v[i,j] for j in range(N) if i!=j) == 2 | |
for i in range(N): | |
for j in range(i+1,N): | |
for k in range(j+1,N): | |
m += v[i,j]+v[j,k]+v[k,i] <= 2 | |
st = set() | |
while True: | |
m.solve() | |
u = unionfind(N) | |
for i in range(N): | |
for j in range(i+1,N): | |
if value(v[i,j])>0: | |
u.unite(i,j) | |
gg = u.groups() | |
if len(gg) == 1: | |
break | |
for g_ in gg: | |
g = tuple(g_) | |
if g not in st: | |
st.add(g) | |
m += lpSum(v[i,j] for i in range(N) for j in range(i+1,N) | |
if (i in g and j not in g) or | |
(i not in g and j in g)) >= 1 | |
break | |
cn = [0]*N | |
for i in range(N): | |
for j in range(i+1,N): | |
if value(v[i,j])>0: | |
if i or cn[i]==0: | |
cn[i] += j | |
cn[j] += i | |
p,q,r = cn[0],0,[0] | |
while p != 0: | |
r.append(p) | |
q,p = p,cn[p]-q | |
return value(m.objective), r | |
def tsp3(point): | |
from math import sqrt | |
from itertools import permutations | |
n = len(point) | |
bst, mn = None, 1e100 | |
for d in permutations(range(1, n)): | |
e = [point[i] for i in [0] + list(d) + [0]] | |
s = sum(sqrt((e[i][0] - e[i + 1][0])**2 | |
+ (e[i][1] - e[i + 1][1])**2) for i in range(n)) | |
if s < mn: | |
mn = s | |
bst = [0] + list(d) | |
return mn, bst |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment