Skip to content

Instantly share code, notes, and snippets.

@nikhilr612
Last active April 30, 2025 16:36
Show Gist options
  • Save nikhilr612/3119b1a8ff3dfbca1d2b643eae89b3f2 to your computer and use it in GitHub Desktop.
Save nikhilr612/3119b1a8ff3dfbca1d2b643eae89b3f2 to your computer and use it in GitHub Desktop.
A nifty union-find implementation.
# Quick-little union-find implementation in python.
# A very elegant data structure
import array
class UnionFind:
"""
Disjoint-Set Data Structure.
"""
def __init__(self, discrete_set):
"""
Initialize with set of singleton sets.
"""
self.obj_index_map = {};
for i, item in enumerate(discrete_set):
self.obj_index_map[item] = i
self.parentMap = array.array('I', range(len(discrete_set)))
self.rankHeurs = array.array('I', (1 for _ in range(len(discrete_set))))
self.setcount = len(discrete_set) # all singleton sets
def find(self, item):
"""
Find a set identifier for the given item.
If set identifiers compare equal, then items are in the same set.
"""
x = self.obj_index_map[item]
p = x # set id
while (p1 := self.parentMap[p]) != p:
p = p1
# path compression
q = x
while (q1 := self.parentMap[q]) != p:
self.parentMap[q] = p
q = q1
return p
def union(self, item1, item2):
"""
Unify the sets containing element `item1` and element `item2` under the set union operation.
"""
x = self.find(item1)
y = self.find(item2)
if x == y:
return # nothing to do here.
rx = self.rankHeurs[x]
ry = self.rankHeurs[y]
# do least work possible.
if rx > ry:
x, y = y, x
rx, ry = ry, rx
self.parentMap[x] = y
if rx == ry:
self.rankHeurs[y] += 1
# union was successful
self.setcount -= 1
def findall(self, itemset):
"""
Find the IDs of sets containing the elements from `itemset`.
"""
ret = set()
for it in itemset:
ret.add(self.find(it))
return ret
def setcount(self):
"""
Returns the number of disjoint sets
"""
return self.setcount
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment