Last active
October 27, 2016 13:38
-
-
Save gsingh93/dc5ebe6c8a1582731918 to your computer and use it in GitHub Desktop.
Segment tree implementation in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::fmt::Show; | |
use std::default::Default; | |
use std::ops::Add; | |
struct SegmentTree<T> { | |
size: uint, | |
root: Node<T> | |
} | |
struct Node<T> { | |
left: Option<Box<Node<T>>>, | |
right: Option<Box<Node<T>>>, | |
val: T | |
} | |
impl<T: Default + Clone + Add<T, T>> SegmentTree<T> { | |
pub fn new(elts: &[T]) -> SegmentTree<T> { | |
let root = SegmentTree::build(elts, 0, elts.len() - 1); | |
SegmentTree { size: elts.len(), root: root } | |
} | |
pub fn query(&self, start: uint, end: uint) -> Result<T, String> { | |
if end >= self.size { | |
return Err("Out of bounds".to_string()); | |
} else if start > end { | |
return Err("Start of query range can't be greater \ | |
than end of range".to_string()); | |
} | |
Ok(self.query_(0, self.size - 1, start, end, &self.root)) | |
} | |
fn build(elts: &[T], left: uint, right: uint) -> Node<T> { | |
if elts.len() == 0 { | |
return Node { left: None, right: None, val: Default::default() } | |
} | |
let mut node = Node { left: None, right: None, | |
val: elts[left].clone() }; | |
if left == right { | |
return node; | |
} | |
let mid = (left + right) / 2; | |
node.left = Some(box SegmentTree::build(elts, left, mid)); | |
node.right = Some(box SegmentTree::build(elts, mid + 1, right)); | |
match (&node.left, &node.right) { | |
(&Some(ref l), &Some(ref r)) => node.val = l.val + r.val, | |
_ => () | |
} | |
node | |
} | |
fn query_(&self, left: uint, right: uint, start: uint, end: uint, | |
cur: &Node<T>) -> T { | |
if left == right || (left == start && right == end) { | |
return cur.val.clone(); | |
} | |
let cr = match cur.right { | |
Some(box ref node) => node, | |
None => fail!("") | |
}; | |
let cl = match cur.left { | |
Some(box ref node) => node, | |
None => fail!("") | |
}; | |
let mid = (left + right) / 2; | |
if start > mid { | |
return self.query_(mid + 1, right, start, end, cr); | |
} else if end <= mid { | |
return self.query_(left, mid, start, end, cl); | |
} else { | |
return self.query_(left, mid, start, end, cl) + | |
self.query_(mid + 1, right, start, end, cr); | |
} | |
} | |
} | |
#[test] | |
fn segment_tree_test() { | |
let v: Vec<int> = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); | |
let tree: SegmentTree<int> = SegmentTree::new(v.as_slice()); | |
check(&tree, v.as_slice(), 0, 9); | |
check(&tree, v.as_slice(), 0, 1); | |
check(&tree, v.as_slice(), 4, 7); | |
} | |
#[test] | |
fn segment_tree_empty_test() { | |
let v: Vec<int> = vec!(); | |
SegmentTree::new(v.as_slice()); | |
} | |
#[test] | |
fn segment_tree_out_of_range_test() { | |
let v: Vec<int> = vec!(1); | |
let tree = SegmentTree::new(v.as_slice()); | |
check(&tree, v.as_slice(), 0, 0); | |
assert!(tree.query(0, 1).is_err()); | |
} | |
#[test] | |
fn segment_tree_backwards_range_test() { | |
let v: Vec<int> = vec!(1, 2, 3); | |
let tree = SegmentTree::new(v.as_slice()); | |
check(&tree, v.as_slice(), 0, 2); | |
assert!(tree.query(2, 0).is_err()); | |
} | |
#[cfg(test)] | |
fn check<T: Default + Add<T, T> + Show + Eq + Clone>(tree: &SegmentTree<T>, | |
elts: &[T], start: uint, | |
end: uint) { | |
assert_eq!(tree.query(start, end).unwrap(), | |
query(elts.as_slice(), start, end)); | |
} | |
#[cfg(test)] | |
fn query<T: Add<T, T> + Show + Default>(elts: &[T], start: uint, | |
end: uint) -> T { | |
elts.iter().skip(start).take(end - start + 1).fold(Default::default(), | |
|a: T, b| a + *b) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment