Created
April 2, 2020 09:55
-
-
Save 270ajay/1c8ae39662833cdae2966c439ccd6161 to your computer and use it in GitHub Desktop.
finding minimum spanning tree (kruskal and prim algorithms)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import heapq | |
'''course: https://www.coursera.org/learn/algorithms-on-graphs#about''' | |
class DisjointSets: | |
'''Finds quickly if two points belong to same set''' | |
def __init__(self, totalNumOfSets): | |
self.parent = [None] * totalNumOfSets | |
self.rank = [None] * totalNumOfSets | |
def makeSet(self, id): | |
self.parent[id] = id | |
self.rank[id] = 0 | |
def find(self, id): | |
while id != self.parent[id]: | |
id = self.parent[id] | |
return id | |
def findFast(self, id): | |
'''path compression heuristic. faster than find() in future calls to this method | |
because we update the parent info''' | |
if id != self.parent[id]: | |
self.parent[id] = self.findFast(self.parent[id]) | |
return self.parent[id] | |
def union(self, id1, id2): | |
'''union by rank heuristic. rank stores the height. | |
update the id of the element that has smaller rank''' | |
rootId1 = self.findFast(id1) | |
rootId2 = self.findFast(id2) | |
if rootId1 == rootId2: | |
return | |
if self.rank[rootId1] > self.rank[rootId2]: | |
self.parent[rootId2] = rootId1 | |
else: | |
self.parent[rootId1] = rootId2 | |
if self.rank[rootId1] == self.rank[rootId2]: | |
self.rank[rootId2] += 1 | |
############################################################################# | |
class PriorityQueue(): | |
'''credits: Eugene Yarmash | |
https://stackoverflow.com/questions/46636656/python-heapq-replace-priority''' | |
REMOVED = "REMOVED" | |
def __init__(self): | |
self.priorityQueue = [] | |
self.nodeInfo = {} | |
def addNodeWithPriority(self, node, priority): | |
'''adds or updates the node in priority queue''' | |
if node in self.nodeInfo: | |
self.removeNode(node) | |
entry = [priority, node] | |
self.nodeInfo[node] = entry | |
heapq.heappush(self.priorityQueue, entry) | |
def removeNode(self, node): | |
removedNode = self.nodeInfo.pop(node) | |
removedNode[1] = PriorityQueue.REMOVED | |
def extractMin(self): | |
while not self.isEmpty(): | |
priority, node = heapq.heappop(self.priorityQueue) | |
if node != PriorityQueue.REMOVED: | |
del self.nodeInfo[node] | |
return node | |
def isEmpty(self): | |
if self.priorityQueue: | |
return False | |
return True | |
######################################################################### | |
def kruskal(adjacencyList, edgeWeights): | |
'''Finds minimum spanning tree i.e., minimum total edge length required to connect all nodes. | |
Repeatedly adds next smallest edge which doesn't create cycle between nodes''' | |
numberOfNodes = len(adjacencyList) | |
solutionSet = set() | |
totalSpan = 0 | |
disjointSet = DisjointSets(totalNumOfSets=numberOfNodes) | |
id = 0 | |
nodeIdMap = {} | |
for node in adjacencyList: | |
nodeIdMap[node] = id | |
disjointSet.makeSet(id) | |
id+=1 | |
for edge in sorted(edgeWeights, key=edgeWeights.get): | |
fromNode = edge[0] | |
toNode = edge[1] | |
if disjointSet.find(nodeIdMap[fromNode]) != disjointSet.find(nodeIdMap[toNode]): | |
solutionSet.add(edge) | |
totalSpan += edgeWeights[edge] | |
disjointSet.union(nodeIdMap[fromNode], nodeIdMap[toNode]) | |
return solutionSet, totalSpan | |
def prim(adjacencyList, edgeWeights, nodeName): | |
'''Finds minimum spanning tree i.e., minimum total edge length required to connect all nodes | |
Repeatedly adds nodes to the solution tree using the smallest edge''' | |
cost = {} | |
parent = {} | |
nodesInSolTree = set() | |
INFINITY = 10000 | |
for node in adjacencyList: | |
cost[node] = INFINITY | |
parent[node] = None | |
cost[nodeName] = 0 | |
del parent[nodeName] | |
priorityQueue = PriorityQueue() | |
for node in adjacencyList: | |
priorityQueue.addNodeWithPriority(node, cost[node]) | |
while not priorityQueue.isEmpty(): | |
fromNode = priorityQueue.extractMin() | |
if fromNode == None: | |
break | |
nodesInSolTree.add(fromNode) | |
for toNode in adjacencyList[fromNode]: | |
if (toNode not in nodesInSolTree) and (cost[toNode] > edgeWeights[(fromNode, toNode)]): | |
cost[toNode] = edgeWeights[(fromNode, toNode)] | |
parent[toNode] = fromNode | |
priorityQueue.addNodeWithPriority(toNode, cost[toNode]) | |
totalSpan = sum(edgeWeights[(parent[toNode], toNode)] for toNode in parent) | |
return parent, totalSpan | |
if __name__ == '__main__': | |
adjacencyList = {'A': {'B', 'D', 'E'}, | |
'B': {'A', 'C', 'E', 'F'}, | |
'C': {'B', 'F'}, | |
'D': {'A', 'E'}, | |
'E': {'A', 'B', 'D', 'F'}, | |
'F': {'B', 'C', 'E'}} | |
edgeWeights = {('A','B'): 4, | |
('A','D'): 2, | |
('A','E'): 1, | |
('B','A'): 4, | |
('B','C'): 8, | |
('B','E'): 5, | |
('B','F'): 6, | |
('C','B'): 8, | |
('C','F'): 1, | |
('D','A'): 2, | |
('D','E'): 3, | |
('E','A'): 1, | |
('E','B'): 5, | |
('E','D'): 3, | |
('E','F'): 9, | |
('F','B'): 6, | |
('F','C'): 1, | |
('F','E'): 9} | |
solutionSet, totalSpan = kruskal(adjacencyList, edgeWeights) | |
print('\nKruskal: Minimum Spanning Tree:', solutionSet) | |
print('Kruskal: Total Span:', totalSpan) | |
parentOfNodesPrim, totalSpanPrim = prim(adjacencyList, edgeWeights, 'A') | |
print('\nPrim: Minimum Spanning Tree:', parentOfNodesPrim) | |
print('Prim: Total Span:', totalSpan) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment