Created
July 24, 2020 00:00
-
-
Save charris/62f2dadc0ab597196635e8803eab786a to your computer and use it in GitHub Desktop.
union-find
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
include "stdlib.pxd" | |
cdef class UnionFind : | |
"""Union Find class. | |
todo : check memory allocations | |
""" | |
cdef int *_dad | |
cdef int *_sib | |
cdef int _size | |
cdef int _capacity | |
cdef int _setcount | |
cdef void _reset(self) : | |
cdef int i | |
cdef int *dad = self._dad | |
cdef int *sib = self._sib | |
for i from 0 <= i < self._size : | |
dad[i] = -1 | |
sib[i] = i | |
self._setcount = self._size | |
cdef int _root(self, int elt) : | |
cdef int top | |
cdef int nxt | |
cdef int *dad = self._dad | |
if elt >= self._size : | |
return -1 | |
# find root | |
top = elt | |
while dad[top] >= 0 : | |
top = dad[top] | |
# compress links | |
if elt != top : | |
while dad[elt] != top : | |
nxt = dad[elt] | |
dad[elt] = top | |
elt = nxt | |
return top | |
cdef int _is_valid_index(self, int i) : | |
return 0 <= i < self._size | |
def __cinit__(self, int size=0, int capacity=100) : | |
cdef int i | |
if size > capacity : | |
capacity = size | |
self._size = size | |
self._capacity = capacity | |
self._dad = <int *>malloc(capacity*sizeof(int)) | |
self._sib = <int *>malloc(capacity*sizeof(int)) | |
self._reset() | |
def __dealloc__(self) : | |
free(self._dad) | |
free(self._sib) | |
def size(self) : | |
return self._size | |
def capacity(self) : | |
return self._capacity | |
def setcount(self) : | |
return self._setcount | |
def clear(self) : | |
self._size = 0 | |
self._setcount = 0 | |
def reset(self) : | |
self._reset() | |
def felt(self, int elt1, int elt2) : | |
assert self._is_valid_index(elt1), "elt1 out of range." | |
assert self._is_valid_index(elt2), "elt2 out of range." | |
return self._root(elt1) != self._root(elt2) | |
def union(self, elts) : | |
cdef int n = len(elts) | |
cdef int *dad = self._dad | |
cdef int *sib = self._sib | |
cdef int r1, r2, tmp, i, elt1, elt2 | |
if n == 0 : | |
return | |
elt1 = elts[0] | |
assert self._is_valid_index(elt1), "element out of range" | |
r1 = self._root(elts[0]) | |
for i from 1 <= i < n : | |
elt2 = elts[i] | |
assert self._is_valid_index(elt2), "element out of range" | |
r2 = self._root(elts[i]) | |
if r1 != r2 : | |
self._setcount -= 1 | |
tmp = sib[r1] | |
sib[r1] = sib[r2] | |
sib[r2] = tmp | |
if dad[r1] < dad[r2] : | |
dad[r1] += dad[r2] | |
dad[r2] = r1 | |
else : | |
dad[r2] += dad[r1] | |
dad[r1] = r2 | |
def getset(self, int elt) : | |
cdef int *sib = self._sib | |
cdef int nxt | |
assert self._is_valid_index(elt), "elt is out of range" | |
set = [elt] | |
nxt = sib[elt] | |
while nxt != elt : | |
set.append(nxt) | |
nxt = sib[nxt] | |
return set | |
def getallsets(self) : | |
cdef int *dad = self._dad | |
cdef int n = self._size | |
cdef int i | |
sets = [] | |
for i from 0 <= i < n : | |
if dad[i] < 0 : | |
sets.append(self.getset(i)) | |
return sets |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment