Created
September 15, 2013 23:26
Implementation of a general segment tree in Go.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package segment_tree | |
import "fmt" | |
const sizeMultiplier int = 3 | |
type Comparator func(int, int) int | |
type Node struct { | |
set bool | |
value int | |
} | |
func (n *Node) updateValue(comparer Comparator, other int) { | |
if n.set { | |
n.value = comparer(n.value, other) | |
} else { | |
n.value = other | |
n.set = true | |
} | |
} | |
type UndefinedNode struct { | |
Start int | |
End int | |
} | |
func (n *UndefinedNode) Error() string { | |
return fmt.Sprintf("Node for %d..%d is not initialized: ", n.Start, n.End) | |
} | |
type SegmentTree struct { | |
tree []Node | |
comparer Comparator | |
} | |
func NewSegmentTree(size int, comparer Comparator) *SegmentTree { | |
seg := new(SegmentTree) | |
seg.tree = make([]Node, size*sizeMultiplier) | |
seg.comparer = comparer | |
for i := 0; i < size*sizeMultiplier; i++ { | |
seg.tree[i].set = false | |
} | |
return seg | |
} | |
func (s *SegmentTree) Update(key, value int) { | |
s.update(1, 1, len(s.tree)/sizeMultiplier, key, value) | |
} | |
func (s *SegmentTree) update(pos, mrange MemoryRange, key, value int) { | |
if key >= start && key <= end { | |
s.tree[pos].updateValue(s.comparer, value) | |
if end != start { | |
middle := (start + end) / 2 | |
s.update(pos*2, start, middle, key, value) | |
s.update(pos*2+1, middle+1, end, key, value) | |
} | |
} | |
} | |
func (s *SegmentTree) Query(query_start, query_end int) int { | |
v, _ := s.query(1, 1, len(s.tree)/sizeMultiplier, query_start, query_end) | |
return v | |
} | |
func (s *SegmentTree) query(pos, start, end, query_start, query_end int) (n int, err error) { | |
if start >= query_start && end <= query_end { | |
if !s.tree[pos].set { | |
err = &UndefinedNode{start, end} | |
} | |
return s.tree[pos].value, err | |
} else if end < query_start || start > query_end { | |
return 0, &UndefinedNode{start, end} | |
} else { | |
middle := (start + end) / 2 | |
left, leftOk := s.query(pos*2, start, middle, query_start, query_end) | |
right, rightOk := s.query(pos*2+1, middle+1, end, query_start, query_end) | |
err = nil | |
if leftOk != nil && rightOk != nil { | |
return 0, &UndefinedNode{start, end} | |
} else if leftOk == nil && rightOk != nil { | |
n = left | |
} else if rightOk == nil && leftOk != nil { | |
n = right | |
} else { | |
n = s.comparer(left, right) | |
} | |
return | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package segment_tree | |
import ( | |
"testing" | |
) | |
func AssertRange(stree *SegmentTree, t *testing.T, start, end, expected int) { | |
res := stree.Query(start, end) | |
if res != expected { | |
t.Errorf("min(%d..%d) ≠ %d, = %d", start, end, res, expected) | |
} | |
} | |
func MinimumComparison(a, b int) int { | |
if(a < b) { | |
return a | |
} | |
return b | |
} | |
func TestMinSegmentTree(t *testing.T) { | |
stree := NewSegmentTree(10, MinimumComparison) | |
stree.Update(1, 10) | |
stree.Update(2, 8) | |
stree.Update(3, 4) | |
stree.Update(4, 7) | |
stree.Update(5, 1) | |
stree.Update(6, 3) | |
stree.Update(7, 9) | |
stree.Update(8, 10) | |
stree.Update(9, 3) | |
stree.Update(10, 2) | |
AssertRange(stree, t, 1, 3, 4) | |
AssertRange(stree, t, 4, 9, 1) | |
} | |
func MaximumComparison(a, b int) int { | |
if(a > b) { | |
return a | |
} | |
return b | |
} | |
func TestMaxSegmentTree(t *testing.T) { | |
stree := NewSegmentTree(10, MaximumComparison) | |
stree.Update(1, 10) | |
stree.Update(2, 8) | |
stree.Update(3, 4) | |
stree.Update(4, 7) | |
stree.Update(5, 1) | |
stree.Update(6, 3) | |
stree.Update(7, 9) | |
stree.Update(8, 10) | |
stree.Update(9, 3) | |
stree.Update(10, 2) | |
AssertRange(stree, t, 1, 3, 10) | |
AssertRange(stree, t, 4, 9, 10) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment