Skip to content

Instantly share code, notes, and snippets.

@james4388
Last active November 8, 2019 21:03
Show Gist options
  • Save james4388/d696749e9c57e24591213fe1660be7d7 to your computer and use it in GitHub Desktop.
Save james4388/d696749e9c57e24591213fe1660be7d7 to your computer and use it in GitHub Desktop.
Disjoin set, union find
class DisjointSet(object):
roots = None
ranks = None
count = 0
def __init__(self, size):
# Each node is parent of itself
self.roots = list(range(size))
self.ranks = [0] * size
self.count = size
def find(self, x):
"""Find the top root of node x"""
roots = self.roots
if x != roots[x]:
roots[x] = self.find(roots[x])
return roots[x]
def union(self, a, b):
"""Join two node a and b into same set"""
root_a = self.find(a)
root_b = self.find(b)
if root_a == root_b:
# Already in same set
return root_a # or root_b
ranks = self.ranks
roots = self.roots
self.count -= 1 # Decrease the count since two node are merged
ra = ranks[root_a]
rb = ranks[root_b]
if ra < rb:
# Swap root, so a become larger
root_a, root_b = root_b, root_a
roots[root_b] = root_a
if ra == rb:
ranks[root_a] += 1
return root_a
def is_unified(self, a, b):
"""Check if two node are in same sets"""
return self.find(a) == self.find(b)
def subsets(self):
sets = {}
roots = self.roots
for i in range(len(roots)-1, -1, -1):
root_i = self.find(i)
if root_i in sets:
sets[root_i].append(i)
else:
sets[root_i] = [i]
return sets
def __repr__(self):
return str(self.subsets())
import math
def minimum_spanning_tree(edges, max_id, verts=lambda e: e,
cost=lambda x: math.sqrt((x[0] - x[1]) ** 2)):
graph = DisjointSet(max_id+1)
result = []
for e in sorted(edges, key=cost):
v = verts(e)
if not graph.is_unified(*v):
graph.union(*v)
result.append(e)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment