Created
June 8, 2013 17:06
-
-
Save DomNomNom/5735893 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import networkx as nx | |
G = nx.Graph() | |
wolf = 'wolf' | |
goat = 'goat' | |
rose = 'rose' | |
objects = [wolf, goat, rose] | |
boat = 'boat' | |
shore_l = 'shore_l' | |
shore_r = 'shore_r' | |
transitions = { # given a position, where we can move to | |
shore_l : [boat], | |
shore_r : [boat], | |
boat : [shore_r, shore_l], | |
} | |
# A state of the world. | |
# It has a linked history and contains which things are where | |
class State(object): | |
def __init__(self, prevstate, **objects): | |
assert len(objects) == 3 | |
self.prevstate = prevstate | |
self.objects = objects # note: should be treated as immutable. | |
self.moves = prevstate.moves+1 if prevstate else 0 # how many actions did we take to get here | |
def __hash__(self): | |
return hash(frozenset(self.objects.items())) | |
def __eq__(self, other): | |
return self.objects == other.objects | |
def __str__(self): | |
# out = '' | |
# if self.prevstate: | |
# out += str(self.prevstate) + '\n' | |
# out += repr(self.objects) | |
# out += '' | |
# return out | |
out = '' | |
out += ' '.join([ obj for obj,pos in self.objects.items() if pos==shore_l ]) + ' | ' | |
out += ' '.join([ obj for obj,pos in self.objects.items() if pos==boat ]) + ' | ' | |
out += ' '.join([ obj for obj,pos in self.objects.items() if pos==shore_r ]) | |
return out | |
# return str(self.moves) | |
def isvalid(self): # note: the goal state technically is not valid | |
return all([ | |
sum([1 for pos in self.objects.values() if pos==boat]) <= 1, # only 1 passenger on the boat | |
# make sure nothing gets eaten | |
self.objects[goat] != self.objects[rose], | |
self.objects[goat] != self.objects[wolf], | |
]) | |
# all reachable states from this state | |
def reachables(self): | |
# for all objects, yield all states where it has moved | |
# it can either swap with something at the destination | |
# or just move there (swap with None) | |
for obj, pos in self.objects.items(): | |
for destination in transitions[pos]: | |
for swapwith in [ obj2 for obj2, pos2 in self.objects.items() if pos2==destination] + [None]: | |
yield self.swap(obj, pos, swapwith, destination) | |
# returns a new state where obj_a in | |
# if obj_b==None, it is ignored | |
def swap(self, obj_a, pos_a, obj_b, pos_b): | |
newobjects = dict(self.objects) | |
newobjects[obj_a] = pos_b | |
if obj_b: | |
newobjects[obj_b] = pos_a | |
return State(self, **newobjects) | |
def astar(starts, goals): | |
exploredstates = set() | |
toexplore = starts | |
while toexplore: | |
state = toexplore.pop(0) | |
if state not in exploredstates: | |
# print('exploring {0}'.format(state)) | |
# print() | |
exploredstates.add(state) | |
for newstate in state.reachables(): | |
if newstate in goals or newstate.isvalid(): | |
G.add_node(newstate) | |
G.add_edge(state, newstate) | |
if newstate in goals: | |
#return newstate | |
print(newstate) | |
# print() | |
elif newstate.isvalid() and state.isvalid(): | |
toexplore.append(newstate) # TODO: insert ordered | |
return None # no path found | |
best = astar( | |
[ # starts | |
# State(None, wolf=boat, goat=shore_l, rose=shore_l), | |
State(None, wolf=shore_l, goat=boat, rose=shore_l), | |
# State(None, wolf=shore_l, goat=shore_l, rose=boat ), | |
], | |
{ # goals | |
State(None, wolf=shore_r, goat=shore_r, rose=shore_r), | |
} | |
) | |
print(best) | |
nx.draw_spring(G) | |
plt.show() | |
print('done') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment