Skip to content

Instantly share code, notes, and snippets.

@orlp
Last active November 14, 2020 04:17
Show Gist options
  • Save orlp/75989f6ead2f3e9ccd3375a56f8ea41c to your computer and use it in GitHub Desktop.
Save orlp/75989f6ead2f3e9ccd3375a56f8ea41c to your computer and use it in GitHub Desktop.
use std::collections::{HashSet, HashMap};
#[derive(Clone, Debug)]
struct CountBucket {
prev: i64,
next: i64,
indices: HashSet<usize>,
}
#[derive(Clone, Debug)]
pub struct CountTracker {
counts: Vec<i64>,
buckets: HashMap<i64, CountBucket>,
min_bucket: i64,
max_bucket: i64,
}
impl CountTracker {
pub fn new(n: usize) -> Self {
let mut buckets = HashMap::new();
buckets.insert(0, CountBucket {
prev: i64::MIN,
next: i64::MAX,
indices: (0..n).collect(),
});
Self {
counts: vec![0; n],
buckets,
min_bucket: 0,
max_bucket: 0,
}
}
fn set_prev(&mut self, to_update: i64, new_prev: i64) {
if let Some(update_bucket) = self.buckets.get_mut(&to_update) {
update_bucket.prev = new_prev;
} else {
self.max_bucket = new_prev;
}
}
fn set_next(&mut self, to_update: i64, new_next: i64) {
if let Some(update_bucket) = self.buckets.get_mut(&to_update) {
update_bucket.next = new_next;
} else {
self.min_bucket = new_next;
}
}
fn ensure_inc_bucket(&mut self, count: i64) {
if self.buckets.contains_key(&(count + 1)) {
return;
}
let next = self.buckets[&count].next;
let new_bucket = CountBucket {
prev: count,
next: next,
indices: HashSet::new(),
};
self.buckets.insert(count + 1, new_bucket);
self.set_prev(next, count + 1);
self.set_next(count, count + 1);
}
fn ensure_dec_bucket(&mut self, count: i64) {
if self.buckets.contains_key(&(count - 1)) {
return;
}
let prev = self.buckets[&count].prev;
let new_bucket = CountBucket {
prev: prev,
next: count,
indices: HashSet::new(),
};
self.buckets.insert(count - 1, new_bucket);
self.set_prev(count, count - 1);
self.set_next(prev, count - 1);
}
fn maybe_collapse_bucket(&mut self, count: i64) {
if self.buckets[&count].indices.len() > 0 {
return;
}
let CountBucket { prev, next, .. } = self.buckets[&count];
self.set_prev(next, prev);
self.set_next(prev, next);
self.buckets.remove(&count);
}
pub fn inc(&mut self, i: usize) {
let count = self.counts[i];
self.ensure_inc_bucket(count);
self.buckets.get_mut(&count).unwrap().indices.remove(&i);
self.buckets.get_mut(&(count + 1)).unwrap().indices.insert(i);
self.maybe_collapse_bucket(count);
self.counts[i] += 1;
}
pub fn dec(&mut self, i: usize) {
let count = self.counts[i];
self.ensure_dec_bucket(count);
self.buckets.get_mut(&count).unwrap().indices.remove(&i);
self.buckets.get_mut(&(count - 1)).unwrap().indices.insert(i);
self.maybe_collapse_bucket(count);
self.counts[i] -= 1;
}
pub fn get(&self, i: usize) -> i64 {
self.counts[i]
}
pub fn min_idx(&self) -> usize {
*self.buckets[&self.min_bucket].indices.iter().next().unwrap()
}
pub fn max_idx(&self) -> usize {
*self.buckets[&self.max_bucket].indices.iter().next().unwrap()
}
pub fn len(&self) -> usize {
self.counts.len()
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_count_tracker() {
use super::CountTracker;
use rand::Rng;
for n in 1..10 {
let mut c = CountTracker::new(n);
let mut v = vec![0i64; n];
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let idx = rng.gen_range(0, n);
if rng.gen() {
c.inc(idx);
v[idx] += 1;
} else {
c.dec(idx);
v[idx] -= 1;
}
assert!((0..n).all(|i| c.get(i) == v[i]));
assert!(c.get(c.min_idx()) == *v.iter().min().unwrap());
assert!(c.get(c.max_idx()) == *v.iter().max().unwrap());
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment