Created
May 11, 2013 04:40
-
-
Save KJTsanaktsidis/5558922 to your computer and use it in GitHub Desktop.
A graph that can do dijkstra's algorithm and a little extra
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import functools | |
import sys | |
import heapq | |
class AdjacencyGraph(): | |
@functools.total_ordering | |
class SearchNode(): | |
""" | |
We're just going to use this guy as an expando object to store path, dist, and known | |
These are sortable on cost | |
""" | |
def __init__(self, name, path=None, dist=sys.maxsize, known=False): | |
self.name = name | |
self.dist = dist | |
self.path = path | |
self.known = known | |
def __eq__(self, other): | |
return self.dist == other.dist | |
def __lt__(self, other): | |
return self.dist < other.dist | |
@functools.total_ordering | |
class MultiSearchNode(): | |
""" | |
We're just going to use this guy as an expando object to store path, dist, and known | |
These are sortable on cost | |
""" | |
def __init__(self, name, known=False): | |
self.name = name | |
self.dist = {} | |
self.path = {} | |
self.known = known | |
def __eq__(self, other): | |
if len(self.dist) == 0 and len(other.dist) == 0: | |
return True | |
elif (len(self.dist) == 0 and not len(other.dist) == 0) or \ | |
(not len(self.dist) == 0 and len(other.dist) == 0): | |
return False | |
else: | |
return min(self.dist.values()) == min(other.dist.values()) | |
def __lt__(self, other): | |
if len(self.dist) == 0 and len(other.dist) == 0: | |
return False | |
elif len(self.dist) == 0 and not len(other.dist) == 0: | |
return False | |
elif not len(self.dist) == 0 and len(other.dist) == 0: | |
return True | |
else: | |
return min(self.dist.values()) < min(other.dist.values()) | |
def __init__(self): | |
self.adjacency_list = dict() | |
def insert_node(self, name): | |
""" | |
Add a node called name to the adjacency list. | |
If name is already present, raise an ValueError | |
""" | |
if name in self.adjacency_list: | |
raise ValueError('{} is already present in the graph'.format(name)) | |
self.adjacency_list[name] = list() | |
def insert_link(self, srcname, destname, cost): | |
""" | |
Add a directional ink to the adjacency list linking srcname and dstname | |
If either are not in the graph, raise a KeyError | |
""" | |
if not destname in self.adjacency_list: | |
raise KeyError('{} not in the graph'.format(destname)) | |
#we'll get a KeyError() automatically here | |
self.adjacency_list[srcname].append((destname, cost)) | |
def single_min_cost_search(self, srcname, destname): | |
""" | |
Search for a min cost path from srcname to dstname | |
Raise a KeyError if either are not present | |
""" | |
#if we don't check for this, it will be as if we searched for an unconnected node | |
if not srcname in self.adjacency_list.keys(): | |
raise KeyError('{} not in the graph'.format(srcname)) | |
if not destname in self.adjacency_list.keys(): | |
raise KeyError('{} not in the graph'.format(destname)) | |
#make a dictionary of all search node objects | |
all_verts = dict() | |
for v in self.adjacency_list.keys(): | |
if v == srcname: | |
all_verts[v] = self.SearchNode(v, dist=0) | |
else: | |
all_verts[v] = self.SearchNode(v) | |
#and a heap containing just the unvisited ones | |
#this is a ref copy, so updating something in unvisited_verts updates it in all_verts | |
unvisited_verts = list(all_verts.values()) | |
heapq.heapify(unvisited_verts) | |
while len(unvisited_verts) > 0: | |
#get smallest | |
cur_vert = heapq.heappop(unvisited_verts) | |
cur_vert.known = True | |
for vname, cost in self.adjacency_list[cur_vert.name]: | |
#we have the name of vertex from adjacency list, can update it in all_verts | |
v = all_verts[vname] | |
if v.known: | |
continue | |
if cur_vert.dist + cost < v.dist: | |
v.dist = cur_vert.dist + cost | |
v.path = cur_vert.name | |
#we've mutated stuff on the heap, so we need to sort it again | |
heapq.heapify(unvisited_verts) | |
#and now we need to return (list of names, total cost) | |
total_cost = all_verts[destname].dist | |
name_list = [destname] | |
prev_name = all_verts[destname].path | |
while prev_name != srcname: | |
name_list.insert(0, prev_name) | |
prev_name = all_verts[prev_name].path | |
name_list.insert(0, srcname) | |
return name_list, total_cost | |
def multi_min_cost_search(self, srcname, destname): | |
""" | |
Search for a dict of min cost paths from srcname to dstname | |
Each dict key is a number of nodes, and the value is the (path, cost) min for getting from | |
srcname to destname with that number of nodes | |
Raise a KeyError if either are not present | |
""" | |
#if we don't check for this, it will be as if we searched for an unconnected node | |
if not srcname in self.adjacency_list.keys(): | |
raise KeyError('{} not in the graph'.format(srcname)) | |
if not destname in self.adjacency_list.keys(): | |
raise KeyError('{} not in the graph'.format(destname)) | |
#make a dictionary of all search node objects | |
all_verts = {v: self.MultiSearchNode(v) for v in self.adjacency_list.keys()} | |
all_verts[srcname].dist[0] = 0 | |
all_verts[srcname].path[0] = None | |
#and a heap containing just the unvisited ones | |
#this is a ref copy, so updating something in unvisited_verts updates it in all_verts | |
unvisited_verts = list(all_verts.values()) | |
heapq.heapify(unvisited_verts) | |
while len(unvisited_verts) > 0: | |
#get smallest | |
cur_vert = heapq.heappop(unvisited_verts) | |
cur_vert.known = True | |
#clean up unwanted length indicies | |
cur_min = sys.maxsize | |
sorted_keys = sorted(cur_vert.dist.keys()) | |
for k in sorted_keys: | |
if cur_vert.dist[k] < cur_min: | |
cur_min = cur_vert.dist[k] | |
else: | |
del cur_vert.dist[k] | |
del cur_vert.path[k] | |
for vname, cost in self.adjacency_list[cur_vert.name]: | |
#we have the name of vertex from adjacency list, can update it in all_verts | |
v = all_verts[vname] | |
if v.known: | |
continue | |
#update our neighbours for our path length | |
for k, dist in cur_vert.dist.items(): | |
if not k + 1 in v.dist: | |
v.dist[k + 1] = sys.maxsize | |
if dist + cost < v.dist[k + 1]: | |
v.dist[k + 1] = dist + cost | |
v.path[k + 1] = cur_vert.name | |
heapq.heapify(unvisited_verts) | |
#get rid of any unwanted indicies again | |
cur_min = sys.maxsize | |
sorted_keys = sorted(all_verts[destname].dist.keys()) | |
for k in sorted_keys: | |
if cur_vert.dist[k] < cur_min: | |
cur_min = cur_vert.dist[k] | |
else: | |
del cur_vert.dist[k] | |
del cur_vert.path[k] | |
#now prepare the return | |
rlist = {} | |
for k in all_verts[destname].dist.keys(): | |
#and now we need to store (list of names, total cost) | |
total_cost = all_verts[destname].dist[k] | |
name_list = [destname] | |
prev_name = all_verts[destname].path[k] | |
i = k - 1 | |
while prev_name != srcname: | |
name_list.insert(0, prev_name) | |
prev_name = all_verts[prev_name].path[i] | |
i -= 1 | |
name_list.insert(0, srcname) | |
rlist[k] = (name_list, total_cost) | |
return rlist | |
def graph_from_csv(data_source): | |
""" | |
Generates an AdjacencyGraph from a csv stream. | |
Format is assumed to be src,dest,cost, and links are automatically bidirectional | |
data_source can be anything that has __iter__() | |
""" | |
reader = csv.reader(data_source) | |
graph = AdjacencyGraph() | |
for row in reader: | |
src = row[0].strip() | |
dest = row[1].strip() | |
cost = int(row[2]) | |
try: | |
graph.insert_node(src) | |
except ValueError: | |
pass | |
try: | |
graph.insert_node(dest) | |
except ValueError: | |
pass | |
graph.insert_link(src, dest, cost) | |
graph.insert_link(dest, src, cost) | |
return graph |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment