Skip to content

Instantly share code, notes, and snippets.

@kprotty
Created January 9, 2021 01:09
Show Gist options
  • Save kprotty/a0ce34836862745a6a5b51253935f0a1 to your computer and use it in GitHub Desktop.
Save kprotty/a0ce34836862745a6a5b51253935f0a1 to your computer and use it in GitHub Desktop.
use std::{
pin::Pin,
future::Future,
cell::UnsafeCell,
task::{Waker, RawWaker, RawWakerVTable, Poll, Context},
sync::atomic::{AtomicUsize, Ordering},
};
struct VTable {
resume: unsafe fn(*const Task),
waker_clone: unsafe fn(*const Task),
waker_wake: unsafe fn(*const Task, bool),
waker_drop: unsafe fn(*const Task),
}
struct Task {
vtable: &'static VTable,
}
impl Task {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
Self::waker_clone,
Self::waker_wake,
Self::waker_wake_by_ref,
Self::waker_drop,
);
unsafe fn get_waker(&self) -> Waker {
let ptr = self as *const Task as *const ();
let raw_waker = RawWaker::new(ptr, &Self::VTABLE);
Waker::from_raw(raw_waker)
}
unsafe fn waker_clone(ptr: *const ()) -> RawWaker {
let task = &*(ptr as *const Task);
(task.vtable.waker_clone)(task);
RawWaker::new(ptr, &Self::VTABLE)
}
unsafe fn waker_wake(ptr: *const ()) {
let task = &*(ptr as *const Task);
(task.vtable.waker_wake)(task, false)
}
unsafe fn waker_wake_by_ref(ptr: *const ()) {
let task = &*(ptr as *const Task);
(task.vtable.waker_wake)(task, true)
}
unsafe fn waker_drop(ptr: *const ()) {
let task = &*(ptr as *const Task);
(task.vtable.waker_drop)(task)
}
}
const STATE_WAITING: usize = 0;
const STATE_SCHEDULING: usize = 1;
const STATE_RUNNING: usize = 2;
const STATE_NOTIFIED: usize = 3;
#[repr(C)]
struct FutureTask<F> {
task: Task,
state: AtomicUsize,
future: UnsafeCell<F>,
ref_count: AtomicUsize,
}
impl<F: Future> FutureTask<F> {
const VTABLE: VTable = VTable {
resume: Self::resume,
waker_clone: Self::waker_clone,
waker_wake: Self::waker_wake,
waker_drop: Self::waker_drop,
};
fn spawn(future: F) {
let fut_task = Box::new(Self {
task: Task { vtable: &Self::VTABLE },
state: AtomicUsize::new(STATE_WAITING),
future: UnsafeCell::new(future),
ref_count: AtomicUsize::new(0),
});
(unsafe { fut_task.task.get_waker() })
.clone()
.wake()
}
unsafe fn resume(task: *const Task) {
let fut_task = &*(task as *const Self);
let waker = fut_task.task.get_waker();
let mut ctx = Context::from_waker(&waker);
let mut state = fut_task.state.swap(STATE_RUNNING, Ordering::Relaxed);
assert_eq!(state, STATE_SCHEDULING);
'resume: loop {
let future = Pin::new_unchecked(&mut *fut_task.future.get());
match future.poll(&mut ctx) {
Poll::Ready(_) => break,
Poll::Pending => {},
}
state = fut_task.state.load(Ordering::Relaxed);
loop {
let new_state = match state {
STATE_WAITING => unreachable!(),
STATE_RUNNING => STATE_WAITING,
STATE_SCHEDULING => unreachable!(),
STATE_NOTIFIED => STATE_RUNNING,
_ => unreachable!(),
};
if let Err(e) = fut_task.state.compare_exchange_weak(
state,
new_state,
Ordering::Relaxed,
Ordering::Relaxed,
) {
state = e;
continue;
}
match new_state {
STATE_WAITING => break 'resume,
STATE_RUNNING => continue 'resume,
_ => unreachable!(),
}
}
}
Self::waker_drop(task)
}
unsafe fn waker_clone(task: *const Task) {
let fut_task = &*(task as *const Self);
fut_task.ref_count.fetch_add(1, Ordering::Relaxed);
}
unsafe fn waker_drop(task: *const Task) {
let fut_task = &*(task as *const Self);
if fut_task.ref_count.fetch_sub(1, Ordering::Relaxed) == 1 {
std::mem::drop(Box::from_raw(task as *mut Self));
}
}
unsafe fn waker_wake(task: *const Task, by_ref: bool) {
let fut_task = &*(task as *const Self);
let mut state = fut_task.state.load(Ordering::Relaxed);
loop {
let new_state = match state {
STATE_WAITING => STATE_SCHEDULING,
STATE_SCHEDULING => break,
STATE_RUNNING => STATE_NOTIFIED,
STATE_NOTIFIED => break,
_ => unreachable!(),
};
if let Err(e) = fut_task.state.compare_exchange_weak(
state,
new_state,
Ordering::Relaxed,
Ordering::Relaxed,
) {
state = e;
continue;
}
match new_state {
STATE_SCHEDULING => return schedule(task),
STATE_NOTIFIED => break,
_ => unreachable!(),
}
}
if !by_ref {
Self::waker_drop(task);
}
}
}
fn schedule(task: *const Task) {
// TODO: replace with calling into executor
let task_ptr = task as usize;
std::thread::spawn(move || unsafe {
let task = &*(task_ptr as *const Task);
(task.vtable.resume)(task)
});
}
pub fn spawn<F: Future + Send + 'static>(future: F) {
FutureTask::spawn(future)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment