Created
June 3, 2013 06:54
-
-
Save Cairnarvon/5696483 to your computer and use it in GitHub Desktop.
A* is easier than you think.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# coding=utf8 | |
import heapq | |
class Node(object): | |
def __init__(self, path, state, h, g): | |
self.path = path | |
self.state = state | |
self.h = h(self.state) | |
self.g = g(self.path) | |
self.f = self.h + self.g | |
def __lt__(self, node): | |
return self.f < node.f | |
def search(state, rules, h, g, seen=True): | |
""" | |
Performs an A* search given an initial state, one or more rules of | |
production, a heuristic future path-cost function h, and a past | |
path-cost function g. | |
State must be a hashable type, but other than that, the details are | |
up to you. The only things that operate on it are functions and methods | |
you supply. | |
A rule of production must be an object with two methods: | |
can_apply, which takes a state as an argument and returns True or | |
False depending on whether the rule applies, and | |
apply, which takes a state as an argument and returns a transformed | |
state. | |
h is applied to state and must be admissible; i.e. it must be | |
monotonic and an overestime. It must also return 0 when applied to the | |
goal state. | |
g is applied to a path, which is a list of rules of production. | |
If seen is True, the algorithm keeps track of states it has already | |
passed through so it won't visit them again. This is usually a good | |
idea, but if your problem is such that revisiting states is unlikely | |
anyway, you can save some space by not doing it. | |
""" | |
mheap = [] | |
heapq.heappush(mheap, Node([], state, h, g)) | |
seen_states = set() | |
while mheap: | |
node = heapq.heappop(mheap) | |
for rule in rules: | |
if rule.can_apply(node.state): | |
n_node = Node(node.path + [rule], rule.apply(node.state), h, g) | |
if h(n_node.state) == 0: | |
return n_node.path | |
if n_node.state not in seen_states: | |
heapq.heappush(mheap, n_node) | |
if seen: | |
seen_states.add(n_node.state) | |
return None | |
if __name__ == '__main__': | |
import random | |
import textwrap | |
state = [1, 2, 3, 4, 5, 6, 7, 8, 0] | |
random.shuffle(state) | |
print('''\ | |
┏━━━┳━━━┳━━━┓ ┏━━━┳━━━┳━━━┓ | |
┃{0}┃{1}┃{2}┃ ┃ 1 ┃ 2 ┃ 3 ┃ | |
┣━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━┫ | |
┃{3}┃{4}┃{5}┃ → ┃ 4 ┃ 5 ┃ 6 ┃ | |
┣━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━┫ | |
┃{6}┃{7}┃{8}┃ ┃ 7 ┃ 8 ┃ 0 ┃ | |
┗━━━┻━━━┻━━━┛ ┗━━━┻━━━┻━━━┛'''.format(*[' %d ' % i for i in state])) | |
class Rule(object): | |
def __init__(self, direction): | |
self.direction = direction | |
def can_apply(self, state): | |
zero = state.index(0) | |
return {'U': zero > 2, | |
'D': zero < 6, | |
'L': zero not in (0, 3, 6), | |
'R': zero not in (2, 5, 8)}[self.direction] | |
def apply(self, state): | |
state = list(state) | |
zero = state.index(0) | |
switch = {'U': zero - 3, | |
'D': zero + 3, | |
'L': zero - 1, | |
'R': zero + 1}[self.direction] | |
state[zero], state[switch] = state[switch], 0 | |
return tuple(state) | |
def h(state): | |
return sum(abs(state.index(i) - (i - 1)) for i in range(1, 9)) + \ | |
abs(state.index(0) - 8) | |
path = search(tuple(state), [Rule(d) for d in 'UDLR'], h, len) | |
if path is None: | |
print('{0:^37}'.format('No path found!')) | |
else: | |
print('{0:^37}'.format('In {0} moves:'.format(len(path)))) | |
print(textwrap.fill(' → '.join(r.direction for r in path), | |
35, initial_indent=' ', subsequent_indent=' ')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment