Created
July 10, 2020 10:12
-
-
Save kendru/323a1d7b90c7750e0ecde5cf155cf184 to your computer and use it in GitHub Desktop.
Simple thread-safe bounded queue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::VecDeque; | |
use std::sync::{Arc, Mutex}; | |
use std::sync::mpsc; | |
use std::thread; | |
struct QueueInner<T: Send> { | |
items: VecDeque<T>, | |
is_done: bool, | |
parked_producers: VecDeque<mpsc::Sender<()>>, | |
} | |
// Ideally, we would separate the reading and writing pieces of the queue | |
// so that we could guarantee a single producing thread. | |
pub struct Queue<T: Send> { | |
inner: Arc<Mutex<QueueInner<T>>>, | |
max_depth: Option<usize>, | |
} | |
impl<T: Send> Queue<T> { | |
pub fn new() -> Queue<T> { | |
Queue { | |
inner: Arc::new(Mutex::new(QueueInner { | |
items: VecDeque::new(), | |
is_done: false, | |
parked_producers: VecDeque::new(), | |
})), | |
max_depth: None, | |
} | |
} | |
pub fn with_max_depth(max_depth: usize) -> Queue<T> { | |
let mut queue = Queue::new(); | |
queue.max_depth = Some(max_depth); | |
queue | |
} | |
// NOTICE: This will block the thread until another item is available or no | |
// more work is left to be done. | |
pub fn next_item(&self) -> Option<T> { | |
loop { | |
let mut inner = self.inner.lock() | |
.expect("Cannot recover from poisoned queue mutex"); | |
match inner.items.pop_front() { | |
None => { | |
if inner.is_done { | |
return None; | |
} | |
} | |
v => { | |
if let Some(parked_producer) = inner.parked_producers.pop_front() { | |
// Resume 1 producer thread | |
parked_producer.send(()).expect("Could not wake producer"); | |
} | |
return v | |
}, | |
} | |
thread::yield_now(); | |
} | |
} | |
pub fn add_item(&self, item: T) { | |
let mut inner = self.inner.lock() | |
.expect("Cannot recover from poisoned queue mutex"); | |
if let Some(max_depth) = self.max_depth { | |
if inner.items.len() >= max_depth as usize { | |
// Wait for more items to be consumed before continuing | |
let (tx, rx) = mpsc::channel(); | |
inner.parked_producers.push_back(tx); | |
drop(inner); // Release mutex while parked | |
rx.recv().expect("Could not receive wake notification"); | |
// Re-acquire mutex once resumed | |
inner = self.inner.lock() | |
.expect("Cannot recover from poisoned queue mutex"); | |
} | |
} | |
inner.items.push_back(item); | |
} | |
pub fn mark_done(&self) { | |
let mut inner = self.inner.lock().expect("Cannot recover from poisoned queue mutex"); | |
if inner.parked_producers.len() > 0 { | |
panic!("Cannot mark a queue as done while there are parked producers"); | |
} | |
inner.is_done = true; | |
} | |
} | |
impl<T> Clone for Queue<T> | |
where | |
T: Send | |
{ | |
fn clone(&self) -> Self { | |
Queue { | |
inner: self.inner.clone(), | |
max_depth: self.max_depth, | |
} | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use std::sync::atomic::{AtomicU8, Ordering}; | |
use super::*; | |
#[test] | |
fn distributes_work() { | |
let queue = Queue::<()>::new(); | |
let queue1 = queue.clone(); | |
let t1 = thread::spawn(move || { | |
while let Some(_) = queue1.next_item() { | |
// Do some work | |
} | |
}); | |
let queue2 = queue.clone(); | |
let t2 = thread::spawn(move || { | |
while let Some(_) = queue2.next_item() { | |
// Do some work | |
} | |
}); | |
for _ in 1..100 { | |
queue.add_item(()); | |
} | |
queue.mark_done(); // Allow workers to complete | |
// These threads should complete and not hang the main thread | |
t1.join().unwrap(); | |
t2.join().unwrap(); | |
} | |
#[test] | |
fn blocks_producer() { | |
let queue = Queue::<()>::with_max_depth(2); | |
let added_count = Arc::new(AtomicU8::new(0)); | |
let (tx, rx) = mpsc::channel(); | |
let queue_producer = queue.clone(); | |
let ctr = added_count.clone(); | |
let producer = thread::spawn(move || { | |
queue_producer.add_item(()); | |
ctr.fetch_add(1, Ordering::SeqCst); | |
queue_producer.add_item(()); | |
ctr.fetch_add(1, Ordering::SeqCst); | |
tx.send(()).unwrap(); | |
queue_producer.add_item(()); | |
ctr.fetch_add(1, Ordering::SeqCst); | |
queue_producer.mark_done(); // Allow workers to complete | |
}); | |
rx.recv().unwrap(); | |
thread::sleep(std::time::Duration::from_millis(5)); | |
assert_eq!(2, added_count.load(Ordering::Acquire)); | |
while let Some(_) = queue.next_item() {} | |
producer.join().unwrap(); | |
assert_eq!(3, added_count.load(Ordering::Acquire)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment