Created
October 22, 2020 18:02
-
-
Save averagesecurityguy/93a55eae51de6649f5d8622a185c548e to your computer and use it in GitHub Desktop.
Simpler, generic Single Producer, Multiple Consumer pattern
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::sync::mpsc; | |
use std::sync; | |
pub struct Spmc<T> { | |
count: usize, | |
chans: Vec::<mpsc::Sender::<Option<T>>>, | |
} | |
impl<T> Spmc<T> { | |
pub fn send(&mut self, item: T) { | |
self.chans[self.count % self.chans.len()].send(Some(item)).unwrap(); | |
self.count += 1; | |
} | |
pub fn add_receiver(&mut self) -> mpsc::Receiver<Option<T>> { | |
let (tx, rx) = mpsc::channel(); | |
self.chans.push(tx); | |
rx | |
} | |
pub fn close(&mut self) { | |
for i in 0..self.chans.len() { | |
self.chans[i].send(None).unwrap(); | |
} | |
} | |
pub fn new() -> Spmc<T> { | |
Spmc { count: 0, chans: vec![] } | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use std::thread; | |
#[test] | |
fn test_spmc_int() { | |
// Send a list of integers through our Spmc and collect them on the | |
// other side using a channel. Our nums_to_send and our nums_received | |
// vectors should be the same once the process is complete. | |
let nums_to_send = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; | |
let mut nums_received = vec![]; | |
let mut s = Spmc::new(); | |
let mut threads = vec![]; | |
let (send, recv) = sync::mpsc::channel(); | |
// Create threads to process the values. | |
for _ in 0..5 { | |
let rx = s.add_receiver(); | |
let tx = mpsc::Sender::clone(&send); | |
threads.push(thread::spawn(move || { | |
for val in rx { | |
match val { | |
Some(val) => tx.send(val).unwrap(), | |
None => return | |
} | |
} | |
})); | |
} | |
drop(send); | |
// Send our values to Spmc to be processed by multiple consumers. | |
for i in 0..nums_to_send.len() { | |
s.send(nums_to_send[i]); | |
} | |
// Close our channels once all data is added. | |
s.close(); | |
// Collect all of our sent values. | |
for r in recv { | |
nums_received.push(r); | |
} | |
// Wait for all threads to complete. | |
for i in threads { | |
let _ = i.join(); | |
} | |
// Sort our numbers since they are out of order after being processed. | |
nums_received.sort(); | |
assert_eq!(nums_to_send, nums_received); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment