Skip to content

Instantly share code, notes, and snippets.

@kendru
Created July 10, 2020 10:12
Show Gist options
  • Save kendru/323a1d7b90c7750e0ecde5cf155cf184 to your computer and use it in GitHub Desktop.
Save kendru/323a1d7b90c7750e0ecde5cf155cf184 to your computer and use it in GitHub Desktop.
Simple thread-safe bounded queue
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::sync::mpsc;
use std::thread;
struct QueueInner<T: Send> {
items: VecDeque<T>,
is_done: bool,
parked_producers: VecDeque<mpsc::Sender<()>>,
}
// Ideally, we would separate the reading and writing pieces of the queue
// so that we could guarantee a single producing thread.
pub struct Queue<T: Send> {
inner: Arc<Mutex<QueueInner<T>>>,
max_depth: Option<usize>,
}
impl<T: Send> Queue<T> {
pub fn new() -> Queue<T> {
Queue {
inner: Arc::new(Mutex::new(QueueInner {
items: VecDeque::new(),
is_done: false,
parked_producers: VecDeque::new(),
})),
max_depth: None,
}
}
pub fn with_max_depth(max_depth: usize) -> Queue<T> {
let mut queue = Queue::new();
queue.max_depth = Some(max_depth);
queue
}
// NOTICE: This will block the thread until another item is available or no
// more work is left to be done.
pub fn next_item(&self) -> Option<T> {
loop {
let mut inner = self.inner.lock()
.expect("Cannot recover from poisoned queue mutex");
match inner.items.pop_front() {
None => {
if inner.is_done {
return None;
}
}
v => {
if let Some(parked_producer) = inner.parked_producers.pop_front() {
// Resume 1 producer thread
parked_producer.send(()).expect("Could not wake producer");
}
return v
},
}
thread::yield_now();
}
}
pub fn add_item(&self, item: T) {
let mut inner = self.inner.lock()
.expect("Cannot recover from poisoned queue mutex");
if let Some(max_depth) = self.max_depth {
if inner.items.len() >= max_depth as usize {
// Wait for more items to be consumed before continuing
let (tx, rx) = mpsc::channel();
inner.parked_producers.push_back(tx);
drop(inner); // Release mutex while parked
rx.recv().expect("Could not receive wake notification");
// Re-acquire mutex once resumed
inner = self.inner.lock()
.expect("Cannot recover from poisoned queue mutex");
}
}
inner.items.push_back(item);
}
pub fn mark_done(&self) {
let mut inner = self.inner.lock().expect("Cannot recover from poisoned queue mutex");
if inner.parked_producers.len() > 0 {
panic!("Cannot mark a queue as done while there are parked producers");
}
inner.is_done = true;
}
}
impl<T> Clone for Queue<T>
where
T: Send
{
fn clone(&self) -> Self {
Queue {
inner: self.inner.clone(),
max_depth: self.max_depth,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU8, Ordering};
use super::*;
#[test]
fn distributes_work() {
let queue = Queue::<()>::new();
let queue1 = queue.clone();
let t1 = thread::spawn(move || {
while let Some(_) = queue1.next_item() {
// Do some work
}
});
let queue2 = queue.clone();
let t2 = thread::spawn(move || {
while let Some(_) = queue2.next_item() {
// Do some work
}
});
for _ in 1..100 {
queue.add_item(());
}
queue.mark_done(); // Allow workers to complete
// These threads should complete and not hang the main thread
t1.join().unwrap();
t2.join().unwrap();
}
#[test]
fn blocks_producer() {
let queue = Queue::<()>::with_max_depth(2);
let added_count = Arc::new(AtomicU8::new(0));
let (tx, rx) = mpsc::channel();
let queue_producer = queue.clone();
let ctr = added_count.clone();
let producer = thread::spawn(move || {
queue_producer.add_item(());
ctr.fetch_add(1, Ordering::SeqCst);
queue_producer.add_item(());
ctr.fetch_add(1, Ordering::SeqCst);
tx.send(()).unwrap();
queue_producer.add_item(());
ctr.fetch_add(1, Ordering::SeqCst);
queue_producer.mark_done(); // Allow workers to complete
});
rx.recv().unwrap();
thread::sleep(std::time::Duration::from_millis(5));
assert_eq!(2, added_count.load(Ordering::Acquire));
while let Some(_) = queue.next_item() {}
producer.join().unwrap();
assert_eq!(3, added_count.load(Ordering::Acquire));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment