Last active
April 30, 2025 16:36
-
-
Save nikhilr612/3119b1a8ff3dfbca1d2b643eae89b3f2 to your computer and use it in GitHub Desktop.
A nifty union-find implementation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Quick-little union-find implementation in python. | |
# A very elegant data structure | |
import array | |
class UnionFind: | |
""" | |
Disjoint-Set Data Structure. | |
""" | |
def __init__(self, discrete_set): | |
""" | |
Initialize with set of singleton sets. | |
""" | |
self.obj_index_map = {}; | |
for i, item in enumerate(discrete_set): | |
self.obj_index_map[item] = i | |
self.parentMap = array.array('I', range(len(discrete_set))) | |
self.rankHeurs = array.array('I', (1 for _ in range(len(discrete_set)))) | |
self.setcount = len(discrete_set) # all singleton sets | |
def find(self, item): | |
""" | |
Find a set identifier for the given item. | |
If set identifiers compare equal, then items are in the same set. | |
""" | |
x = self.obj_index_map[item] | |
p = x # set id | |
while (p1 := self.parentMap[p]) != p: | |
p = p1 | |
# path compression | |
q = x | |
while (q1 := self.parentMap[q]) != p: | |
self.parentMap[q] = p | |
q = q1 | |
return p | |
def union(self, item1, item2): | |
""" | |
Unify the sets containing element `item1` and element `item2` under the set union operation. | |
""" | |
x = self.find(item1) | |
y = self.find(item2) | |
if x == y: | |
return # nothing to do here. | |
rx = self.rankHeurs[x] | |
ry = self.rankHeurs[y] | |
# do least work possible. | |
if rx > ry: | |
x, y = y, x | |
rx, ry = ry, rx | |
self.parentMap[x] = y | |
if rx == ry: | |
self.rankHeurs[y] += 1 | |
# union was successful | |
self.setcount -= 1 | |
def findall(self, itemset): | |
""" | |
Find the IDs of sets containing the elements from `itemset`. | |
""" | |
ret = set() | |
for it in itemset: | |
ret.add(self.find(it)) | |
return ret | |
def setcount(self): | |
""" | |
Returns the number of disjoint sets | |
""" | |
return self.setcount |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment