Skip to content

Instantly share code, notes, and snippets.

@leiless
Created August 8, 2023 02:45
Show Gist options
  • Save leiless/7465b08600b45ef0bb1c450f8c1595da to your computer and use it in GitHub Desktop.
Save leiless/7465b08600b45ef0bb1c450f8c1595da to your computer and use it in GitHub Desktop.
Rust maintain top K elements on the fly
use std::cmp::Reverse;
use std::collections::BinaryHeap;
#[inline]
fn maintain_top_k<T: Ord>(min_heap: &mut BinaryHeap<Reverse<T>>, val: T, top_k: usize) {
if min_heap.len() < top_k {
min_heap.push(Reverse(val));
} else if top_k > 0 && val > min_heap.peek().unwrap().0 {
min_heap.pop();
min_heap.push(Reverse(val));
}
while min_heap.len() > top_k {
min_heap.pop();
}
}
#[inline(always)]
fn maintain_top_k_finalize<T: Ord + Copy>(min_heap: BinaryHeap<Reverse<T>>) -> Vec<T> {
min_heap.into_sorted_vec().iter().map(|x| x.0).collect()
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let top_k = 3usize;
let mut min_heap = BinaryHeap::with_capacity(top_k);
let items = [3, 8, 0, 2, 6, 9, 5, 4, 7, 1];
for i in items {
maintain_top_k(&mut min_heap, i, top_k);
}
let v: Vec<_> = maintain_top_k_finalize(min_heap);
println!("{:?}", v);
Ok(())
}
@leiless
Copy link
Author

leiless commented Aug 8, 2023

https://doc.rust-lang.org/stable/std/collections/struct.BinaryHeap.html

struct TopK<T>

use std::cmp::Reverse;
use std::collections::BinaryHeap;

#[derive(Debug)]
pub struct TopK<T> {
    min_heap: BinaryHeap<Reverse<T>>,
    top_k: usize,
}

impl<T: Ord + Clone> TopK<T> {
    pub fn new(top_k: usize) -> Self {
        Self {
            min_heap: BinaryHeap::with_capacity(top_k),
            top_k,
        }
    }

    #[inline]
    pub fn push(&mut self, val: T) -> bool {
        let mut dirty = false;

        if self.min_heap.len() < self.top_k {
            self.min_heap.push(Reverse(val));
            dirty = true;
        } else if self.top_k > 0 {
            // Push element in-place
            let mut min_elem = self.min_heap.peek_mut().unwrap();
            if val > min_elem.0 {
                *min_elem = Reverse(val);
                dirty = true;
            }
        }

        debug_assert!(self.min_heap.len() <= self.top_k,
                      "{} vs {}", self.min_heap.len(), self.top_k);

        dirty
    }

    // Top-K elements view
    #[inline(always)]
    pub fn to_sorted_vec(&self) -> Vec<T> {
        self.min_heap.clone().into_sorted_vec().iter().map(|e| e.0.clone()).collect()
    }
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let top_k_num = 3usize;
    let mut top_k = TopK::new(top_k_num);

    let items = [3, 8, 0, 2, 6, 9, 5, 4, 7, 1];
    for i in items {
        top_k.push(i);
    }

    let v = top_k.to_sorted_vec();
    println!("{:?}", v);

    Ok(())
}
[9, 8, 7]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment