Created
March 8, 2016 22:59
-
-
Save mumbleskates/b7b3bbd3924b48805087 to your computer and use it in GitHub Desktop.
Pythonic Dijkstra path finding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
from collections import OrderedDict, namedtuple | |
from functools import total_ordering | |
from heapq import heappush, heappop | |
from itertools import count, zip_longest | |
INITIAL_START = object() | |
class SumTuple(tuple): | |
""" | |
A handy class for storing priority-costs. Acts just like a regular tuple, but addition | |
adds together corresponding elements rather than appending. | |
""" | |
def __add__(self, other): | |
if other == 0: | |
return self | |
if not isinstance(other, tuple): | |
raise TypeError("Cannot add '{0} to {1}".format(str(self), str(other))) | |
return SumTuple(x + y for x, y in zip_longest(self, other, fillvalue=0)) | |
def __radd__(self, other): | |
return self + other | |
class MaxFirstSumTuple(tuple): | |
""" | |
Like SumTuple, but the first element in the tuple reduces with max instead of sum. This allows an undesirable | |
edge to equally taint any route that passes through it. | |
""" | |
def __add__(self, other): | |
if other == 0: | |
return self | |
if not isinstance(other, tuple): | |
raise TypeError("Cannot add '{0} to {1}".format(str(self), str(other))) | |
return MaxFirstSumTuple(self._adder(other)) | |
def __radd__(self, other): | |
return self + other | |
def _adder(self, other): | |
it = zip_longest(self, other, fillvalue=0) | |
yield max(next(it)) | |
yield from (x + y for x, y in it) | |
@total_ordering | |
class Worker(object): | |
""" | |
Worker is a class for helper objects that can transform the costs of traversing a graph on an instance basis. | |
Work costs will be added together directly, so recommended return types include int, float, and SumTuple. | |
Workers can also be used for the task of computing paths from multiple starting points, where the | |
point you begin will affect the cost of your traversal overall (different workers beginning at different locations). | |
Workers essentially conditionally transform the edge-cost into a summable value. When using workers, edgefinder | |
should produce a cost that DESCRIBES the work to be performed to traverse the edge, which is passed into the | |
perform_work function as its sole parameter. The return value of this function must then be the COST of doing the | |
work thus described; for instance, edgefinder should describe the distance between the edge and the neighbor, | |
and the worker will accept that distance and return the amount of time to travel that distance. | |
""" | |
def __init__(self, name, perform_work): | |
""" | |
:type name: str | |
:type perform_work: (Any) -> Any | |
""" | |
self.name = name | |
self.perform_work = perform_work | |
def __add__(self, other): | |
if other == 0: | |
return self | |
else: | |
return WorkPerformed(other, self) | |
def __radd__(self, other): | |
return self + other | |
def __eq__(self, other): | |
if isinstance(other, Worker): | |
return self.name == other.name | |
else: | |
raise TypeError | |
def __lt__(self, other): | |
if isinstance(other, Worker): | |
return self.name < other.name | |
else: | |
raise TypeError | |
def __str__(self): | |
return "Worker({})".format(self.name) | |
__repr__ = __str__ | |
class WorkPerformed(namedtuple("WorkPerformed", ("cost", "worker"))): | |
def __add__(self, other): | |
if other == 0: | |
return self | |
else: | |
return WorkPerformed(self.cost + self.worker.perform_work(other), self.worker) | |
def __radd__(self, other): | |
return self + other | |
def with_initial(initial): | |
""" | |
:param initial: iterable of (start node, worker) tuples | |
:return: Decorate an edgefinder to start the given initial costs at the given locations. | |
If these initial costs are Workers, The edgefinder being decorated should normally | |
return edge costs that are compatible work descriptors. To use this decorator to | |
populate the map traversal with workers, send the constant INITIAL_START as the | |
starting node. | |
""" | |
def dec(edgefinder): | |
def new_edgefinder(node): | |
if node is INITIAL_START: | |
return initial | |
else: | |
return edgefinder(node) | |
return new_edgefinder | |
return dec | |
def dijkstra(start, destination, edgefinder=lambda node: ((x, 1) for x in node)): | |
""" | |
:param start: The start node | |
:param destination: The destination node | |
:param edgefinder: A function that returns an iterable of tuples | |
of (neighbor, distance) from the node it is passed | |
:return: Returns the shortest path from the start to the destination. | |
Only accepts one start and one end. | |
""" | |
return dijkstra_first((start,), lambda node: node == destination, edgefinder) | |
def dijkstra_first(starts, valid_destination, edgefinder=lambda node: ((x, 1) for x in node)): | |
""" | |
:param starts: iterable of any type, only used as keys. | |
:param valid_destination: a predicate function returning true for any node that is a suitable destination | |
:param edgefinder: A function that returns an iterable of tuples | |
of (neighbor, distance) from the node it is passed | |
:return: the shortest path from any starting node to any valid destination | |
""" | |
visited = set() | |
index = count() | |
heap = [] | |
def process(): | |
yield from ((0, None, seed, ()) for seed in starts) | |
while heap: | |
yield heappop(heap) | |
# Heap values are: distance value, a unique counter for sorting, the next place to go, and the (path, (so, (far,))) | |
for dist, _, node, path in process(): | |
if node not in visited: | |
path = (node, path) | |
if valid_destination(node): | |
return dist, path | |
visited.add(node) | |
for neighbor, dist_to_neighbor in edgefinder(node): | |
if neighbor not in visited: | |
heappush(heap, (dist + dist_to_neighbor, next(index), neighbor, path)) | |
return None, () # no path exists | |
def dijkstra_multiple(starts, valid_destination, num_to_find, edgefinder=lambda node: ((x, 1) for x in node)): | |
""" | |
:param starts: iterable of any type, only used as keys. | |
:param valid_destination: a predicate function returning true for any node that is a suitable destination | |
:param edgefinder: A function that returns an iterable of tuples | |
of (neighbor, distance) from the node it is passed | |
:return: the shortest 'num_to_find' paths from any starting node to any valid destination. Keys are the endpoint, | |
values are (total cost, path) tuples, and the whole result is an ordered dictionary from least to greatest | |
total cost. | |
""" | |
visited = set() | |
index = count() | |
heap = [] | |
results = OrderedDict() | |
def process(): | |
yield from ((0, None, seed, ()) for seed in starts) | |
while heap and len(results) <= num_to_find: | |
yield heappop(heap) | |
# Heap values are: distance value, a unique counter for sorting, the next place to go, and the (path, (so, (far,))) | |
for dist, _, node, path in process(): | |
if node not in visited: | |
path = (node, path) | |
if valid_destination(node): | |
results[node] = (dist, path) | |
visited.add(node) | |
for neighbor, dist_to_neighbor in edgefinder(node): | |
if neighbor not in visited: | |
heappush(heap, (dist + dist_to_neighbor, next(index), neighbor, path)) | |
return results | |
def dijkstra_set(starts, destinations, edgefinder=lambda node: ((x, 1) for x in node)): | |
""" | |
:param starts: an iterable of starting nodes | |
:param destinations: a set of destinations | |
:param edgefinder: A function that returns an iterable of tuples | |
of (neighbor, distance) from the node it is passed | |
:return: a dictionary of the shortest path from any starting node to each destination in the set | |
""" | |
return dijkstra_multiple(starts, (lambda x: x in destinations), len(destinations), edgefinder) | |
def dijkstra_full(starts, edgefinder=lambda node: ((x, 1) for x in node)): | |
""" | |
:param starts: iterable of any type, only used as keys. | |
:param edgefinder: A function that returns an iterable of tuples | |
of (neighbor, distance) from the node it is passed | |
:rtype: dict[object, (float, List[object])] | |
:return: the shortest 'num_to_find' paths from any starting node to any valid destination. Keys are the endpoint, | |
values are (total cost, path) tuples, and the whole result is an ordered dictionary from least to greatest | |
total cost. | |
""" | |
visited = set() | |
index = count() | |
heap = [] | |
results = OrderedDict() | |
def process(): | |
yield from ((0, None, seed, ()) for seed in starts) | |
while heap: | |
yield heappop(heap) | |
# Heap values are: distance value, a unique counter for sorting, the next place to go, and the (path, (so, (far,))) | |
for dist, _, node, path in process(): | |
if node not in visited: | |
path = (node, path) | |
results[node] = (dist, path) | |
visited.add(node) | |
for neighbor, dist_to_neighbor in edgefinder(node): | |
if neighbor not in visited: | |
heappush(heap, (dist + dist_to_neighbor, next(index), neighbor, path)) | |
return results | |
def convert_path(path): | |
result = [] | |
while path: | |
result.append(path[0]) | |
path = path[1] | |
result.reverse() | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment