Created
August 8, 2025 04:43
-
-
Save CGamesPlay/967f06d3421f19c1bffc4aebb35b5ea3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Proof of concept of a tokio::sync::watch-style channel with WeakSender support | |
//! Copyright Ryan Patterson - MIT license | |
use std::sync::{ | |
atomic::{AtomicUsize, Ordering}, | |
Arc, RwLock, RwLockReadGuard, TryLockError, Weak, | |
}; | |
use tokio::sync::Notify; | |
/// Thread-safe reference-counted reader-writer lock with change notification. | |
/// | |
/// Provides thread-safe access to a shared resource like `Arc<RwLock<T>>`, but additionally | |
/// provides methods to observe when other tasks update the value. | |
/// | |
/// [WatchWeak] objects created with [WatchHandle::downgrade] are additionally able to receive a | |
/// notification when the last WatchHandle is dropped. | |
pub struct WatchHandle<T> { | |
shared: Arc<WatchShared<T>>, | |
seen_version: AtomicUsize, | |
} | |
/// Weak reference to a [WatchHandle]. | |
pub struct WatchWeak<T> { | |
shared: Weak<WatchShared<T>>, | |
notify: Arc<Notify>, | |
seen_version: AtomicUsize, | |
} | |
struct WatchShared<T> { | |
data: RwLock<T>, | |
version: AtomicUsize, | |
notify: Arc<Notify>, | |
} | |
pub struct WatchRef<'a, T>(RwLockReadGuard<'a, T>); | |
pub struct OwnedWatchRef<'a, T>( | |
#[allow(dead_code, reason = "must maintain strong reference")] Arc<WatchShared<T>>, | |
RwLockReadGuard<'a, T>, | |
); | |
#[derive(Debug, thiserror::Error)] | |
#[error("WatchHandle was dropped")] | |
pub struct WatchWeakError<T>(pub T); | |
impl<T> WatchHandle<T> { | |
pub fn new(data: T) -> Self { | |
Self { | |
shared: Arc::new(WatchShared { | |
data: RwLock::new(data), | |
version: AtomicUsize::new(0), | |
notify: Arc::new(Notify::new()), | |
}), | |
seen_version: AtomicUsize::new(0), | |
} | |
} | |
pub fn ptr_eq(&self, other: &WatchHandle<T>) -> bool { | |
Arc::ptr_eq(&self.shared, &other.shared) | |
} | |
pub fn strong_count(&self) -> usize { | |
Arc::strong_count(&self.shared) | |
} | |
pub fn weak_count(&self) -> usize { | |
Arc::weak_count(&self.shared) | |
} | |
pub fn downgrade(&self) -> WatchWeak<T> { | |
WatchWeak { | |
shared: Arc::downgrade(&self.shared), | |
notify: self.shared.notify.clone(), | |
seen_version: self.seen_version.load(Ordering::Relaxed).into(), | |
} | |
} | |
/// Return a reference to the current value. | |
/// | |
/// This method does not mark the current value as seen, use [Self::borrow_and_update] to do that in | |
/// an atomic fashion. | |
/// | |
/// The returned guard holds a read lock on the value. | |
pub fn borrow(&self) -> WatchRef<T> { | |
WatchRef(self.shared.data.read().unwrap()) | |
} | |
/// Return a reference to the current value and mark it as seen. | |
/// | |
/// The returned guard holds a read lock on the value. | |
pub fn borrow_and_update(&self) -> WatchRef<T> { | |
let guard = self.shared.data.read().unwrap(); | |
let latest_version = self.shared.version.load(Ordering::Relaxed); | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
WatchRef(guard) | |
} | |
/// Replace the current value with another and notify. | |
/// | |
/// The new value will be marked seen. | |
/// | |
/// Returns the previous value. | |
pub fn set(&self, value: T) -> T { | |
self.shared.set(value, &self.seen_version) | |
} | |
/// Update the current value in-place and notify. | |
/// | |
/// The new value will be marked seen. | |
pub fn update(&self, modify: impl FnOnce(&mut T)) { | |
self.shared.update(modify, &self.seen_version); | |
} | |
/// Update the current value in-place, conditionally. | |
/// | |
/// The closure must return `true` if the value has actually been modified, in order to send | |
/// change notifications. The value will be marked seen regardless. | |
pub fn maybe_update(&self, modify: impl FnOnce(&mut T) -> bool) { | |
self.shared.maybe_update(modify, &self.seen_version); | |
} | |
/// Checks if the value has changed. | |
/// | |
/// The value is marked as "unchanged" by [Self::borrow_and_update], [Self::changed], and [Self::mark_unchanged]. | |
pub fn has_changed(&self) -> bool { | |
let latest_version = self.shared.version.load(Ordering::Relaxed); | |
let seen_version = self.seen_version.load(Ordering::Relaxed); | |
seen_version != latest_version | |
} | |
/// Mark the current value as unseen by this handle. | |
/// | |
/// **Note**: This method never causes a pending async change notification on any | |
/// handles, including this one. | |
pub fn mark_changed(&self) { | |
let latest_version = self.shared.version.load(Ordering::Relaxed); | |
self.seen_version | |
.store(latest_version.wrapping_sub(1), Ordering::Relaxed); | |
} | |
/// Mark the current value as seen by this handle. | |
pub fn mark_unchanged(&self) { | |
let latest_version = self.shared.version.load(Ordering::Relaxed); | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
} | |
/// Waits for the next unseen value. | |
/// | |
/// This method ignores changes that originate from this handle, including | |
/// [Self::mark_changed]. | |
/// | |
/// If you are using this method in a loop, it is recommended to combine it with | |
/// [Self::borrow_and_update] instead of [Self::borrow]. This avoid a race condition where the | |
/// value is changed after `changed` resolves but before `borrow` starts, which would result in | |
/// the next call to `changed` immediately resolving, and seeing the same value in the | |
/// subsequent `borrow`. | |
#[allow(clippy::future_not_send, reason = "Send when T is Send")] | |
pub async fn changed(&self) { | |
let mut latest_version; | |
loop { | |
latest_version = self.shared.version.load(Ordering::Relaxed); | |
let seen_version = self.seen_version.load(Ordering::Relaxed); | |
if seen_version != latest_version { | |
break; | |
} | |
self.shared.notify.notified().await; | |
} | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
} | |
/// Waits for the value to match a predicate. | |
/// | |
/// This method immediately calls the closure on the current value and every subsequent | |
/// value until it returns true. Once it returns true, that value is marked seen and the | |
/// future resolves. | |
#[allow(clippy::future_not_send, reason = "Send when T is Send")] | |
pub async fn wait_for(&self, mut f: impl FnMut(&T) -> bool) { | |
let mut latest_version; | |
loop { | |
{ | |
let guard = self.shared.data.read().unwrap(); | |
latest_version = self.shared.version.load(Ordering::Relaxed); | |
let seen_version = self.seen_version.load(Ordering::Relaxed); | |
if seen_version != latest_version && f(&*guard) { | |
break; | |
} | |
} | |
self.shared.notify.notified().await; | |
} | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
} | |
} | |
impl<T> Clone for WatchHandle<T> { | |
fn clone(&self) -> Self { | |
WatchHandle { | |
shared: self.shared.clone(), | |
seen_version: self.seen_version.load(Ordering::Relaxed).into(), | |
} | |
} | |
} | |
impl<T> std::fmt::Debug for WatchHandle<T> | |
where | |
T: std::fmt::Debug, | |
{ | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
f.debug_struct("WatchHandle") | |
.field("shared", &self.shared) | |
.field("seen_version", &self.seen_version) | |
.finish() | |
} | |
} | |
impl<T> WatchWeak<T> { | |
delegate::delegate! { | |
to self.shared { | |
pub fn strong_count(&self) -> usize; | |
pub fn weak_count(&self) -> usize; | |
} | |
} | |
pub fn upgrade(&self) -> Option<WatchHandle<T>> { | |
self.shared.upgrade().map(|shared| WatchHandle { | |
shared, | |
seen_version: self.seen_version.load(Ordering::Relaxed).into(), | |
}) | |
} | |
pub fn borrow(&self) -> Option<OwnedWatchRef<T>> { | |
self.shared.upgrade().map(OwnedWatchRef::new) | |
} | |
/// Return a reference to the current value and mark it as seen. | |
/// | |
/// The returned guard holds a read lock on the value. | |
pub fn borrow_and_update(&self) -> Option<OwnedWatchRef<T>> { | |
match self.shared.upgrade() { | |
None => None, | |
Some(arc) => { | |
let guard = OwnedWatchRef::new(arc); | |
let latest_version = guard.0.version.load(Ordering::Relaxed); | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
Some(guard) | |
} | |
} | |
} | |
pub fn set(&self, value: T) -> Result<T, WatchWeakError<T>> { | |
match self.shared.upgrade() { | |
Some(shared) => Ok(shared.set(value, &self.seen_version)), | |
None => Err(WatchWeakError(value)), | |
} | |
} | |
/// Update the current value in-place and notify. | |
/// | |
/// The new value will be marked seen. | |
/// | |
/// Returns true if there were any strong references at the time of call. | |
pub fn update(&self, modify: impl FnOnce(&mut T)) -> bool { | |
match self.shared.upgrade() { | |
Some(shared) => { | |
shared.update(modify, &self.seen_version); | |
true | |
} | |
None => false, | |
} | |
} | |
/// Update the current value in-place, conditionally. | |
/// | |
/// The closure must return `true` if the value has actually been modified, in order to send | |
/// change notifications. The value will be marked seen regardless. | |
/// | |
/// Returns true if there were any strong references at the time of call. | |
pub fn maybe_update(&self, modify: impl FnOnce(&mut T) -> bool) -> bool { | |
match self.shared.upgrade() { | |
Some(shared) => { | |
shared.maybe_update(modify, &self.seen_version); | |
true | |
} | |
None => false, | |
} | |
} | |
/// Checks if the value has changed. | |
/// | |
/// The value is marked as "unchanged" by [Self::borrow_and_update], [Self::changed], and [Self::mark_unchanged]. | |
pub fn has_changed(&self) -> Option<bool> { | |
self.shared.upgrade().map(|shared| { | |
let latest_version = shared.version.load(Ordering::Relaxed); | |
let seen_version = self.seen_version.load(Ordering::Relaxed); | |
seen_version != latest_version | |
}) | |
} | |
/// Mark the current value as unseen. | |
pub fn mark_changed(&self) { | |
let Some(shared) = self.shared.upgrade() else { | |
return; | |
}; | |
let latest_version = shared.version.load(Ordering::Relaxed); | |
self.seen_version | |
.store(latest_version.wrapping_sub(1), Ordering::Relaxed); | |
} | |
/// Mark the current value as seen. | |
pub fn mark_unchanged(&self) { | |
let Some(shared) = self.shared.upgrade() else { | |
return; | |
}; | |
let latest_version = shared.version.load(Ordering::Relaxed); | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
} | |
/// Waits for the value to change, then marks the value as seen. | |
/// | |
/// This method ignores changes that originate from this handle, including | |
/// [Self::mark_changed]. It also fails if the last strong reference to the handle is | |
/// dropped. | |
/// | |
/// If you are using this method in a loop, it is recommended to combine it with | |
/// [Self::borrow_and_update] instead of [Self::borrow]. This avoid a race condition where the | |
/// value is changed after `changed` resolves but before `borrow` starts, which would result in | |
/// the next call to `changed` immediately resolving, and seeing the same value in the | |
/// subsequent `borrow`. | |
#[allow(clippy::future_not_send, reason = "Send when T is Send")] | |
pub async fn changed(&self) -> Result<(), WatchWeakError<()>> { | |
let mut latest_version; | |
loop { | |
let Some(shared) = self.shared.upgrade() else { | |
return Err(WatchWeakError(())); | |
}; | |
latest_version = shared.version.load(Ordering::Relaxed); | |
drop(shared); | |
let seen_version = self.seen_version.load(Ordering::Relaxed); | |
if seen_version != latest_version { | |
break; | |
} | |
self.notify.notified().await; | |
} | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
Ok(()) | |
} | |
/// Waits for the value to match a predicate. | |
/// | |
/// This method immediately calls the closure on the current value and every subsequent | |
/// value until it returns true. Once it returns true, that value is marked seen and the | |
/// future resolves. | |
#[allow(clippy::future_not_send, reason = "Send when T is Send")] | |
pub async fn wait_for(&self, mut f: impl FnMut(&T) -> bool) -> Result<(), WatchWeakError<()>> { | |
let mut latest_version; | |
loop { | |
{ | |
let Some(shared) = self.shared.upgrade() else { | |
return Err(WatchWeakError(())); | |
}; | |
let guard = shared.data.read().unwrap(); | |
latest_version = shared.version.load(Ordering::Relaxed); | |
let seen_version = self.seen_version.load(Ordering::Relaxed); | |
if seen_version != latest_version && f(&*guard) { | |
break; | |
} | |
} | |
self.notify.notified().await; | |
} | |
self.seen_version.store(latest_version, Ordering::Relaxed); | |
Ok(()) | |
} | |
} | |
impl<T> Clone for WatchWeak<T> { | |
fn clone(&self) -> Self { | |
WatchWeak { | |
shared: self.shared.clone(), | |
notify: self.notify.clone(), | |
seen_version: self.seen_version.load(Ordering::Relaxed).into(), | |
} | |
} | |
} | |
impl<T> std::fmt::Debug for WatchWeak<T> { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
write!(f, "(WatchWeak)") | |
} | |
} | |
macro_rules! mark_changed_and_seen { | |
($self: ident, $seen_version: ident) => { | |
$seen_version.store( | |
$self | |
.version | |
.fetch_add(1, Ordering::Relaxed) | |
.wrapping_add(1), | |
Ordering::Relaxed, | |
); | |
$self.notify.notify_waiters(); | |
}; | |
} | |
impl<T> WatchShared<T> { | |
fn set(&self, mut value: T, seen_version: &AtomicUsize) -> T { | |
let mut guard = self.data.write().unwrap(); | |
std::mem::swap(&mut value, &mut *guard); | |
mark_changed_and_seen!(self, seen_version); | |
value | |
} | |
pub fn update(&self, modify: impl FnOnce(&mut T), seen_version: &AtomicUsize) { | |
let mut guard = self.data.write().unwrap(); | |
modify(&mut *guard); | |
mark_changed_and_seen!(self, seen_version); | |
} | |
pub fn maybe_update(&self, modify: impl FnOnce(&mut T) -> bool, seen_version: &AtomicUsize) { | |
let mut guard = self.data.write().unwrap(); | |
let modified = modify(&mut *guard); | |
if modified { | |
mark_changed_and_seen!(self, seen_version); | |
} | |
} | |
} | |
impl<T> Drop for WatchShared<T> { | |
fn drop(&mut self) { | |
self.notify.notify_waiters(); | |
} | |
} | |
impl<T> std::fmt::Debug for WatchShared<T> | |
where | |
T: std::fmt::Debug, | |
{ | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
let mut d = f.debug_struct("WatchHandle"); | |
match self.data.try_read() { | |
Ok(guard) => { | |
d.field("data", &&*guard); | |
} | |
Err(TryLockError::Poisoned(err)) => { | |
d.field("data", &&**err.get_ref()); | |
} | |
Err(TryLockError::WouldBlock) => { | |
d.field("data", &format_args!("<locked>")); | |
} | |
} | |
d.field("poisoned", &self.data.is_poisoned()); | |
d.field("version", &self.version); | |
d.finish_non_exhaustive() | |
} | |
} | |
impl<T> std::ops::Deref for WatchRef<'_, T> { | |
type Target = T; | |
fn deref(&self) -> &Self::Target { | |
&self.0 | |
} | |
} | |
impl<'a, T> OwnedWatchRef<'a, T> { | |
fn new(arc: Arc<WatchShared<T>>) -> Self { | |
let guard = arc.data.read().unwrap(); | |
// SAFETY: arc cannot be dropped while guard is held | |
let guard: RwLockReadGuard<'a, T> = unsafe { std::mem::transmute(guard) }; | |
Self(arc, guard) | |
} | |
} | |
impl<T> std::ops::Deref for OwnedWatchRef<'_, T> { | |
type Target = T; | |
fn deref(&self) -> &Self::Target { | |
&self.1 | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use assert_matches::assert_matches; | |
use std::task::Poll; | |
#[test] | |
fn test_watch_handle_ref_counts() { | |
let handle = WatchHandle::new(()); | |
let weak = handle.downgrade(); | |
assert_eq!(handle.strong_count(), 1); | |
assert_eq!(weak.strong_count(), 1); | |
assert_eq!(handle.weak_count(), 1); | |
assert_eq!(weak.weak_count(), 1); | |
let third = weak.upgrade().unwrap(); | |
drop(weak); | |
assert_eq!(handle.weak_count(), 0); | |
assert!(third.ptr_eq(&handle)); | |
} | |
#[test] | |
fn test_watch_handle_set() { | |
let handle = WatchHandle::new(0); | |
assert_eq!(*handle.borrow(), 0); | |
assert!(!handle.has_changed()); | |
handle.set(1); | |
assert_eq!(*handle.borrow(), 1); | |
assert!(!handle.has_changed()); | |
handle.update(|x| { | |
*x += 1; | |
}); | |
assert_eq!(*handle.borrow(), 2); | |
assert!(!handle.has_changed()); | |
} | |
#[test] | |
fn test_watch_handle_maybe_update() { | |
let handle = WatchHandle::new(0); | |
let clone = handle.clone(); | |
handle.maybe_update(|x| { | |
*x += 1; | |
true | |
}); | |
assert!(!handle.has_changed()); | |
assert!(clone.has_changed()); | |
clone.mark_unchanged(); | |
handle.maybe_update(|_| false); | |
assert!(!handle.has_changed()); | |
assert!(!clone.has_changed()); | |
} | |
#[tokio::test] | |
async fn test_watch_handle_changed() { | |
let handle = WatchHandle::new(0); | |
let other = handle.clone(); | |
let handle_changed = handle.changed(); | |
let other_changed = other.changed(); | |
futures::pin_mut!(handle_changed, other_changed); | |
assert!(!handle.has_changed()); | |
assert!(!other.has_changed()); | |
assert_eq!(futures::poll!(&mut handle_changed), Poll::Pending); | |
assert_eq!(futures::poll!(&mut other_changed), Poll::Pending); | |
handle.set(1); | |
assert!(!handle.has_changed()); | |
assert!(other.has_changed()); | |
assert_eq!(futures::poll!(&mut handle_changed), Poll::Pending); | |
assert_eq!(futures::poll!(&mut other_changed), Poll::Ready(())); | |
other.mark_unchanged(); | |
let other_changed = other.changed(); | |
futures::pin_mut!(other_changed); | |
assert!(!handle.has_changed()); | |
assert!(!other.has_changed()); | |
assert_eq!(futures::poll!(&mut handle_changed), Poll::Pending); | |
assert_eq!(futures::poll!(other_changed), Poll::Pending); | |
other.mark_changed(); | |
assert!(other.has_changed()); | |
other.borrow(); | |
assert!(other.has_changed()); | |
other.borrow_and_update(); | |
assert!(!other.has_changed()); | |
} | |
#[tokio::test] | |
async fn test_watch_handle_wait_for() { | |
let handle = WatchHandle::new(0); | |
let other = handle.clone(); | |
let wait_for = other.wait_for(|x| *x == 2); | |
futures::pin_mut!(wait_for); | |
handle.set(1); | |
assert_eq!(futures::poll!(&mut wait_for), Poll::Pending); | |
handle.set(2); | |
assert_eq!(futures::poll!(&mut wait_for), Poll::Ready(())); | |
} | |
#[test] | |
fn test_watch_weak_set() { | |
let handle = WatchHandle::new(0); | |
let weak = handle.downgrade(); | |
assert_eq!(*weak.borrow().unwrap(), 0); | |
assert!(!weak.has_changed().unwrap()); | |
weak.set(1).unwrap(); | |
assert_eq!(*weak.borrow().unwrap(), 1); | |
assert!(!weak.has_changed().unwrap()); | |
assert!(weak.update(|x| { | |
*x += 1; | |
})); | |
assert_eq!(*weak.borrow().unwrap(), 2); | |
assert!(!weak.has_changed().unwrap()); | |
weak.mark_changed(); | |
assert!(weak.has_changed().unwrap()); | |
weak.mark_unchanged(); | |
assert!(!weak.has_changed().unwrap()); | |
weak.mark_changed(); | |
assert!(weak.has_changed().unwrap()); | |
weak.borrow_and_update(); | |
assert!(!weak.has_changed().unwrap()); | |
} | |
#[tokio::test] | |
async fn test_watch_weak_changed() { | |
let strong = WatchHandle::new(0); | |
let handle = strong.downgrade(); | |
let handle_changed = handle.changed(); | |
futures::pin_mut!(handle_changed); | |
assert_eq!(handle.has_changed(), Some(false)); | |
assert_matches!(futures::poll!(&mut handle_changed), Poll::Pending); | |
handle.set(1).unwrap(); | |
assert_eq!(handle.has_changed(), Some(false)); | |
assert_matches!(futures::poll!(&mut handle_changed), Poll::Pending); | |
assert_eq!(handle.has_changed(), Some(false)); | |
assert_matches!(futures::poll!(&mut handle_changed), Poll::Pending); | |
drop(strong); | |
assert_matches!( | |
futures::poll!(&mut handle_changed), | |
Poll::Ready(Err(WatchWeakError(()))) | |
); | |
} | |
#[tokio::test] | |
async fn test_watch_weak_wait_for() { | |
let handle = WatchHandle::new(0); | |
let weak = handle.downgrade(); | |
let fut = weak.wait_for(|x| *x == 2); | |
futures::pin_mut!(fut); | |
handle.set(1); | |
assert_matches!(futures::poll!(&mut fut), Poll::Pending); | |
handle.set(2); | |
assert_matches!(futures::poll!(&mut fut), Poll::Ready(Ok(()))); | |
let fut = weak.wait_for(|_| false); | |
futures::pin_mut!(fut); | |
handle.set(1); | |
assert_matches!(futures::poll!(&mut fut), Poll::Pending); | |
drop(handle); | |
assert_matches!( | |
futures::poll!(&mut fut), | |
Poll::Ready(Err(WatchWeakError(()))) | |
); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment