Created
April 26, 2017 03:17
-
-
Save cramertj/979448f963d81d38e71fe0c171b1e070 to your computer and use it in GitHub Desktop.
Rust Fibers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "fiber_test" | |
version = "0.1.0" | |
authors = ["Taylor Cramer <[email protected]>"] | |
[target.'cfg(windows)'.dependencies] | |
winapi = "0.2" | |
kernel32-sys = "0.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate winapi; | |
extern crate kernel32; | |
use std::cell::{Cell, UnsafeCell}; | |
use std::marker::PhantomData; | |
use std::mem; | |
use std::ptr; | |
use winapi::basetsd::SIZE_T; | |
use winapi::minwindef::LPVOID; | |
use winapi::winbase::LPFIBER_START_ROUTINE; | |
use kernel32::{ | |
ConvertFiberToThread, | |
ConvertThreadToFiber, | |
CreateFiber, | |
SwitchToFiber, | |
DeleteFiber, | |
}; | |
trait FiberFunc<Input, Output>: FnOnce(&Yielder<Input, Output>, Input) + Sized {} | |
impl<Input, Output, FN> FiberFunc<Input, Output> for FN | |
where FN: FnOnce(&Yielder<Input, Output>, Input) + Sized | |
{} | |
struct FiberData<Input, Output, FN> { | |
func: Option<FN>, | |
// msg is boxed so that it won't move around, allowing the fibers to write back and forth | |
msg: Box<UnsafeCell<FiberMsg<Input, Output>>>, | |
} | |
enum FiberMsg<Input, Output> { | |
None, | |
PreInit, | |
Init { | |
root_fiber: LPVOID, | |
arg: Input, | |
}, | |
ToFiber(Input), | |
FromFiber(Output), | |
Paused, | |
Done, | |
} | |
impl<Input, Output> FiberMsg<Input, Output> { | |
#[inline(always)] | |
fn take(&mut self) -> FiberMsg<Input, Output> { | |
unsafe { | |
let old_self = ptr::read(self); | |
ptr::write(self, FiberMsg::None); | |
old_self | |
} | |
} | |
} | |
struct Yielder<Input, Output> { | |
root_fiber: LPVOID, | |
msg_ptr: *mut FiberMsg<Input, Output>, | |
} | |
impl<Input, Output> Yielder<Input, Output> { | |
fn suspend(&self, val: Output) -> Input { | |
unsafe { | |
*self.msg_ptr = FiberMsg::FromFiber(val); | |
// Yeild out to the root fiber | |
SwitchToFiber(self.root_fiber); | |
// We've resumed! | |
if let FiberMsg::ToFiber(arg) = (*self.msg_ptr).take() { | |
arg | |
} else { | |
panic!("huuuuuhhhh????") | |
} | |
} | |
} | |
} | |
struct PanicGuard; | |
impl Drop for PanicGuard { | |
fn drop(&mut self) { | |
panic!("Panic accross FFI boundary."); | |
} | |
} | |
unsafe extern "system" fn fiber_fn<Input, Output, FN> (fiber_data_void: LPVOID) | |
where FN: FiberFunc<Input, Output> | |
{ | |
let panic_guard = PanicGuard; | |
let fiber_data_ptr = fiber_data_void as *mut FiberData<Input, Output, FN>; | |
let func = (*fiber_data_ptr).func.take().unwrap(); | |
let msg_ptr = (*fiber_data_ptr).msg.get(); | |
if let FiberMsg::Init { root_fiber, arg } = (*msg_ptr).take() { | |
(func)(&Yielder { | |
root_fiber, | |
msg_ptr, | |
}, arg); | |
// We're done-- go back | |
*msg_ptr = FiberMsg::Done; | |
SwitchToFiber(root_fiber); | |
} else { | |
panic!("First message to fiber should be FiberMsg::Init. This is a bug."); | |
} | |
mem::forget(panic_guard); | |
} | |
const STACK_SIZE: SIZE_T = 2048; | |
const NULL: LPVOID = 0 as LPVOID; | |
struct Fiber<'a, Input: 'a, Output: 'a, FN: FiberFunc<Input, Output>> { | |
fiber_ptr: LPVOID, // NULL if it hasn't run yet | |
fiber_data: FiberData<Input, Output, FN>, | |
// Invariant in Input ptr lifetime, variant in Output ptr lifetime | |
phantom: PhantomData<(&'a (), *mut Input, *const Output)> | |
} | |
struct ThreadFiberInfo { | |
fiber_count: Cell<usize>, | |
thread_fiber: Cell<LPVOID>, | |
} | |
thread_local!(static THREAD_FIBER_INFO: ThreadFiberInfo = ThreadFiberInfo { | |
fiber_count: Cell::new(0), | |
thread_fiber: Cell::new(NULL), | |
}); | |
impl<'a, Input, Output, FN> Fiber<'a, Input, Output, FN> | |
where | |
Input: 'a, | |
Output: 'a, | |
FN: FiberFunc<Input, Output> | |
{ | |
fn new(func: FN) -> Self { | |
THREAD_FIBER_INFO.with(|tfi| { | |
if tfi.thread_fiber.get() == NULL { | |
let thread_fiber: LPVOID = unsafe { ConvertThreadToFiber(NULL) }; | |
assert!(thread_fiber != NULL); | |
tfi.thread_fiber.set(thread_fiber); | |
} | |
tfi.fiber_count.set(tfi.fiber_count.get() + 1); | |
Fiber { | |
fiber_ptr: NULL, | |
fiber_data: FiberData { | |
func: Some(func), | |
msg: Box::new(UnsafeCell::new(FiberMsg::PreInit)), | |
}, | |
phantom: PhantomData, | |
} | |
}) | |
} | |
fn resume(&mut self, arg: Input) -> Option<Output> { | |
THREAD_FIBER_INFO.with(|tfi| { | |
{ | |
// We can guarantee no aliasing of msg_ptr up until the fiber switch | |
let msg_ptr: &mut FiberMsg<Input, Output> = unsafe { | |
&mut *self.fiber_data.msg.get() | |
}; | |
match msg_ptr.take() { | |
FiberMsg::None | | |
FiberMsg::Init { .. } | | |
FiberMsg::ToFiber(_) | | |
FiberMsg::FromFiber(_) => unreachable!(), | |
FiberMsg::PreInit => { | |
debug_assert!(self.fiber_ptr == NULL); | |
let start_func: LPFIBER_START_ROUTINE = Some(fiber_fn::<Input, Output, FN>); | |
*msg_ptr = FiberMsg::Init { | |
root_fiber: tfi.thread_fiber.get(), | |
arg: arg, | |
}; | |
let lp_parameter: LPVOID = &mut self.fiber_data as *mut _ as LPVOID; | |
self.fiber_ptr = unsafe { CreateFiber(STACK_SIZE, start_func, lp_parameter) }; | |
assert!(self.fiber_ptr != NULL); | |
}, | |
FiberMsg::Paused => { | |
*msg_ptr = FiberMsg::ToFiber(arg); | |
} | |
FiberMsg::Done => { return None; } | |
} | |
} | |
unsafe { SwitchToFiber(self.fiber_ptr); } | |
// Once again, we can guarantee no aliasing of msg_ptr after the fiber switch | |
let msg_ptr: &mut FiberMsg<Input, Output> = unsafe { | |
&mut *self.fiber_data.msg.get() | |
}; | |
match msg_ptr.take() { | |
FiberMsg::None | | |
FiberMsg::PreInit | | |
FiberMsg::Init{ .. } | | |
FiberMsg::ToFiber(_) | | |
FiberMsg::Paused => unreachable!(), | |
FiberMsg::FromFiber(ret) => { | |
*msg_ptr = FiberMsg::Paused; | |
Some(ret) | |
}, | |
FiberMsg::Done => { | |
*msg_ptr = FiberMsg::Done; | |
None | |
}, | |
} | |
}) | |
} | |
} | |
impl<'a, Input, Output, FN> Drop for Fiber<'a, Input, Output, FN> | |
where | |
Input: 'a, | |
Output: 'a, | |
FN: FiberFunc<Input, Output> | |
{ | |
fn drop(&mut self) { | |
THREAD_FIBER_INFO.with(|tfi| { | |
unsafe { DeleteFiber(self.fiber_ptr); } | |
let fiber_count = tfi.fiber_count.get() - 1; | |
tfi.fiber_count.set(fiber_count); | |
if fiber_count == 0 { | |
let thread_again: bool = unsafe { ConvertFiberToThread() } != 0 ; | |
assert!(thread_again); | |
} | |
}); | |
} | |
} | |
fn main() { | |
let mut fiber = Fiber::new(move |yielder: &Yielder<_, _>, _| { | |
for i in 0.. { yielder.suspend(i); } | |
}); | |
println!("{:?} {:?} {:?} {:?} {:?}", | |
fiber.resume(()), | |
fiber.resume(()), | |
fiber.resume(()), | |
fiber.resume(()), | |
fiber.resume(())); | |
println!("Done"); | |
} | |
#[cfg(test)] | |
mod test { | |
use super::*; | |
#[test] | |
fn count() { | |
let mut fiber = Fiber::new(move |yielder: &Yielder<_, _>, _| { | |
for i in 0.. { yielder.suspend(i); } | |
}); | |
assert_eq!( | |
[ 0, 1, 2, 3, 4 ], | |
[ | |
fiber.resume(()).unwrap(), | |
fiber.resume(()).unwrap(), | |
fiber.resume(()).unwrap(), | |
fiber.resume(()).unwrap(), | |
fiber.resume(()).unwrap(), | |
] | |
); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment