Created
October 8, 2023 11:29
-
-
Save DimanNe/b212b36d6e5845c7067ddc6eb8fe1085 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// This module contains helpers for managing lifetime of async functions. | |
/// In particular, it wraps the following two things from tokio: | |
/// * cancellation token to notify "async functions/actors" that they have to stop and | |
/// * tokio::sync::mpsc::Sender object (when it is dropped, parent knows that all its "children" has stopped) | |
/// We use a channel to wait until all children has died (likely, because cancellation_token had told them) | |
/// See more at https://tokio.rs/tokio/topics/shutdown | |
/// Usage: | |
/// ``` | |
/// async fn my_task(..., pill: Pill) { | |
/// loop { | |
/// tokio::select! { | |
/// _ = pill.received() => { | |
/// log::info!("Exiting..."); | |
/// return; | |
/// } | |
/// ... | |
/// } | |
/// ``` | |
pub struct Pill { | |
cancellation_token: tokio_util::sync::CancellationToken, | |
child_died_signal: tokio::sync::mpsc::Sender<()>, | |
name: String, | |
} | |
impl Drop for Pill { | |
fn drop(&mut self) { | |
if self.name.is_empty() == false { | |
log::info!("Children {} finished", self.name); | |
} | |
} | |
} | |
impl Pill { | |
pub fn received(&self) -> impl std::future::Future<Output = ()> + '_ { self.cancellation_token.cancelled() } | |
} | |
/// This struct is created and used by parent. | |
/// Example: | |
/// ``` | |
/// let mut ck = poison_pill::ChildrenStopper::from_existing_cancellation_token(ct); | |
/// tokio::spawn(write_tetra_events(client, | |
/// tetra_slot, | |
/// ck.register_child("write_tetra_events"))); | |
/// ck.stop_and_wait().await; | |
/// ``` | |
pub struct ChildrenStopper { | |
cancellation_token: tokio_util::sync::CancellationToken, | |
child_died_slot: tokio::sync::mpsc::Receiver<()>, | |
child_died_signal: tokio::sync::mpsc::Sender<()>, | |
children_names: std::collections::HashMap<String, u64>, | |
} | |
impl ChildrenStopper { | |
pub fn new() -> Self { Self::from_existing_cancellation_token(tokio_util::sync::CancellationToken::new()) } | |
pub fn from_existing_cancellation_token(ct: tokio_util::sync::CancellationToken) -> Self { | |
let (signal, slot) = tokio::sync::mpsc::channel(1); | |
ChildrenStopper { cancellation_token: ct, | |
child_died_slot: slot, | |
child_died_signal: signal, | |
children_names: Default::default(), } | |
} | |
pub fn register_child(&mut self, name: &str) -> Pill { | |
*self.children_names.entry(name.into()).or_insert(0) += 1; | |
Pill { cancellation_token: self.cancellation_token.clone(), | |
child_died_signal: self.child_died_signal.clone(), | |
name: name.into(), } | |
} | |
pub async fn stop_and_wait(mut self) { | |
drop(self.child_died_signal); // drop our sender first because the recv() call otherwise sleeps forever. | |
let total: u64 = self.children_names.values().sum(); | |
log::info!("Waiting for {total} children to finish..."); | |
loop { | |
tokio::select! { | |
// _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { | |
// println!("Printing a message every second..."); | |
// } | |
_ = self.child_died_slot.recv() => { | |
log::info!("All {total} children finished"); | |
break; | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment