Last active
August 18, 2021 09:11
-
-
Save f0lie/b9a57be922f02671dd95a18acc71f0ad to your computer and use it in GitHub Desktop.
Python 3: Clean implementation of Heapq Dijsktra, Bellman-Ford, and SPFA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict, deque | |
import heapq | |
from typing import OrderedDict | |
def create_graph(matrix): | |
graph = defaultdict(list) | |
for row in range(len(matrix)): | |
for col in range(len(matrix[0])): | |
if matrix[row][col] > 0: | |
graph[row].append([col, matrix[row][col]]) | |
return graph | |
def get_path(path, source, end): | |
current = end | |
found_path = [current] | |
while current != source: | |
current = path[current] | |
found_path.append(current) | |
return found_path[::-1] | |
def dijsktra(graph, source): | |
# https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm | |
# https://cs.stackexchange.com/questions/118388/dijkstra-without-decrease-key | |
# min path to i-th node from source | |
distance = [float("inf")] * len(graph) | |
distance[0] = 0 | |
# contains previous node pointing to i-th node | |
path = [-1] * len(graph) | |
# (distance, node), heap[0] contains shortest distance so far | |
heap = [(0, source)] | |
while heap: | |
dist_from_source, node = heapq.heappop(heap) | |
# Exit path early if there is no way path can improve on answer | |
if dist_from_source > distance[node]: | |
continue | |
for neighbor, weight in graph[node]: | |
# Update dist if a shorter path was found than stored currently | |
if distance[node] + weight < distance[neighbor]: | |
distance[neighbor] = distance[node] + weight | |
path[neighbor] = node | |
heapq.heappush(heap, (distance[node] + weight, neighbor)) | |
return distance, path | |
def bellman_ford(graph, source): | |
# Bellman Ford can be thought of as brute forcing to find min distance by checking | |
# all of the edges repeatably by the number of vertexs | |
distance = [float("inf")] * len(graph) | |
distance[source] = 0 | |
path = [-1] * len(graph) | |
# at i step, distances contain shortest path at most i length | |
for _ in range(len(graph)-1): | |
for frm, neighbors in graph.items(): | |
for to, weight in neighbors: | |
if distance[to] > distance[frm] + weight: | |
distance[to] = distance[frm] + weight | |
path[to] = frm | |
return distance, path | |
def spfa(graph, source): | |
# https://en.wikipedia.org/wiki/Shortest_Path_Faster_Algorithm | |
distance = [float("inf")] * len(graph) | |
distance[source] = 0 | |
path = [-1] * len(graph) | |
# OrderedDict is used because appending is ordered with O(1) and lookup is O(1), values are ignored | |
queue = OrderedDict() | |
queue[source] = 0 | |
while queue: | |
current, _ = queue.popitem() | |
for neighbor, weight in graph[current]: | |
if distance[neighbor] > distance[current] + weight: | |
distance[neighbor] = distance[current] + weight | |
path[neighbor] = current | |
if neighbor not in queue: | |
queue[neighbor] = None | |
return distance, path | |
def spfa_2(graph, source): | |
# A variation of spfa using a simpler dict without using odd ball OrderDict for O(1) appending and lookup | |
distance = {source: 0} | |
path = [-1] * len(graph) | |
queue = deque([source]) | |
while queue: | |
current = queue.popleft() | |
for neighbor, weight in graph[current]: | |
if neighbor not in distance or distance[neighbor] > distance[current] + weight: | |
distance[neighbor] = distance[current] + weight | |
path[neighbor] = current | |
queue.append(neighbor) | |
return distance, path | |
if __name__ == "__main__": | |
# Input taken from here. | |
# https://www.geeksforgeeks.org/dijkstras-shortest-path-algorithm-greedy-algo-7/ | |
input_graph = [[0, 4, 0, 0, 0, 0, 0, 8, 0], | |
[4, 0, 8, 0, 0, 0, 0, 11, 0], | |
[0, 8, 0, 7, 0, 4, 0, 0, 2], | |
[0, 0, 7, 0, 9, 14, 0, 0, 0], | |
[0, 0, 0, 9, 0, 10, 0, 0, 0], | |
[0, 0, 4, 14, 10, 0, 2, 0, 0], | |
[0, 0, 0, 0, 0, 2, 0, 1, 6], | |
[8, 11, 0, 0, 0, 0, 1, 0, 7], | |
[0, 0, 2, 0, 0, 0, 6, 7, 0] | |
] | |
graph = create_graph(input_graph) | |
distance, path = dijsktra(graph, 0) | |
print("Path from 0 to 8", get_path(path, 0, 8)) | |
print("Distance from 0 to 8:", distance[8]) | |
distance, path = bellman_ford(graph, 0) | |
print("Path from 0 to 8", get_path(path, 0, 8)) | |
print("Distance from 0 to 8:", distance[8]) | |
distance, path = spfa(graph, 0) | |
print("Path from 0 to 8", get_path(path, 0, 8)) | |
print("Distance from 0 to 8:", distance[8]) | |
distance, path = spfa_2(graph, 0) | |
print("Path from 0 to 8", get_path(path, 0, 8)) | |
print("Distance from 0 to 8:", distance[8]) | |
""" | |
Path from 0 to 8 [0, 1, 2, 8] | |
Distance from 0 to 8: 14 | |
Path from 0 to 8 [0, 1, 2, 8] | |
Distance from 0 to 8: 14 | |
Path from 0 to 8 [0, 1, 2, 8] | |
Distance from 0 to 8: 14 | |
Path from 0 to 8 [0, 1, 2, 8] | |
Distance from 0 to 8: 14 | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I wrote these implementations because I felt like many of the implementations out there weren't clear and clean. I put some of these algorithms into leetcode questions like Network Delay Time so I know they are good.