Skip to content

Instantly share code, notes, and snippets.

@gurunars
Last active October 16, 2016 19:14
Show Gist options
  • Save gurunars/e1652a7ea319fdca69d3b0a18d9466fc to your computer and use it in GitHub Desktop.
Save gurunars/e1652a7ea319fdca69d3b0a18d9466fc to your computer and use it in GitHub Desktop.
Djikstra algorithm implementation in Python
"""
Glossary used in the functions below:
Node can be virtually any hashable datatype.
:param start: starting node
:param end: ending node
:param weighted_graph: {"node1": {"node2": "weight", ...}, ...}
"""
def _get_shortest_path_tree(weighted_graph, start):
"""
Uses Djikstra algorithm to calculate the shortest path tree.
https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm
Return {"node1": "theBestNextNodeToWardsStart", ...}
"""
unvisited_nodes = set(weighted_graph)
tentative_weights = {}
tentative_parents = {}
for node in weighted_graph:
tentative_weights[node] = float('inf')
tentative_parents[node] = None
# Mark initial node as visited
tentative_weights[start] = 0
unvisited_nodes.discard(start)
current = start
while unvisited_nodes:
edges = weighted_graph[current]
unvisited_neighbours = set(edges).intersection(unvisited_nodes)
for neighbour in unvisited_neighbours:
tentative_parents[neighbour], tentative_weights[neighbour] = min(
(tentative_parents[neighbour], tentative_weights[neighbour]),
(current, tentative_weights[current] + edges[neighbour]),
key=lambda pair: pair[1]
)
unvisited_nodes.discard(current)
# Traversal has finished
if not unvisited_nodes:
break
# The next node should be the unvisited one with the smallest weight
current, smallest_weight = min(
[(node, tentative_weights[node]) for node in unvisited_nodes],
key=lambda pair: pair[1]
)
# This essentially means that the nodes are not reachable
if smallest_weight == float('inf'):
break
return tentative_parents
def get_shortest_path(weighted_graph, start, end):
"""
Return ["START", ... nodes between ..., "END"]
"""
tentative_parents = _get_shortest_path_tree(weighted_graph, start)
if tentative_parents[end] is None:
raise ValueError("Node {} is unreachable from node {}".format(
end, start
))
cursor = end
path = []
while cursor:
path.append(cursor)
cursor = tentative_parents[cursor]
return list(reversed(path))
import unittest
from path_finder import get_shortest_path
class PathFinderTest(unittest.TestCase):
def test_simple_unreachable(self):
graph = {
"A": {
"B": 1,
"C": 1,
"D": 1
},
"B": {
"E": 2
},
"C": {
"E": 1,
"F": 2
},
"D": {
"F": 2
},
"F": {
"G": 2
},
"E": {
"G": 1
},
"G": {},
"H": {}
}
self.assertRaisesRegexp(
ValueError, "Node H is unreachable from node A",
get_shortest_path, graph, "A", "H")
def test_simple(self):
graph = {
"A": {
"B": 1,
"C": 1,
"D": 1
},
"B": {
"E": 2
},
"C": {
"E": 1,
"F": 2
},
"D": {
"F": 2
},
"F": {
"G": 2
},
"E": {
"G": 1
},
"G": {}
}
path = get_shortest_path(graph, "A", "G")
self.assertEqual(["A", "C", "E", "G"], path)
def test_complex(self):
graph = {
1: {
2: 2,
7: 1
},
2: {
3: 3,
8: 3
},
3: {
8: 1,
4: 1
},
4: {
9: 1,
5: 2
},
5: {
10: 2,
6: 2
},
7: {
8: 3,
11: 1
},
11: {
8: 1,
12: 2
},
12: {
9: 1,
13: 1
},
13: {
10: 2,
6: 1
},
8: {
9: 2
},
9: {
10: 1
},
10: {
6: 1
},
6: {}
}
path = get_shortest_path(graph, 1, 6)
self.assertEqual([1, 7, 11, 12, 13, 6], path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment