Created
January 9, 2021 01:09
-
-
Save kprotty/a0ce34836862745a6a5b51253935f0a1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
pin::Pin, | |
future::Future, | |
cell::UnsafeCell, | |
task::{Waker, RawWaker, RawWakerVTable, Poll, Context}, | |
sync::atomic::{AtomicUsize, Ordering}, | |
}; | |
struct VTable { | |
resume: unsafe fn(*const Task), | |
waker_clone: unsafe fn(*const Task), | |
waker_wake: unsafe fn(*const Task, bool), | |
waker_drop: unsafe fn(*const Task), | |
} | |
struct Task { | |
vtable: &'static VTable, | |
} | |
impl Task { | |
const VTABLE: RawWakerVTable = RawWakerVTable::new( | |
Self::waker_clone, | |
Self::waker_wake, | |
Self::waker_wake_by_ref, | |
Self::waker_drop, | |
); | |
unsafe fn get_waker(&self) -> Waker { | |
let ptr = self as *const Task as *const (); | |
let raw_waker = RawWaker::new(ptr, &Self::VTABLE); | |
Waker::from_raw(raw_waker) | |
} | |
unsafe fn waker_clone(ptr: *const ()) -> RawWaker { | |
let task = &*(ptr as *const Task); | |
(task.vtable.waker_clone)(task); | |
RawWaker::new(ptr, &Self::VTABLE) | |
} | |
unsafe fn waker_wake(ptr: *const ()) { | |
let task = &*(ptr as *const Task); | |
(task.vtable.waker_wake)(task, false) | |
} | |
unsafe fn waker_wake_by_ref(ptr: *const ()) { | |
let task = &*(ptr as *const Task); | |
(task.vtable.waker_wake)(task, true) | |
} | |
unsafe fn waker_drop(ptr: *const ()) { | |
let task = &*(ptr as *const Task); | |
(task.vtable.waker_drop)(task) | |
} | |
} | |
const STATE_WAITING: usize = 0; | |
const STATE_SCHEDULING: usize = 1; | |
const STATE_RUNNING: usize = 2; | |
const STATE_NOTIFIED: usize = 3; | |
#[repr(C)] | |
struct FutureTask<F> { | |
task: Task, | |
state: AtomicUsize, | |
future: UnsafeCell<F>, | |
ref_count: AtomicUsize, | |
} | |
impl<F: Future> FutureTask<F> { | |
const VTABLE: VTable = VTable { | |
resume: Self::resume, | |
waker_clone: Self::waker_clone, | |
waker_wake: Self::waker_wake, | |
waker_drop: Self::waker_drop, | |
}; | |
fn spawn(future: F) { | |
let fut_task = Box::new(Self { | |
task: Task { vtable: &Self::VTABLE }, | |
state: AtomicUsize::new(STATE_WAITING), | |
future: UnsafeCell::new(future), | |
ref_count: AtomicUsize::new(0), | |
}); | |
(unsafe { fut_task.task.get_waker() }) | |
.clone() | |
.wake() | |
} | |
unsafe fn resume(task: *const Task) { | |
let fut_task = &*(task as *const Self); | |
let waker = fut_task.task.get_waker(); | |
let mut ctx = Context::from_waker(&waker); | |
let mut state = fut_task.state.swap(STATE_RUNNING, Ordering::Relaxed); | |
assert_eq!(state, STATE_SCHEDULING); | |
'resume: loop { | |
let future = Pin::new_unchecked(&mut *fut_task.future.get()); | |
match future.poll(&mut ctx) { | |
Poll::Ready(_) => break, | |
Poll::Pending => {}, | |
} | |
state = fut_task.state.load(Ordering::Relaxed); | |
loop { | |
let new_state = match state { | |
STATE_WAITING => unreachable!(), | |
STATE_RUNNING => STATE_WAITING, | |
STATE_SCHEDULING => unreachable!(), | |
STATE_NOTIFIED => STATE_RUNNING, | |
_ => unreachable!(), | |
}; | |
if let Err(e) = fut_task.state.compare_exchange_weak( | |
state, | |
new_state, | |
Ordering::Relaxed, | |
Ordering::Relaxed, | |
) { | |
state = e; | |
continue; | |
} | |
match new_state { | |
STATE_WAITING => break 'resume, | |
STATE_RUNNING => continue 'resume, | |
_ => unreachable!(), | |
} | |
} | |
} | |
Self::waker_drop(task) | |
} | |
unsafe fn waker_clone(task: *const Task) { | |
let fut_task = &*(task as *const Self); | |
fut_task.ref_count.fetch_add(1, Ordering::Relaxed); | |
} | |
unsafe fn waker_drop(task: *const Task) { | |
let fut_task = &*(task as *const Self); | |
if fut_task.ref_count.fetch_sub(1, Ordering::Relaxed) == 1 { | |
std::mem::drop(Box::from_raw(task as *mut Self)); | |
} | |
} | |
unsafe fn waker_wake(task: *const Task, by_ref: bool) { | |
let fut_task = &*(task as *const Self); | |
let mut state = fut_task.state.load(Ordering::Relaxed); | |
loop { | |
let new_state = match state { | |
STATE_WAITING => STATE_SCHEDULING, | |
STATE_SCHEDULING => break, | |
STATE_RUNNING => STATE_NOTIFIED, | |
STATE_NOTIFIED => break, | |
_ => unreachable!(), | |
}; | |
if let Err(e) = fut_task.state.compare_exchange_weak( | |
state, | |
new_state, | |
Ordering::Relaxed, | |
Ordering::Relaxed, | |
) { | |
state = e; | |
continue; | |
} | |
match new_state { | |
STATE_SCHEDULING => return schedule(task), | |
STATE_NOTIFIED => break, | |
_ => unreachable!(), | |
} | |
} | |
if !by_ref { | |
Self::waker_drop(task); | |
} | |
} | |
} | |
fn schedule(task: *const Task) { | |
// TODO: replace with calling into executor | |
let task_ptr = task as usize; | |
std::thread::spawn(move || unsafe { | |
let task = &*(task_ptr as *const Task); | |
(task.vtable.resume)(task) | |
}); | |
} | |
pub fn spawn<F: Future + Send + 'static>(future: F) { | |
FutureTask::spawn(future) | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment