Skip to content

Instantly share code, notes, and snippets.

@sirupsen
Created September 15, 2013 23:26
Implementation of a general segment tree in Go.
package segment_tree
import "fmt"
const sizeMultiplier int = 3
type Comparator func(int, int) int
type Node struct {
set bool
value int
}
func (n *Node) updateValue(comparer Comparator, other int) {
if n.set {
n.value = comparer(n.value, other)
} else {
n.value = other
n.set = true
}
}
type UndefinedNode struct {
Start int
End int
}
func (n *UndefinedNode) Error() string {
return fmt.Sprintf("Node for %d..%d is not initialized: ", n.Start, n.End)
}
type SegmentTree struct {
tree []Node
comparer Comparator
}
func NewSegmentTree(size int, comparer Comparator) *SegmentTree {
seg := new(SegmentTree)
seg.tree = make([]Node, size*sizeMultiplier)
seg.comparer = comparer
for i := 0; i < size*sizeMultiplier; i++ {
seg.tree[i].set = false
}
return seg
}
func (s *SegmentTree) Update(key, value int) {
s.update(1, 1, len(s.tree)/sizeMultiplier, key, value)
}
func (s *SegmentTree) update(pos, mrange MemoryRange, key, value int) {
if key >= start && key <= end {
s.tree[pos].updateValue(s.comparer, value)
if end != start {
middle := (start + end) / 2
s.update(pos*2, start, middle, key, value)
s.update(pos*2+1, middle+1, end, key, value)
}
}
}
func (s *SegmentTree) Query(query_start, query_end int) int {
v, _ := s.query(1, 1, len(s.tree)/sizeMultiplier, query_start, query_end)
return v
}
func (s *SegmentTree) query(pos, start, end, query_start, query_end int) (n int, err error) {
if start >= query_start && end <= query_end {
if !s.tree[pos].set {
err = &UndefinedNode{start, end}
}
return s.tree[pos].value, err
} else if end < query_start || start > query_end {
return 0, &UndefinedNode{start, end}
} else {
middle := (start + end) / 2
left, leftOk := s.query(pos*2, start, middle, query_start, query_end)
right, rightOk := s.query(pos*2+1, middle+1, end, query_start, query_end)
err = nil
if leftOk != nil && rightOk != nil {
return 0, &UndefinedNode{start, end}
} else if leftOk == nil && rightOk != nil {
n = left
} else if rightOk == nil && leftOk != nil {
n = right
} else {
n = s.comparer(left, right)
}
return
}
}
package segment_tree
import (
"testing"
)
func AssertRange(stree *SegmentTree, t *testing.T, start, end, expected int) {
res := stree.Query(start, end)
if res != expected {
t.Errorf("min(%d..%d) ≠ %d, = %d", start, end, res, expected)
}
}
func MinimumComparison(a, b int) int {
if(a < b) {
return a
}
return b
}
func TestMinSegmentTree(t *testing.T) {
stree := NewSegmentTree(10, MinimumComparison)
stree.Update(1, 10)
stree.Update(2, 8)
stree.Update(3, 4)
stree.Update(4, 7)
stree.Update(5, 1)
stree.Update(6, 3)
stree.Update(7, 9)
stree.Update(8, 10)
stree.Update(9, 3)
stree.Update(10, 2)
AssertRange(stree, t, 1, 3, 4)
AssertRange(stree, t, 4, 9, 1)
}
func MaximumComparison(a, b int) int {
if(a > b) {
return a
}
return b
}
func TestMaxSegmentTree(t *testing.T) {
stree := NewSegmentTree(10, MaximumComparison)
stree.Update(1, 10)
stree.Update(2, 8)
stree.Update(3, 4)
stree.Update(4, 7)
stree.Update(5, 1)
stree.Update(6, 3)
stree.Update(7, 9)
stree.Update(8, 10)
stree.Update(9, 3)
stree.Update(10, 2)
AssertRange(stree, t, 1, 3, 10)
AssertRange(stree, t, 4, 9, 10)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment