Last active
March 29, 2020 14:50
-
-
Save benallard/da0318b36ea81b940ca64ac2c9e5cba2 to your computer and use it in GitHub Desktop.
Travelling's Salesman based on historic data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
+ABCDE | |
+CBDF | |
+ACDF | |
+BEACD | |
+ABDE | |
ABCDEF? | |
-ABCDE | |
BCDEF? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import sys | |
# public domain | |
class Entry(object): | |
def __init__(self, score): | |
self.score = score | |
self.amount = 1; | |
def add(self, score): | |
self.amount += 1; | |
self.score += score | |
def sub(self, score): | |
if self.amount > 0: | |
self.amount -= 1 | |
self.score -= score | |
@property | |
def value(self): | |
if self.amount == 0: | |
return 0 | |
return self.score / self.amount | |
class Model(object): | |
def __init__(self): | |
self.data = dict() | |
def learn(self, route): | |
prev = '0' | |
for i, station in enumerate(route): | |
self._process(prev, route[i:], True) | |
prev = station | |
def forget(self, route): | |
prev = '0' | |
for i, station in enumerate(route): | |
self._process(prev, route[i:], False) | |
prev = station | |
def predict(self, stations): | |
stations = [s for s in stations] | |
random.shuffle(stations) | |
prev = '0' | |
res = [] | |
while len(stations) > 0: | |
best = max(stations, key=lambda x : self.__get(prev, x)) | |
stations.remove(best) | |
res.append(best) | |
print(''.join(res)) | |
def _process(self, frm, tos, add=True): | |
scores = [8,4,2,1] | |
for i, to in enumerate(tos): | |
if i >= len(scores): | |
break | |
if add: | |
self.__add(frm, to, scores[i]) | |
else: | |
self.__sub(frm, to, scores[i]) | |
def __get(self, frm, to): | |
key = frm + to | |
if key in self.data: | |
return self.data[key].value | |
else: | |
return 0 | |
def __add(self, frm, to, score): | |
key = frm + to | |
if key in self.data: | |
self.data[key].add(score) | |
else: | |
self.data[key] = Entry(score) | |
def __sub(self, frm, to, score): | |
key = frm + to | |
if key in self.data: | |
self.data[key].sub(score) | |
def __str__(self): | |
s = " " | |
s += " ".join(self.__tos()) | |
s += "\n" | |
for frm in self.__froms(): | |
s += " " + frm + " " | |
for to in self.__tos(): | |
key = frm + to | |
if key in self.data: | |
s += '{:4.1f}'.format(self.data[key].value) | |
else: | |
s += " " | |
s += "\n" | |
return s | |
def __froms(self): | |
return sorted(set(k[0] for k in self.data.keys())) | |
def __tos(self): | |
return sorted(set(k[1] for k in self.data.keys())) | |
def main(filename): | |
model = Model() | |
with open(filename) as f: | |
for line in f: | |
line = line.strip() | |
if line.startswith("+"): | |
model.learn(line[1:]) | |
elif line.startswith("-"): | |
model.forget(line[1:]) | |
elif line.endswith("?"): | |
model.predict(line[:-1]) | |
else: | |
print(line) | |
print(model) | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main(sys.argv[1])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment