Last active
May 23, 2022 18:56
-
-
Save Lucretiel/65a5929c4de56613cc501dde8ab540d1 to your computer and use it in GitHub Desktop.
A paired smart pointer. The value owned by a pair of Joint objects is dropped as soon as *either* of the Joints are dropped
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
marker::PhantomData, | |
mem::MaybeUninit, | |
ops::Deref, | |
process::abort, | |
ptr::NonNull, | |
sync::atomic::{AtomicU32, Ordering}, | |
}; | |
const MAX_REFCOUNT: u32 = i32::MAX as u32; | |
struct JointContainer<T> { | |
value: MaybeUninit<T>, | |
// Special states: | |
// | |
// 0: It has been dropped. No new handles should be created. Only one | |
// handle exists, and when it's dropped, the container can be freed. | |
// 1: It is being dropped. The dropping thread will need to check the | |
// state after it's done. | |
// Normal states: | |
// 2+: there are N handles in existence. When the count drops to 1, we | |
// begin to drop. | |
count: AtomicU32, | |
} | |
#[repr(transparent)] | |
pub struct Joint<T> { | |
container: NonNull<JointContainer<T>>, | |
phantom: PhantomData<JointContainer<T>>, | |
} | |
unsafe impl<T: Send + Sync> Send for Joint<T> {} | |
unsafe impl<T: Send + Sync> Sync for Joint<T> {} | |
impl<T> Joint<T> { | |
// Note that, while it's guaranteed that the container exists, it's not | |
// guaranteed that the value is in an initialized state. | |
fn container(&self) -> &JointContainer<T> { | |
unsafe { self.container.as_ref() } | |
} | |
pub fn new(value: T) -> (Self, Self) { | |
let container = Box::new(JointContainer { | |
value: MaybeUninit::new(value), | |
count: AtomicU32::new(2), | |
}); | |
let container = NonNull::new(Box::into_raw(container)).expect("box is definitely non null"); | |
( | |
Joint { | |
container, | |
phantom: PhantomData, | |
}, | |
Joint { | |
container, | |
phantom: PhantomData, | |
}, | |
) | |
} | |
pub fn lock(&self) -> Option<JointLock<'_, T>> { | |
// Increasing the reference count can always be done with Relaxed– New | |
// references to an object can only be formed from an existing | |
// reference, and passing an existing reference from one thread to | |
// another must already provide any required synchronization. | |
let mut current = self.container().count.load(Ordering::Relaxed); | |
loop { | |
// We can only lock this if *both* handles currently exist. | |
// TODO: prevent the distribution of new locks after the other | |
// handle has dropped (currently, if this handle has some | |
// outstanding locks, it may create more). In general we're not | |
// worried because the typical usage pattern is that each joint | |
// will only ever make 1 lock at a time. | |
current = match current { | |
0 | 1 => break None, | |
n if n > MAX_REFCOUNT => abort(), | |
n => match self.container().count.compare_exchange_weak( | |
n, | |
n + 1, | |
Ordering::Relaxed, | |
Ordering::Relaxed, | |
) { | |
Ok(_) => { | |
break Some(JointLock { | |
container: self.container, | |
lifetime: PhantomData, | |
}) | |
} | |
Err(n) => n, | |
}, | |
} | |
} | |
} | |
} | |
impl<T> Drop for Joint<T> { | |
fn drop(&mut self) { | |
let mut current = self.container().count.load(Ordering::Acquire); | |
// Note that all of the failures in the compare-exchanges here are | |
// Acquire ordering, because failures could indicate that the other | |
// handle dropped, meaning that we need to acquire its changes before | |
// we start dropping or deallocating anything. Additionally, note that | |
// we *usually* don't need to release anything here, because `Joint` | |
// isn't itself capable of writing to `value` (only JointLock can do | |
// that, and it *does* release on drop.) | |
loop { | |
current = match current { | |
// The handle has been fully dropped, this is the last | |
// remaining handle in existence | |
0 => { | |
drop(unsafe { Box::from_raw(self.container.as_ptr()) }); | |
return; | |
} | |
n => match self.container().count.compare_exchange_weak( | |
n, | |
n - 1, | |
Ordering::Acquire, | |
Ordering::Acquire, | |
) { | |
// All failures, spurious or otherwise, need to be retried. | |
// There's no "fast escape" case because we always need to | |
// ensure that n - 1 was stored. | |
Err(n) => n, | |
// Another thread is in the middle of dropping the value. | |
// We stored a 0, so it will also take care of deallocating | |
// the container. | |
Ok(1) => return, | |
// This is the second to last handle in existence, which | |
// means it's time to drop the value. Don't need to release | |
// anything until after the drop is finished; other threads | |
// won't be touching the value while we're in this state. | |
Ok(2) => { | |
unsafe { (*self.container.as_ptr()).value.assume_init_drop() }; | |
loop { | |
// At this point we need to release store the 0, to | |
// ensure our drop propagates to other threads. We | |
// did the drop, so there's no other changes we | |
// might need to acquire. If we find there's already | |
// a zero, the last handle dropped, so we handle | |
// deallocating. | |
match self.container().count.compare_exchange_weak( | |
1, | |
0, | |
// Don't need to acquire in this case because we also did the | |
// drop ourselves. | |
Ordering::Release, | |
Ordering::Relaxed, | |
) { | |
// We stored a zero; the other Joint will be responsible | |
// for deallocating the container | |
Ok(_) => return, | |
// There was already a 0; the last handle dropped while we | |
// were dropping the value. Deallocate. | |
// | |
// There's no risk of another thread loading this same 0, because | |
// we know the only other reference in existence is the other Joint. | |
// we stored a 1, so it can never create more locks; either it will | |
// store a 0 (detected here) or we'll store a 0 that it will load. | |
Err(0) => { | |
drop(unsafe { Box::from_raw(self.container.as_ptr()) }); | |
return; | |
} | |
// Spurious failure; retry | |
Err(1) => continue, | |
// It's never possible for the count to transition from 1 to | |
// any value other than 0 or 1 | |
Err(_) => unreachable!(), | |
} | |
} | |
} | |
// There are plenty of handles in existence; the decrement we | |
// performed is the only thing that needed to happen. | |
Ok(_) => return, | |
}, | |
} | |
} | |
} | |
} | |
#[repr(transparent)] | |
pub struct JointLock<'a, T> { | |
container: NonNull<JointContainer<T>>, | |
lifetime: PhantomData<&'a Joint<T>>, | |
} | |
unsafe impl<T: Send + Sync> Send for JointLock<'_, T> {} | |
unsafe impl<T: Send + Sync> Sync for JointLock<'_, T> {} | |
impl<T> JointLock<'_, T> { | |
fn container(&self) -> &JointContainer<T> { | |
unsafe { self.container.as_ref() } | |
} | |
} | |
impl<T> Deref for JointLock<'_, T> { | |
type Target = T; | |
fn deref(&self) -> &Self::Target { | |
// Safety: if a JointLock exists, it's guaranteed that the value will | |
// be alive for at least the duration of the lock | |
unsafe { self.container().value.assume_init_ref() } | |
} | |
} | |
impl<T> Clone for JointLock<'_, T> { | |
fn clone(&self) -> Self { | |
// Safety: if a jointlock exists, it's guaranteed that the value will | |
// continue to exist, so we can do a simple count increment. | |
// TODO: this could just be a fetch_add, but we need to guard against overflow | |
let old_count = self.container().count.fetch_add(1, Ordering::Relaxed); | |
if old_count > MAX_REFCOUNT { | |
abort() | |
} | |
JointLock { | |
container: self.container, | |
lifetime: PhantomData, | |
} | |
} | |
fn clone_from(&mut self, source: &Self) { | |
if self.container != source.container { | |
*self = JointLock::clone(source) | |
} | |
} | |
} | |
impl<T> Drop for JointLock<'_, T> { | |
fn drop(&mut self) { | |
// The logic here can be a little simpler than Joint, because we're | |
// guaranteed that there's at least one other handle in existence (our | |
// parent), and that it definitely won't be dropped before we're done | |
// being dropped. | |
// - Need to acquire any changes made by other threads before dropping | |
// - Need to release any changes made by *this* thread so that it | |
// can be dropped by another thread | |
match self.container().count.fetch_sub(1, Ordering::AcqRel) { | |
// The count must be at LEAST 2: one for us and one for our parent | |
0 | 1 => unreachable!(), | |
// If the count was 2, it means that this was the last lock. We've | |
// already stored the decrement, which means we've taken | |
// responsibility for attempting to drop (and that future attempts | |
// to lock will now fail) | |
2 => unsafe { (*self.container.as_ptr()).value.assume_init_drop() }, | |
// If the count is higher than two, the value is still alive | |
_ => {} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment