Skip to content

Instantly share code, notes, and snippets.

@wyfo
Last active July 10, 2023 05:11
Show Gist options
  • Save wyfo/a9ce6b6c23914688b94ea48ab2badd48 to your computer and use it in GitHub Desktop.
Save wyfo/a9ce6b6c23914688b94ea48ab2badd48 to your computer and use it in GitHub Desktop.
`AtomicWaker` optimized implementation
#![no_std]
use core::{
cell::UnsafeCell,
mem::MaybeUninit,
sync::atomic::{AtomicU8, Ordering},
task::Waker,
};
const EMPTY: u8 = 0b000;
const REGISTERING: u8 = 0b001;
const REGISTERED: u8 = 0b010;
const TAKING_FLAG: u8 = 0b100;
pub struct AtomicWaker {
waker: UnsafeCell<MaybeUninit<Waker>>,
state: AtomicU8,
}
unsafe impl Send for AtomicWaker {}
unsafe impl Sync for AtomicWaker {}
impl AtomicWaker {
#[inline]
pub const fn new() -> Self {
Self {
waker: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(EMPTY),
}
}
#[inline]
pub fn register(&self, waker: &Waker) {
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state != EMPTY && state != REGISTERED {
waker.wake_by_ref();
return;
}
match self.state.compare_exchange_weak(
state,
REGISTERING,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(s) => state = s,
}
}
let waker_cell = unsafe { &mut *self.waker.get() };
if state == EMPTY {
waker_cell.write(waker.clone());
} else if !waker.will_wake(unsafe { waker_cell.assume_init_ref() }) {
unsafe { waker_cell.assume_init_read() };
waker_cell.write(waker.clone());
}
if let Err(state) = self.state.compare_exchange(
REGISTERING,
REGISTERED,
Ordering::AcqRel,
Ordering::Acquire,
) {
debug_assert_eq!(state, REGISTERING | TAKING_FLAG);
unsafe { waker_cell.assume_init_read() }.wake();
self.state.store(EMPTY, Ordering::Release);
}
}
#[inline]
pub fn take_or_wake(&self) -> Option<Waker> {
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state != REGISTERING && state != REGISTERED {
return None;
}
match self.state.compare_exchange_weak(
state,
state | TAKING_FLAG,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(REGISTERED) => break,
Ok(_) => return None,
Err(s) => state = s,
}
};
let waker = Some(unsafe { (*self.waker.get()).assume_init_read() });
self.state.store(EMPTY, Ordering::Release);
waker
}
#[inline]
pub fn wake(&self) {
if let Some(waker) = self.take_or_wake() {
waker.wake();
}
}
}
impl Default for AtomicWaker {
fn default() -> Self {
Self::new()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment