Last active
October 28, 2025 15:38
-
-
Save Frando/3cd6ad0a0b493ad59076f25665215f1d to your computer and use it in GitHub Desktop.
WatchableMap
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::{ | |
| collections::BTreeMap, | |
| sync::{Arc, RwLock, RwLockReadGuard, Weak}, | |
| }; | |
| use n0_future::{Either, Stream, TryStreamExt}; | |
| use tokio::sync::broadcast::{self, error::RecvError}; | |
| use tokio_stream::{StreamExt, wrappers::BroadcastStream}; | |
| /// Events for [`MapWatcher`]. | |
| #[derive(Debug, Clone)] | |
| pub enum Event<K, V> { | |
| /// An entry was inserted. | |
| Insert(K, V), | |
| /// An entry was removed. | |
| Remove(K), | |
| } | |
| /// A map that can be watched. | |
| #[derive(Debug)] | |
| pub struct WatchableMap<K, V> { | |
| map: Arc<RwLock<BTreeMap<K, V>>>, | |
| events: broadcast::Sender<Event<K, V>>, | |
| } | |
| impl<K: std::hash::Hash + Eq + Ord + Clone + Send + 'static, V: Send + Clone + Eq + 'static> | |
| WatchableMap<K, V> | |
| { | |
| /// Creates a new watchable map. | |
| /// | |
| /// `events_cap` is the capacity of the event channel on which watchers receive live updates. | |
| pub fn new(events_cap: usize) -> Self { | |
| let (events, _) = broadcast::channel(events_cap); | |
| Self { | |
| map: Default::default(), | |
| events, | |
| } | |
| } | |
| /// Creates a new watcher for this map. | |
| pub fn watch(&self) -> MapWatcher<K, V> { | |
| MapWatcher { | |
| map: Arc::downgrade(&self.map), | |
| events: self.events.downgrade(), | |
| } | |
| } | |
| /// Removes all entries from the map. | |
| pub fn clear(&mut self) { | |
| let mut map = self.map.write().expect("poisoned"); | |
| let old = std::mem::replace(&mut *map, BTreeMap::default()); | |
| drop(map); | |
| for (k, _v) in old.into_iter() { | |
| let event = Event::Remove(k); | |
| self.events.send(event).ok(); | |
| } | |
| } | |
| /// Inserts a new item. | |
| /// | |
| /// Note that this takes `&mut self` even though this type has interior mutability. This is so that we can | |
| /// handle out a read guard in [`Self::read`] without having to fear deadlocks. | |
| pub fn insert(&mut self, key: K, value: V) -> Option<V> { | |
| let previous = self | |
| .map | |
| .write() | |
| .expect("poisoned") | |
| .insert(key.clone(), value.clone()); | |
| let changed = previous.as_ref().map(|p| p != &value).unwrap_or(true); | |
| if changed { | |
| let event = Event::Insert(key, value); | |
| self.events.send(event).ok(); | |
| } | |
| previous | |
| } | |
| /// Removes an item. | |
| /// | |
| /// Note that this takes `&mut self` even though this type has interior mutability. This is so that we can | |
| /// handle out a read guard in [`Self::read`] without having to fear deadlocks. | |
| pub fn remove(&mut self, key: K) -> Option<V> { | |
| let previous = self.map.write().expect("poisoned").remove(&key); | |
| let event = Event::Remove(key); | |
| self.events.send(event).ok(); | |
| previous | |
| } | |
| /// Returns a read guard for the current state. | |
| pub fn read(&self) -> RwLockReadGuard<'_, BTreeMap<K, V>> { | |
| self.map.read().expect("poisoned") | |
| } | |
| } | |
| /// A watcher for a map. | |
| #[derive(Debug)] | |
| pub struct MapWatcher<K, V> { | |
| map: Weak<RwLock<BTreeMap<K, V>>>, | |
| events: broadcast::WeakSender<Event<K, V>>, | |
| } | |
| impl<K: std::hash::Hash + Eq + Clone + Send + 'static, V: Send + Clone + 'static> MapWatcher<K, V> { | |
| /// Returns a clone of the current state. | |
| pub fn cloned(&self) -> Result<BTreeMap<K, V>, Disconnected> { | |
| match self.map.upgrade() { | |
| None => Err(Disconnected), | |
| Some(map) => Ok(map.read().expect("poisoned").clone()), | |
| } | |
| } | |
| /// Takes a closure that gets to view the current state. | |
| /// | |
| /// Cheaper than [`Self::cloned`] because it does not clone the current state. | |
| pub fn peek<R>(&self, f: impl Fn(&BTreeMap<K, V>) -> R) -> Result<R, Disconnected> { | |
| match self.map.upgrade() { | |
| None => Err(Disconnected), | |
| Some(map) => { | |
| let map = map.read().expect("poisoned"); | |
| Ok(f(&map)) | |
| } | |
| } | |
| } | |
| /// Returns a future that completes once the map is updated. | |
| pub async fn updated(&self) -> Result<(), Disconnected> { | |
| match self.events.upgrade() { | |
| None => Err(Disconnected), | |
| Some(sender) => match sender.subscribe().recv().await { | |
| Ok(_) => Ok(()), | |
| Err(RecvError::Lagged(_)) => Ok(()), | |
| Err(RecvError::Closed) => Err(Disconnected), | |
| }, | |
| } | |
| } | |
| /// Returns a stream of all current items and items inserted in the future. | |
| pub fn stream(&self) -> impl Stream<Item = Result<(K, V), Lagged>> { | |
| match self.map.upgrade() { | |
| None => Either::Left(n0_future::stream::empty()), | |
| Some(map) => { | |
| let updates = self.stream_updates_only(); | |
| let map = map.read().expect("poisoned"); | |
| let current = map.clone().into_iter().map(Result::Ok); | |
| Either::Right(n0_future::stream::iter(current).chain(updates)) | |
| } | |
| } | |
| } | |
| /// Returns a stream of all items inserted after this call. | |
| pub fn stream_updates_only(&self) -> impl Stream<Item = Result<(K, V), Lagged>> { | |
| self.events().filter_map(|event| match event { | |
| Ok(Event::Insert(k, v)) => Some(Ok((k, v))), | |
| Ok(Event::Remove(_)) => None, | |
| Err(err) => Some(Err(err)), | |
| }) | |
| } | |
| /// Returns a stream of insert and remove events. | |
| pub fn stream_events(&self) -> impl Stream<Item = Result<Event<K, V>, Lagged>> { | |
| match self.events.upgrade() { | |
| None => Either::Left(n0_future::stream::empty()), | |
| Some(sender) => { | |
| let stream = BroadcastStream::new(sender.subscribe()).map_err(|_| Lagged); | |
| Either::Right(stream) | |
| } | |
| } | |
| } | |
| } | |
| /// Error when the stream from [`Endpoint::watch_connections`] was not consumed fast enough. | |
| #[derive(Debug, Clone, derive_more::Display, Eq, PartialEq)] | |
| #[display("Lagged")] | |
| pub struct Lagged; | |
| /// Returned when a watchable was dropped. | |
| #[derive(Debug, Clone, derive_more::Display, Eq, PartialEq)] | |
| #[display("Disconnected")] | |
| pub struct Disconnected; | |
| impl std::error::Error for Lagged {} | |
| #[cfg(test)] | |
| mod tests { | |
| use super::WatchableMap; | |
| use tokio_stream::StreamExt; | |
| #[tokio::test] | |
| async fn watch_map() { | |
| let mut w = WatchableMap::new(32); | |
| w.insert("a", 1); | |
| let watcher = w.watch(); | |
| assert_eq!( | |
| watcher.cloned().unwrap().into_iter().collect::<Vec<_>>(), | |
| vec![("a", 1)] | |
| ); | |
| let s = watcher.stream(); | |
| let su = watcher.stream_updates_only(); | |
| w.insert("b", 2); | |
| drop(w); | |
| assert_eq!( | |
| s.collect::<Vec<_>>().await, | |
| vec![Ok(("a", 1)), Ok(("b", 2))] | |
| ); | |
| assert_eq!(su.collect::<Vec<_>>().await, vec![Ok(("b", 2))]); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment