Skip to content

Instantly share code, notes, and snippets.

@kuba--
Created April 7, 2019 00:44
Show Gist options
  • Save kuba--/7f1e42876972b95e57539ee9d511983a to your computer and use it in GitHub Desktop.
Save kuba--/7f1e42876972b95e57539ee9d511983a to your computer and use it in GitHub Desktop.
Disjoint Sets
package sets
import (
"fmt"
"strings"
)
// DisjointSet tracks a set of elements partitioned
// into a number of disjoint (non-overlapping) subsets.
type DisjointSet struct {
Elem interface{}
Parent *DisjointSet
rank int
}
// Find follows the chain of parents up the tree until
// it reaches a root element, whose parent is itself.
// This root element is the representative member of the set to which ds belongs,
// and may be ds itself.
func (ds *DisjointSet) Find() *DisjointSet {
if ds == nil {
return nil
}
if ds.Parent == nil {
return ds
}
ds.Parent = ds.Parent.Find()
return ds.Parent
}
// Union uses Find to determine the roots of the trees ds and s belong to.
// If the roots are distinct, the trees are combined by attaching the root of one to the root of the other.
func (ds *DisjointSet) Union(s *DisjointSet) *DisjointSet {
p1, p2 := ds.Find(), s.Find()
if p1 == p2 {
return ds
}
if p1.rank < p2.rank {
p2.Parent = nil
p1.Parent = p2
return ds
}
p1.Parent = nil
p2.Parent = p1
if p1.rank == p2.rank {
p1.rank++
}
return s
}
func (ds *DisjointSet) String() string {
sb := &strings.Builder{}
var build func(h *DisjointSet, tab int)
build = func(s *DisjointSet, tab int) {
if s == nil {
return
}
sb.WriteString(strings.Repeat(".", tab))
sb.WriteString(fmt.Sprintf("(%v/%d):\n", s.Elem, s.rank))
build(s.Parent, tab+1)
sb.WriteRune('\n')
}
build(ds, 0)
return sb.String()
}
package sets
import "testing"
func TestUnionFind(t *testing.T) {
ds := []*DisjointSet{
&DisjointSet{Elem: 0},
&DisjointSet{Elem: 1},
&DisjointSet{Elem: 2},
&DisjointSet{Elem: 3},
&DisjointSet{Elem: 4},
&DisjointSet{Elem: 5},
&DisjointSet{Elem: 6},
&DisjointSet{Elem: 7},
}
ds[0].Union(ds[1]).Union(ds[4]).Union(ds[5]).Union(ds[7])
ds[2].Union(ds[3])
sets := make(map[int][]int)
for _, s := range ds {
if s.Parent == nil {
sets[s.Elem.(int)] = nil
} else {
l := sets[s.Parent.Elem.(int)]
sets[s.Parent.Elem.(int)] = append(l, s.Elem.(int))
}
}
l, ok := sets[0]
if !ok {
t.Fatalf("%v\n", l)
}
l, ok = sets[6]
if !ok {
t.Fatalf("%v\n", l)
}
l, ok = sets[2]
if !ok {
t.Fatalf("%v\n", l)
}
_, ok = sets[1]
if ok {
t.FailNow()
}
_, ok = sets[3]
if ok {
t.FailNow()
}
_, ok = sets[4]
if ok {
t.FailNow()
}
_, ok = sets[5]
if ok {
t.FailNow()
}
_, ok = sets[7]
if ok {
t.FailNow()
}
for p, l := range sets {
t.Logf("%v\n", append(l, p))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment