Created
January 7, 2017 20:20
-
-
Save prophile/5c72af00cd4a392c621fe365ae6d3965 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
try: | |
from collections.abc import Iterable | |
except ImportError: | |
from collections import Iterable | |
import heapq | |
class AStar(Iterable): | |
def __init__(self, transition, heuristic, start, zero_score=0, pure_heuristic=False): | |
self.transition = transition | |
self.heuristic = heuristic | |
self.pure_heuristic = pure_heuristic | |
self.candidate = start | |
self.candidate_path = [] | |
self.candidate_path_length = zero_score | |
self.candidate_score = self._score(zero_score, start) | |
self.open_set = [] | |
self.closed_set = set() | |
self.index = 0 | |
def __iter__(self): | |
return self | |
def _get_candidate(self): | |
try: | |
(score, plength, _, path, node) = heapq.heappop(self.open_set) | |
except IndexError: | |
self.candidate = None | |
else: | |
self.candidate = node | |
self.candidate_path = path | |
self.candidate_path_length = plength | |
self.candidate_score = score | |
def _next_index(self): | |
ix = self.index | |
self.index = ix + 1 | |
return ix | |
def _score(self, path_length, node): | |
if self.pure_heuristic: | |
return self.heuristic(node) | |
return path_length + self.heuristic(node) | |
def next(self): | |
while self.candidate in self.closed_set: | |
self._get_candidate() | |
if self.candidate is None: | |
raise StopIteration() | |
self.closed_set.add(self.candidate) | |
transitions = [ | |
(node, transition, move_cost) | |
for node, transition, move_cost in self.transition(self.candidate) | |
if node not in self.closed_set | |
] | |
prev_node = self.candidate | |
prev_path = list(self.candidate_path) | |
if not transitions: | |
# Dead end: just pull in the next candidate from the open set | |
self._get_candidate() | |
return prev_node, prev_path | |
# Special-case for fast descent | |
if len(transitions) == 1: | |
node, transition, move_cost = transitions[0] | |
score = self._score( | |
self.candidate_path_length + move_cost, | |
node, | |
) | |
if score <= self.candidate_score: | |
# Explore this node next without faffing with the open set | |
self.candidate = node | |
self.candidate_score = score | |
self.candidate_path.append(transition) | |
self.candidate_path_length += move_cost | |
return prev_node, prev_path | |
for ix, (node, transition, move_cost) in enumerate(transitions): | |
score = self._score( | |
self.candidate_path_length + move_cost, | |
node, | |
) | |
index = self._next_index() | |
entry = ( | |
score, | |
self.candidate_path_length + move_cost, | |
index, | |
self.candidate_path + [transition], | |
node, | |
) | |
if ix == len(transitions) - 1: | |
# Last node, use pushpop to avoid an extra heap op | |
next_candidate = heapq.heappushpop(self.open_set, entry) | |
( | |
self.candidate_score, | |
self.candidate_path_length, | |
_, | |
self.candidate_path, | |
self.candidate, | |
) = next_candidate | |
else: | |
heapq.heappush(self.open_set, entry) | |
return prev_node, prev_path | |
__next__ = next | |
def astar( | |
transition, | |
heuristic, | |
start, | |
is_final, | |
open_set_limits=None, | |
**kwargs | |
): | |
instance = AStar(transition, heuristic, start, **kwargs) | |
for node, path in instance: | |
if is_final(node): | |
return node, path | |
# Apply open set policy | |
if open_set_limits is not None: | |
max_open_set_size, trunc_open_set_size = open_set_limits | |
if len(instance.open_set) > max_open_set_size: | |
instance.open_set.sort() | |
instance.open_set[trunc_open_set_size:] = [] | |
raise ValueError("No path found") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment