Skip to content

Instantly share code, notes, and snippets.

@CGamesPlay
Created August 8, 2025 04:43
Show Gist options
  • Save CGamesPlay/967f06d3421f19c1bffc4aebb35b5ea3 to your computer and use it in GitHub Desktop.
Save CGamesPlay/967f06d3421f19c1bffc4aebb35b5ea3 to your computer and use it in GitHub Desktop.
//! Proof of concept of a tokio::sync::watch-style channel with WeakSender support
//! Copyright Ryan Patterson - MIT license
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock, RwLockReadGuard, TryLockError, Weak,
};
use tokio::sync::Notify;
/// Thread-safe reference-counted reader-writer lock with change notification.
///
/// Provides thread-safe access to a shared resource like `Arc<RwLock<T>>`, but additionally
/// provides methods to observe when other tasks update the value.
///
/// [WatchWeak] objects created with [WatchHandle::downgrade] are additionally able to receive a
/// notification when the last WatchHandle is dropped.
pub struct WatchHandle<T> {
shared: Arc<WatchShared<T>>,
seen_version: AtomicUsize,
}
/// Weak reference to a [WatchHandle].
pub struct WatchWeak<T> {
shared: Weak<WatchShared<T>>,
notify: Arc<Notify>,
seen_version: AtomicUsize,
}
struct WatchShared<T> {
data: RwLock<T>,
version: AtomicUsize,
notify: Arc<Notify>,
}
pub struct WatchRef<'a, T>(RwLockReadGuard<'a, T>);
pub struct OwnedWatchRef<'a, T>(
#[allow(dead_code, reason = "must maintain strong reference")] Arc<WatchShared<T>>,
RwLockReadGuard<'a, T>,
);
#[derive(Debug, thiserror::Error)]
#[error("WatchHandle was dropped")]
pub struct WatchWeakError<T>(pub T);
impl<T> WatchHandle<T> {
pub fn new(data: T) -> Self {
Self {
shared: Arc::new(WatchShared {
data: RwLock::new(data),
version: AtomicUsize::new(0),
notify: Arc::new(Notify::new()),
}),
seen_version: AtomicUsize::new(0),
}
}
pub fn ptr_eq(&self, other: &WatchHandle<T>) -> bool {
Arc::ptr_eq(&self.shared, &other.shared)
}
pub fn strong_count(&self) -> usize {
Arc::strong_count(&self.shared)
}
pub fn weak_count(&self) -> usize {
Arc::weak_count(&self.shared)
}
pub fn downgrade(&self) -> WatchWeak<T> {
WatchWeak {
shared: Arc::downgrade(&self.shared),
notify: self.shared.notify.clone(),
seen_version: self.seen_version.load(Ordering::Relaxed).into(),
}
}
/// Return a reference to the current value.
///
/// This method does not mark the current value as seen, use [Self::borrow_and_update] to do that in
/// an atomic fashion.
///
/// The returned guard holds a read lock on the value.
pub fn borrow(&self) -> WatchRef<T> {
WatchRef(self.shared.data.read().unwrap())
}
/// Return a reference to the current value and mark it as seen.
///
/// The returned guard holds a read lock on the value.
pub fn borrow_and_update(&self) -> WatchRef<T> {
let guard = self.shared.data.read().unwrap();
let latest_version = self.shared.version.load(Ordering::Relaxed);
self.seen_version.store(latest_version, Ordering::Relaxed);
WatchRef(guard)
}
/// Replace the current value with another and notify.
///
/// The new value will be marked seen.
///
/// Returns the previous value.
pub fn set(&self, value: T) -> T {
self.shared.set(value, &self.seen_version)
}
/// Update the current value in-place and notify.
///
/// The new value will be marked seen.
pub fn update(&self, modify: impl FnOnce(&mut T)) {
self.shared.update(modify, &self.seen_version);
}
/// Update the current value in-place, conditionally.
///
/// The closure must return `true` if the value has actually been modified, in order to send
/// change notifications. The value will be marked seen regardless.
pub fn maybe_update(&self, modify: impl FnOnce(&mut T) -> bool) {
self.shared.maybe_update(modify, &self.seen_version);
}
/// Checks if the value has changed.
///
/// The value is marked as "unchanged" by [Self::borrow_and_update], [Self::changed], and [Self::mark_unchanged].
pub fn has_changed(&self) -> bool {
let latest_version = self.shared.version.load(Ordering::Relaxed);
let seen_version = self.seen_version.load(Ordering::Relaxed);
seen_version != latest_version
}
/// Mark the current value as unseen by this handle.
///
/// **Note**: This method never causes a pending async change notification on any
/// handles, including this one.
pub fn mark_changed(&self) {
let latest_version = self.shared.version.load(Ordering::Relaxed);
self.seen_version
.store(latest_version.wrapping_sub(1), Ordering::Relaxed);
}
/// Mark the current value as seen by this handle.
pub fn mark_unchanged(&self) {
let latest_version = self.shared.version.load(Ordering::Relaxed);
self.seen_version.store(latest_version, Ordering::Relaxed);
}
/// Waits for the next unseen value.
///
/// This method ignores changes that originate from this handle, including
/// [Self::mark_changed].
///
/// If you are using this method in a loop, it is recommended to combine it with
/// [Self::borrow_and_update] instead of [Self::borrow]. This avoid a race condition where the
/// value is changed after `changed` resolves but before `borrow` starts, which would result in
/// the next call to `changed` immediately resolving, and seeing the same value in the
/// subsequent `borrow`.
#[allow(clippy::future_not_send, reason = "Send when T is Send")]
pub async fn changed(&self) {
let mut latest_version;
loop {
latest_version = self.shared.version.load(Ordering::Relaxed);
let seen_version = self.seen_version.load(Ordering::Relaxed);
if seen_version != latest_version {
break;
}
self.shared.notify.notified().await;
}
self.seen_version.store(latest_version, Ordering::Relaxed);
}
/// Waits for the value to match a predicate.
///
/// This method immediately calls the closure on the current value and every subsequent
/// value until it returns true. Once it returns true, that value is marked seen and the
/// future resolves.
#[allow(clippy::future_not_send, reason = "Send when T is Send")]
pub async fn wait_for(&self, mut f: impl FnMut(&T) -> bool) {
let mut latest_version;
loop {
{
let guard = self.shared.data.read().unwrap();
latest_version = self.shared.version.load(Ordering::Relaxed);
let seen_version = self.seen_version.load(Ordering::Relaxed);
if seen_version != latest_version && f(&*guard) {
break;
}
}
self.shared.notify.notified().await;
}
self.seen_version.store(latest_version, Ordering::Relaxed);
}
}
impl<T> Clone for WatchHandle<T> {
fn clone(&self) -> Self {
WatchHandle {
shared: self.shared.clone(),
seen_version: self.seen_version.load(Ordering::Relaxed).into(),
}
}
}
impl<T> std::fmt::Debug for WatchHandle<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WatchHandle")
.field("shared", &self.shared)
.field("seen_version", &self.seen_version)
.finish()
}
}
impl<T> WatchWeak<T> {
delegate::delegate! {
to self.shared {
pub fn strong_count(&self) -> usize;
pub fn weak_count(&self) -> usize;
}
}
pub fn upgrade(&self) -> Option<WatchHandle<T>> {
self.shared.upgrade().map(|shared| WatchHandle {
shared,
seen_version: self.seen_version.load(Ordering::Relaxed).into(),
})
}
pub fn borrow(&self) -> Option<OwnedWatchRef<T>> {
self.shared.upgrade().map(OwnedWatchRef::new)
}
/// Return a reference to the current value and mark it as seen.
///
/// The returned guard holds a read lock on the value.
pub fn borrow_and_update(&self) -> Option<OwnedWatchRef<T>> {
match self.shared.upgrade() {
None => None,
Some(arc) => {
let guard = OwnedWatchRef::new(arc);
let latest_version = guard.0.version.load(Ordering::Relaxed);
self.seen_version.store(latest_version, Ordering::Relaxed);
Some(guard)
}
}
}
pub fn set(&self, value: T) -> Result<T, WatchWeakError<T>> {
match self.shared.upgrade() {
Some(shared) => Ok(shared.set(value, &self.seen_version)),
None => Err(WatchWeakError(value)),
}
}
/// Update the current value in-place and notify.
///
/// The new value will be marked seen.
///
/// Returns true if there were any strong references at the time of call.
pub fn update(&self, modify: impl FnOnce(&mut T)) -> bool {
match self.shared.upgrade() {
Some(shared) => {
shared.update(modify, &self.seen_version);
true
}
None => false,
}
}
/// Update the current value in-place, conditionally.
///
/// The closure must return `true` if the value has actually been modified, in order to send
/// change notifications. The value will be marked seen regardless.
///
/// Returns true if there were any strong references at the time of call.
pub fn maybe_update(&self, modify: impl FnOnce(&mut T) -> bool) -> bool {
match self.shared.upgrade() {
Some(shared) => {
shared.maybe_update(modify, &self.seen_version);
true
}
None => false,
}
}
/// Checks if the value has changed.
///
/// The value is marked as "unchanged" by [Self::borrow_and_update], [Self::changed], and [Self::mark_unchanged].
pub fn has_changed(&self) -> Option<bool> {
self.shared.upgrade().map(|shared| {
let latest_version = shared.version.load(Ordering::Relaxed);
let seen_version = self.seen_version.load(Ordering::Relaxed);
seen_version != latest_version
})
}
/// Mark the current value as unseen.
pub fn mark_changed(&self) {
let Some(shared) = self.shared.upgrade() else {
return;
};
let latest_version = shared.version.load(Ordering::Relaxed);
self.seen_version
.store(latest_version.wrapping_sub(1), Ordering::Relaxed);
}
/// Mark the current value as seen.
pub fn mark_unchanged(&self) {
let Some(shared) = self.shared.upgrade() else {
return;
};
let latest_version = shared.version.load(Ordering::Relaxed);
self.seen_version.store(latest_version, Ordering::Relaxed);
}
/// Waits for the value to change, then marks the value as seen.
///
/// This method ignores changes that originate from this handle, including
/// [Self::mark_changed]. It also fails if the last strong reference to the handle is
/// dropped.
///
/// If you are using this method in a loop, it is recommended to combine it with
/// [Self::borrow_and_update] instead of [Self::borrow]. This avoid a race condition where the
/// value is changed after `changed` resolves but before `borrow` starts, which would result in
/// the next call to `changed` immediately resolving, and seeing the same value in the
/// subsequent `borrow`.
#[allow(clippy::future_not_send, reason = "Send when T is Send")]
pub async fn changed(&self) -> Result<(), WatchWeakError<()>> {
let mut latest_version;
loop {
let Some(shared) = self.shared.upgrade() else {
return Err(WatchWeakError(()));
};
latest_version = shared.version.load(Ordering::Relaxed);
drop(shared);
let seen_version = self.seen_version.load(Ordering::Relaxed);
if seen_version != latest_version {
break;
}
self.notify.notified().await;
}
self.seen_version.store(latest_version, Ordering::Relaxed);
Ok(())
}
/// Waits for the value to match a predicate.
///
/// This method immediately calls the closure on the current value and every subsequent
/// value until it returns true. Once it returns true, that value is marked seen and the
/// future resolves.
#[allow(clippy::future_not_send, reason = "Send when T is Send")]
pub async fn wait_for(&self, mut f: impl FnMut(&T) -> bool) -> Result<(), WatchWeakError<()>> {
let mut latest_version;
loop {
{
let Some(shared) = self.shared.upgrade() else {
return Err(WatchWeakError(()));
};
let guard = shared.data.read().unwrap();
latest_version = shared.version.load(Ordering::Relaxed);
let seen_version = self.seen_version.load(Ordering::Relaxed);
if seen_version != latest_version && f(&*guard) {
break;
}
}
self.notify.notified().await;
}
self.seen_version.store(latest_version, Ordering::Relaxed);
Ok(())
}
}
impl<T> Clone for WatchWeak<T> {
fn clone(&self) -> Self {
WatchWeak {
shared: self.shared.clone(),
notify: self.notify.clone(),
seen_version: self.seen_version.load(Ordering::Relaxed).into(),
}
}
}
impl<T> std::fmt::Debug for WatchWeak<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "(WatchWeak)")
}
}
macro_rules! mark_changed_and_seen {
($self: ident, $seen_version: ident) => {
$seen_version.store(
$self
.version
.fetch_add(1, Ordering::Relaxed)
.wrapping_add(1),
Ordering::Relaxed,
);
$self.notify.notify_waiters();
};
}
impl<T> WatchShared<T> {
fn set(&self, mut value: T, seen_version: &AtomicUsize) -> T {
let mut guard = self.data.write().unwrap();
std::mem::swap(&mut value, &mut *guard);
mark_changed_and_seen!(self, seen_version);
value
}
pub fn update(&self, modify: impl FnOnce(&mut T), seen_version: &AtomicUsize) {
let mut guard = self.data.write().unwrap();
modify(&mut *guard);
mark_changed_and_seen!(self, seen_version);
}
pub fn maybe_update(&self, modify: impl FnOnce(&mut T) -> bool, seen_version: &AtomicUsize) {
let mut guard = self.data.write().unwrap();
let modified = modify(&mut *guard);
if modified {
mark_changed_and_seen!(self, seen_version);
}
}
}
impl<T> Drop for WatchShared<T> {
fn drop(&mut self) {
self.notify.notify_waiters();
}
}
impl<T> std::fmt::Debug for WatchShared<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("WatchHandle");
match self.data.try_read() {
Ok(guard) => {
d.field("data", &&*guard);
}
Err(TryLockError::Poisoned(err)) => {
d.field("data", &&**err.get_ref());
}
Err(TryLockError::WouldBlock) => {
d.field("data", &format_args!("<locked>"));
}
}
d.field("poisoned", &self.data.is_poisoned());
d.field("version", &self.version);
d.finish_non_exhaustive()
}
}
impl<T> std::ops::Deref for WatchRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a, T> OwnedWatchRef<'a, T> {
fn new(arc: Arc<WatchShared<T>>) -> Self {
let guard = arc.data.read().unwrap();
// SAFETY: arc cannot be dropped while guard is held
let guard: RwLockReadGuard<'a, T> = unsafe { std::mem::transmute(guard) };
Self(arc, guard)
}
}
impl<T> std::ops::Deref for OwnedWatchRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.1
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use std::task::Poll;
#[test]
fn test_watch_handle_ref_counts() {
let handle = WatchHandle::new(());
let weak = handle.downgrade();
assert_eq!(handle.strong_count(), 1);
assert_eq!(weak.strong_count(), 1);
assert_eq!(handle.weak_count(), 1);
assert_eq!(weak.weak_count(), 1);
let third = weak.upgrade().unwrap();
drop(weak);
assert_eq!(handle.weak_count(), 0);
assert!(third.ptr_eq(&handle));
}
#[test]
fn test_watch_handle_set() {
let handle = WatchHandle::new(0);
assert_eq!(*handle.borrow(), 0);
assert!(!handle.has_changed());
handle.set(1);
assert_eq!(*handle.borrow(), 1);
assert!(!handle.has_changed());
handle.update(|x| {
*x += 1;
});
assert_eq!(*handle.borrow(), 2);
assert!(!handle.has_changed());
}
#[test]
fn test_watch_handle_maybe_update() {
let handle = WatchHandle::new(0);
let clone = handle.clone();
handle.maybe_update(|x| {
*x += 1;
true
});
assert!(!handle.has_changed());
assert!(clone.has_changed());
clone.mark_unchanged();
handle.maybe_update(|_| false);
assert!(!handle.has_changed());
assert!(!clone.has_changed());
}
#[tokio::test]
async fn test_watch_handle_changed() {
let handle = WatchHandle::new(0);
let other = handle.clone();
let handle_changed = handle.changed();
let other_changed = other.changed();
futures::pin_mut!(handle_changed, other_changed);
assert!(!handle.has_changed());
assert!(!other.has_changed());
assert_eq!(futures::poll!(&mut handle_changed), Poll::Pending);
assert_eq!(futures::poll!(&mut other_changed), Poll::Pending);
handle.set(1);
assert!(!handle.has_changed());
assert!(other.has_changed());
assert_eq!(futures::poll!(&mut handle_changed), Poll::Pending);
assert_eq!(futures::poll!(&mut other_changed), Poll::Ready(()));
other.mark_unchanged();
let other_changed = other.changed();
futures::pin_mut!(other_changed);
assert!(!handle.has_changed());
assert!(!other.has_changed());
assert_eq!(futures::poll!(&mut handle_changed), Poll::Pending);
assert_eq!(futures::poll!(other_changed), Poll::Pending);
other.mark_changed();
assert!(other.has_changed());
other.borrow();
assert!(other.has_changed());
other.borrow_and_update();
assert!(!other.has_changed());
}
#[tokio::test]
async fn test_watch_handle_wait_for() {
let handle = WatchHandle::new(0);
let other = handle.clone();
let wait_for = other.wait_for(|x| *x == 2);
futures::pin_mut!(wait_for);
handle.set(1);
assert_eq!(futures::poll!(&mut wait_for), Poll::Pending);
handle.set(2);
assert_eq!(futures::poll!(&mut wait_for), Poll::Ready(()));
}
#[test]
fn test_watch_weak_set() {
let handle = WatchHandle::new(0);
let weak = handle.downgrade();
assert_eq!(*weak.borrow().unwrap(), 0);
assert!(!weak.has_changed().unwrap());
weak.set(1).unwrap();
assert_eq!(*weak.borrow().unwrap(), 1);
assert!(!weak.has_changed().unwrap());
assert!(weak.update(|x| {
*x += 1;
}));
assert_eq!(*weak.borrow().unwrap(), 2);
assert!(!weak.has_changed().unwrap());
weak.mark_changed();
assert!(weak.has_changed().unwrap());
weak.mark_unchanged();
assert!(!weak.has_changed().unwrap());
weak.mark_changed();
assert!(weak.has_changed().unwrap());
weak.borrow_and_update();
assert!(!weak.has_changed().unwrap());
}
#[tokio::test]
async fn test_watch_weak_changed() {
let strong = WatchHandle::new(0);
let handle = strong.downgrade();
let handle_changed = handle.changed();
futures::pin_mut!(handle_changed);
assert_eq!(handle.has_changed(), Some(false));
assert_matches!(futures::poll!(&mut handle_changed), Poll::Pending);
handle.set(1).unwrap();
assert_eq!(handle.has_changed(), Some(false));
assert_matches!(futures::poll!(&mut handle_changed), Poll::Pending);
assert_eq!(handle.has_changed(), Some(false));
assert_matches!(futures::poll!(&mut handle_changed), Poll::Pending);
drop(strong);
assert_matches!(
futures::poll!(&mut handle_changed),
Poll::Ready(Err(WatchWeakError(())))
);
}
#[tokio::test]
async fn test_watch_weak_wait_for() {
let handle = WatchHandle::new(0);
let weak = handle.downgrade();
let fut = weak.wait_for(|x| *x == 2);
futures::pin_mut!(fut);
handle.set(1);
assert_matches!(futures::poll!(&mut fut), Poll::Pending);
handle.set(2);
assert_matches!(futures::poll!(&mut fut), Poll::Ready(Ok(())));
let fut = weak.wait_for(|_| false);
futures::pin_mut!(fut);
handle.set(1);
assert_matches!(futures::poll!(&mut fut), Poll::Pending);
drop(handle);
assert_matches!(
futures::poll!(&mut fut),
Poll::Ready(Err(WatchWeakError(())))
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment