Last active
December 21, 2023 01:21
-
-
Save rlee287/68cfd829ade065e88171fd32de5eb7cb to your computer and use it in GitHub Desktop.
Proposed OnceStorage for `oncecell::race`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use core::sync::atomic::{Ordering, AtomicU32}; | |
use core::mem::MaybeUninit; | |
use core::cell::UnsafeCell; | |
use core::hint::spin_loop; | |
use core::convert::Infallible; | |
/// A thread-safe cell which can only be written to once. | |
pub struct OnceStorage<T> { | |
/// The actual storage of the stored object | |
data_holder: UnsafeCell<MaybeUninit<T>>, | |
/// Tracks whether the OnceStorage has been initialized | |
/// 0 -> no | |
/// 1 -> write in progress (because writing to data_holder is not atomic) | |
/// 2 -> init | |
/// This value can only ever increase in increments of 1 | |
is_init: AtomicU32, | |
#[cfg(debug)] | |
#[cfg_attr(debug, doc(hidden))] | |
/// Helper counter to assert that the critical section is, in fact, only entered by one thread at a time | |
critical_section_ctr: AtomicU32 | |
} | |
impl<T> OnceStorage<T> { | |
/// Creates a new empty cell. | |
#[inline] | |
pub const fn new() -> Self { | |
Self { | |
data_holder: UnsafeCell::new(MaybeUninit::uninit()), | |
is_init: AtomicU32::new(0) | |
} | |
} | |
/// Gets a reference to the underlying value. | |
pub fn get(&self) -> Option<&T> { | |
let state_snapshot = self.is_init.load(Ordering::Acquire); | |
if state_snapshot == 2 { | |
#[cfg(debug)] | |
assert_eq!(self.critical_section_ctr.load(Ordering::SeqCst), 0); | |
// SAFETY: 2 -> value is init and nobody is trying to change it | |
unsafe { | |
let mut_ptr = self.data_holder.get(); | |
Some((&*mut_ptr as &MaybeUninit<T>).assume_init_ref()) | |
} | |
} else { | |
debug_assert!(state_snapshot <= 1); | |
None | |
} | |
} | |
/// Forcibly sets the value of the cell and returns a mutable reference | |
/// to the new value. | |
/// | |
/// SAFETY: The internal state must be set to the intermediate state before | |
/// this is called. If the internal value was already set, then | |
/// this function overwrites it without dropping it. | |
unsafe fn force_set(&self, value: T) -> &mut T { | |
#[cfg(debug)] | |
assert_eq!(self.critical_section_ctr.fetch_add(1, Ordering::SeqCst), 0); | |
let value_ref: &mut T; | |
unsafe { | |
let mut_ptr = self.data_holder.get(); | |
value_ref = (&mut *mut_ptr as &mut MaybeUninit<T>).write(value); | |
} | |
#[cfg(debug)] | |
assert_eq!(self.critical_section_ctr.fetch_sub(1, Ordering::SeqCst), 0); | |
value_ref | |
} | |
/// Sets the contents of this cell to value. | |
/// | |
/// Returns `Ok(())` if the cell was empty and `Err(value)` if it was full. | |
pub fn set(&self, value: T) -> Result<(), T> { | |
// Indicate that we are now trying to set the value | |
// If someone else is also trying, back off and let them go through | |
// On success we wish to release the new value, already knowing it | |
// On failure we don't need an ordering, as the failure already forms | |
// a happens-before relationship between their set and our check | |
if self.is_init.compare_exchange(0, 1, Ordering::Release, Ordering::Relaxed).is_err() { | |
return Err(value); | |
} | |
// SAFETY: state==1 -> nobody else is touching the UnsafeCell -> we can safely obtain &mut | |
unsafe { | |
self.force_set(value); | |
} | |
// Indicate that we have successfully written the value | |
if self.is_init.swap(2, Ordering::AcqRel) != 1 { | |
unreachable!("Concurrent modification to self.data_holder despite state signalling") | |
} | |
return Ok(()) | |
} | |
/// Gets the contents of the cell, initializing it with `f` if the cell was | |
/// empty. | |
/// | |
/// If several threads concurrently run `get_or_init`, more than one `f` can | |
/// be called. However, all threads will return the same value, produced by | |
/// some `f`. If this instance of `f` finishes while another instance is | |
/// writing its value, then this function will spinloop until that instance | |
/// finishes writing the new value before returning a reference to the | |
/// initialized value. | |
pub fn get_or_init_spin<F>(&self, f: F) -> &T | |
where | |
F: FnOnce() -> T, | |
{ | |
let fn_wrap = || { | |
Ok::<T, Infallible> (f()) | |
}; | |
self.get_or_try_init_spin(fn_wrap).unwrap() | |
} | |
/// Gets the contents of the cell, initializing it with `f` if | |
/// the cell was empty. If the cell was empty and `f` failed, an | |
/// error is returned. | |
/// | |
/// If several threads concurrently run `get_or_init`, more than one `f` can | |
/// be called. However, all threads will return the same value, produced by | |
/// some `f`. If this instance of `f` finishes while another instance is | |
/// writing its value, then this function will spinloop until that instance | |
/// finishes writing the new value before returning a reference to the | |
/// initialized value. | |
pub fn get_or_try_init_spin<F, E>(&self, f: F) -> Result<&T, E> | |
where | |
F: FnOnce() -> Result<T, E> | |
{ | |
let mut state_snapshot = self.is_init.load(Ordering::Acquire); | |
if state_snapshot == 0 { | |
let f_value = f()?; | |
// Indicate that we are now trying to set the value | |
// If someone else is also trying, break out and wait for the other write to go through | |
// On success we wish to release the new value, without needing to acquire it again | |
// On failure we need to acquire the actual value and loop again | |
match self.is_init.compare_exchange(0, 1, Ordering::Release, Ordering::Acquire) { | |
Ok(_) => { | |
// SAFETY: state==1 -> nobody else is touching the UnsafeCell -> we can safely obtain &mut | |
let new_ref = unsafe {self.force_set(f_value)}; | |
// Indicate that we have successfully written the value | |
if self.is_init.swap(2, Ordering::AcqRel) != 1 { | |
unreachable!("Concurrent modification to self.data_holder despite state signalling") | |
} | |
return Ok(new_ref as &T); | |
}, | |
Err(new_state) => { | |
state_snapshot = new_state; | |
debug_assert!(state_snapshot==1 || state_snapshot==2); | |
} | |
} | |
} | |
while state_snapshot == 1 { | |
// 1 -> someone else is currently writing | |
// Writes (should be) fast so we won't be spinning for long | |
state_snapshot = self.is_init.load(Ordering::Acquire); | |
spin_loop(); | |
} | |
debug_assert_eq!(state_snapshot, 2); | |
unsafe { | |
let mut_ptr = self.data_holder.get(); | |
return Ok((&*mut_ptr as &MaybeUninit<T>).assume_init_ref()); | |
} | |
} | |
/// ``` compile_fail | |
/// # use tiva_c_secure_bootloader_rustlib::OnceStorage; | |
/// # | |
/// // Ensure that OnceStorage<T> is invariant over T lifetime subtypes | |
/// let heap_object = std::vec::Vec::from([1,2,3,4]); | |
/// let once_storage = OnceStorage::new(); | |
/// once_storage.set(&heap_object).unwrap(); | |
/// drop(heap_object); | |
/// // The stored reference is no longer live because vec is dropped | |
/// // The following line should fail to compile | |
/// let _ref = once_storage.get(); | |
/// ``` | |
fn _dummy() {} | |
} | |
impl<T> Drop for OnceStorage<T> { | |
fn drop(&mut self) { | |
let state = self.is_init.load(Ordering::Acquire); | |
// &mut self -> nobody else can try to init -> value can't be 1 | |
// If we somehow do, then we leak the set value, which is safer than | |
// incorrectly freeing it | |
debug_assert_ne!(state, 1); | |
if state == 2 { | |
unsafe { | |
let mut_ptr = self.data_holder.get(); | |
(&mut *mut_ptr as &mut MaybeUninit<T>).assume_init_drop(); | |
} | |
} | |
} | |
} | |
unsafe impl<T: Send+Sync> Sync for OnceStorage<T> {} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
extern crate std; | |
#[test] | |
fn test_should_compile_static() { | |
let heap_object = std::vec::Vec::from([1,2,3,4]); | |
let once_storage = OnceStorage::new(); | |
once_storage.set(heap_object).unwrap(); | |
let _ref = once_storage.get(); | |
assert_eq!(_ref.unwrap(), &[1,2,3,4]); | |
drop(once_storage); | |
} | |
#[test] | |
fn test_init_only_once() { | |
const THREAD_COUNT: usize = 20; | |
use std::sync::Barrier; | |
use std::sync::atomic::AtomicU32; | |
let init_ctr = AtomicU32::new(0); | |
let barrier_obj = Barrier::new(THREAD_COUNT+1); | |
let once_storage = OnceStorage::new(); | |
std::thread::scope(|s| { | |
// Start the threads... | |
for _ in 0..THREAD_COUNT { | |
s.spawn(|| { | |
barrier_obj.wait(); | |
if once_storage.set(std::vec::Vec::from([std::string::String::from("abcd")])).is_ok() { | |
init_ctr.fetch_add(1, Ordering::Relaxed); | |
} | |
}); | |
} | |
// ...and let them hammer the OnceStorage | |
barrier_obj.wait(); | |
}); | |
// Ensure that writes to init_ctr are now visible | |
std::sync::atomic::fence(Ordering::Acquire); | |
// Check that object was only initialized once | |
assert_eq!(init_ctr.load(Ordering::Acquire), 1); | |
// Now read from the vec so that Miri can catch invalid accesses | |
assert_eq!(once_storage.get().unwrap().len(), 1); | |
assert_eq!(once_storage.get().unwrap()[0], "abcd"); | |
} | |
#[test] | |
fn test_should_compile_nonstatic() { | |
let heap_object = std::vec::Vec::from([1,2,3,4]); | |
let once_storage = OnceStorage::new(); | |
once_storage.set(&heap_object).unwrap(); | |
let _ref = once_storage.get(); | |
drop(once_storage); | |
drop(heap_object); | |
} | |
} |
jyn514
commented
Dec 21, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment