Created
July 8, 2020 07:49
-
-
Save suyash/abee2e135b7321bf1aff36699cf0595e to your computer and use it in GitHub Desktop.
Generic Segment Tree implemented for custom comparator and min functions.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Generic Segment Tree implemented for custom comparator and min functions. | |
//! | |
//! See Chapters 15, 16 in Blandy and Orendorff | |
/// Segment Tree allows for querying information over ranges in a continuous data stream. | |
pub struct SegmentTree<T, Comparator, Minimum> { | |
tree: Vec<T>, | |
n: usize, | |
f: Comparator, | |
d: Minimum, | |
} | |
impl<T, Comparator, Minimum> SegmentTree<T, Comparator, Minimum> | |
where | |
T: Copy, | |
Comparator: Fn(T, T) -> T, | |
Minimum: Fn() -> T, | |
{ | |
/// new creates a new instance of the SegmentTree | |
pub fn new(data: &[T], f: Comparator, d: Minimum) -> Self { | |
let n = data.len(); | |
let h = 1 + ((n as f64).log2().ceil() as usize); | |
let s = (1 << h) - 1; | |
let tree = vec![d(); s]; | |
let mut ans = SegmentTree { tree, n, f, d }; | |
ans.new_(0, n - 1, 0, data); | |
ans | |
} | |
fn new_(&mut self, i: usize, j: usize, k: usize, data: &[T]) -> T { | |
if i == j { | |
self.tree[k] = data[i]; | |
} else { | |
let mid = i + (j - i) / 2; | |
let v1 = self.new_(i, mid, 2 * k + 1, data); | |
let v2 = self.new_(mid + 1, j, 2 * k + 2, data); | |
self.tree[k] = (self.f)(v1, v2); | |
} | |
self.tree[k] | |
} | |
/// query returns the result for the specified [start..end] range | |
pub fn query(&self, start: usize, end: usize) -> T { | |
self.query_(start, end, 0, self.n - 1, 0) | |
} | |
fn query_(&self, start: usize, end: usize, i: usize, j: usize, k: usize) -> T { | |
if start <= i && end >= j { | |
self.tree[k] | |
} else if i > end || j < start { | |
(self.d)() | |
} else { | |
let mid = i + (j - i) / 2; | |
let v1 = self.query_(start, end, i, mid, 2 * k + 1); | |
let v2 = self.query_(start, end, mid + 1, j, 2 * k + 2); | |
(self.f)(v1, v2) | |
} | |
} | |
/// update updates the tree representation to accomodate for the changed value | |
/// at the provided index in the original data. | |
pub fn update(&mut self, val: T, index: usize) { | |
self.update_(val, index, 0, self.n - 1, 0); | |
} | |
fn update_(&mut self, val: T, index: usize, i: usize, j: usize, k: usize) -> T { | |
if i == index && i == j { | |
self.tree[k] = val; | |
} else if index >= i && index <= j { | |
let mid = i + (j - i) / 2; | |
let v1 = self.update_(val, index, i, mid, 2 * k + 1); | |
let v2 = self.update_(val, index, mid + 1, j, 2 * k + 2); | |
self.tree[k] = (self.f)(v1, v2); | |
} | |
self.tree[k] | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
#[test] | |
fn test_max_tree() { | |
let tree = super::SegmentTree::new( | |
&vec![1, 2, 3, 4, 5], | |
|a, b| if a > b { a } else { b }, | |
|| i32::min_value(), | |
); | |
assert_eq!(tree.query(1, 3), 4); | |
assert_eq!(tree.query(0, 0), 1); | |
assert_eq!(tree.query(2, 2), 3); | |
assert_eq!(tree.query(0, 4), 5); | |
let mut tree = super::SegmentTree::new( | |
&vec![-1, -2, -3, -4, -5], | |
|a, b| if a > b { a } else { b }, | |
|| i32::min_value(), | |
); | |
assert_eq!(tree.query(1, 3), -2); | |
assert_eq!(tree.query(0, 0), -1); | |
assert_eq!(tree.query(2, 2), -3); | |
assert_eq!(tree.query(0, 4), -1); | |
tree.update(2, 2); | |
assert_eq!(tree.query(0, 4), 2); | |
assert_eq!(tree.query(3, 4), -4); | |
assert_eq!(tree.query(0, 1), -1); | |
assert_eq!(tree.query(0, 2), 2); | |
} | |
#[test] | |
fn test_min_tree() { | |
let tree = super::SegmentTree::new( | |
&vec![1, 2, 3, 4, 5], | |
|a, b| if a < b { a } else { b }, | |
|| i32::max_value(), | |
); | |
assert_eq!(tree.query(1, 3), 2); | |
assert_eq!(tree.query(0, 0), 1); | |
assert_eq!(tree.query(2, 2), 3); | |
assert_eq!(tree.query(0, 4), 1); | |
let mut tree = super::SegmentTree::new( | |
&vec![-1, -2, -3, -4, -5], | |
|a, b| if a < b { a } else { b }, | |
|| i32::max_value(), | |
); | |
assert_eq!(tree.query(1, 3), -4); | |
assert_eq!(tree.query(0, 0), -1); | |
assert_eq!(tree.query(2, 2), -3); | |
assert_eq!(tree.query(0, 4), -5); | |
tree.update(-6, 2); | |
assert_eq!(tree.query(0, 4), -6); | |
assert_eq!(tree.query(3, 4), -5); | |
assert_eq!(tree.query(0, 1), -2); | |
assert_eq!(tree.query(0, 2), -6); | |
} | |
#[test] | |
fn test_sum_tree() { | |
let tree = super::SegmentTree::new( | |
&vec![1, 2, 3, 4, 5], | |
|a, b| a + b, | |
|| 0, | |
); | |
assert_eq!(tree.query(1, 3), 9); | |
assert_eq!(tree.query(0, 0), 1); | |
assert_eq!(tree.query(2, 2), 3); | |
assert_eq!(tree.query(0, 4), 15); | |
let mut tree = super::SegmentTree::new( | |
&vec![-1, -2, -3, -4, -5], | |
|a, b| a + b, | |
|| 0, | |
); | |
assert_eq!(tree.query(1, 3), -9); | |
assert_eq!(tree.query(0, 0), -1); | |
assert_eq!(tree.query(2, 2), -3); | |
assert_eq!(tree.query(0, 4), -15); | |
tree.update(2, 2); | |
assert_eq!(tree.query(0, 4), -10); | |
assert_eq!(tree.query(3, 4), -9); | |
assert_eq!(tree.query(0, 1), -3); | |
assert_eq!(tree.query(0, 2), -1); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment