Skip to content

Instantly share code, notes, and snippets.

@suyash
Created July 8, 2020 07:49
Show Gist options
  • Save suyash/abee2e135b7321bf1aff36699cf0595e to your computer and use it in GitHub Desktop.
Save suyash/abee2e135b7321bf1aff36699cf0595e to your computer and use it in GitHub Desktop.
Generic Segment Tree implemented for custom comparator and min functions.
//! Generic Segment Tree implemented for custom comparator and min functions.
//!
//! See Chapters 15, 16 in Blandy and Orendorff
/// Segment Tree allows for querying information over ranges in a continuous data stream.
pub struct SegmentTree<T, Comparator, Minimum> {
tree: Vec<T>,
n: usize,
f: Comparator,
d: Minimum,
}
impl<T, Comparator, Minimum> SegmentTree<T, Comparator, Minimum>
where
T: Copy,
Comparator: Fn(T, T) -> T,
Minimum: Fn() -> T,
{
/// new creates a new instance of the SegmentTree
pub fn new(data: &[T], f: Comparator, d: Minimum) -> Self {
let n = data.len();
let h = 1 + ((n as f64).log2().ceil() as usize);
let s = (1 << h) - 1;
let tree = vec![d(); s];
let mut ans = SegmentTree { tree, n, f, d };
ans.new_(0, n - 1, 0, data);
ans
}
fn new_(&mut self, i: usize, j: usize, k: usize, data: &[T]) -> T {
if i == j {
self.tree[k] = data[i];
} else {
let mid = i + (j - i) / 2;
let v1 = self.new_(i, mid, 2 * k + 1, data);
let v2 = self.new_(mid + 1, j, 2 * k + 2, data);
self.tree[k] = (self.f)(v1, v2);
}
self.tree[k]
}
/// query returns the result for the specified [start..end] range
pub fn query(&self, start: usize, end: usize) -> T {
self.query_(start, end, 0, self.n - 1, 0)
}
fn query_(&self, start: usize, end: usize, i: usize, j: usize, k: usize) -> T {
if start <= i && end >= j {
self.tree[k]
} else if i > end || j < start {
(self.d)()
} else {
let mid = i + (j - i) / 2;
let v1 = self.query_(start, end, i, mid, 2 * k + 1);
let v2 = self.query_(start, end, mid + 1, j, 2 * k + 2);
(self.f)(v1, v2)
}
}
/// update updates the tree representation to accomodate for the changed value
/// at the provided index in the original data.
pub fn update(&mut self, val: T, index: usize) {
self.update_(val, index, 0, self.n - 1, 0);
}
fn update_(&mut self, val: T, index: usize, i: usize, j: usize, k: usize) -> T {
if i == index && i == j {
self.tree[k] = val;
} else if index >= i && index <= j {
let mid = i + (j - i) / 2;
let v1 = self.update_(val, index, i, mid, 2 * k + 1);
let v2 = self.update_(val, index, mid + 1, j, 2 * k + 2);
self.tree[k] = (self.f)(v1, v2);
}
self.tree[k]
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_max_tree() {
let tree = super::SegmentTree::new(
&vec![1, 2, 3, 4, 5],
|a, b| if a > b { a } else { b },
|| i32::min_value(),
);
assert_eq!(tree.query(1, 3), 4);
assert_eq!(tree.query(0, 0), 1);
assert_eq!(tree.query(2, 2), 3);
assert_eq!(tree.query(0, 4), 5);
let mut tree = super::SegmentTree::new(
&vec![-1, -2, -3, -4, -5],
|a, b| if a > b { a } else { b },
|| i32::min_value(),
);
assert_eq!(tree.query(1, 3), -2);
assert_eq!(tree.query(0, 0), -1);
assert_eq!(tree.query(2, 2), -3);
assert_eq!(tree.query(0, 4), -1);
tree.update(2, 2);
assert_eq!(tree.query(0, 4), 2);
assert_eq!(tree.query(3, 4), -4);
assert_eq!(tree.query(0, 1), -1);
assert_eq!(tree.query(0, 2), 2);
}
#[test]
fn test_min_tree() {
let tree = super::SegmentTree::new(
&vec![1, 2, 3, 4, 5],
|a, b| if a < b { a } else { b },
|| i32::max_value(),
);
assert_eq!(tree.query(1, 3), 2);
assert_eq!(tree.query(0, 0), 1);
assert_eq!(tree.query(2, 2), 3);
assert_eq!(tree.query(0, 4), 1);
let mut tree = super::SegmentTree::new(
&vec![-1, -2, -3, -4, -5],
|a, b| if a < b { a } else { b },
|| i32::max_value(),
);
assert_eq!(tree.query(1, 3), -4);
assert_eq!(tree.query(0, 0), -1);
assert_eq!(tree.query(2, 2), -3);
assert_eq!(tree.query(0, 4), -5);
tree.update(-6, 2);
assert_eq!(tree.query(0, 4), -6);
assert_eq!(tree.query(3, 4), -5);
assert_eq!(tree.query(0, 1), -2);
assert_eq!(tree.query(0, 2), -6);
}
#[test]
fn test_sum_tree() {
let tree = super::SegmentTree::new(
&vec![1, 2, 3, 4, 5],
|a, b| a + b,
|| 0,
);
assert_eq!(tree.query(1, 3), 9);
assert_eq!(tree.query(0, 0), 1);
assert_eq!(tree.query(2, 2), 3);
assert_eq!(tree.query(0, 4), 15);
let mut tree = super::SegmentTree::new(
&vec![-1, -2, -3, -4, -5],
|a, b| a + b,
|| 0,
);
assert_eq!(tree.query(1, 3), -9);
assert_eq!(tree.query(0, 0), -1);
assert_eq!(tree.query(2, 2), -3);
assert_eq!(tree.query(0, 4), -15);
tree.update(2, 2);
assert_eq!(tree.query(0, 4), -10);
assert_eq!(tree.query(3, 4), -9);
assert_eq!(tree.query(0, 1), -3);
assert_eq!(tree.query(0, 2), -1);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment