Skip to content

Instantly share code, notes, and snippets.

@cramertj
Created April 26, 2017 03:17
Show Gist options
  • Save cramertj/979448f963d81d38e71fe0c171b1e070 to your computer and use it in GitHub Desktop.
Save cramertj/979448f963d81d38e71fe0c171b1e070 to your computer and use it in GitHub Desktop.
Rust Fibers
[package]
name = "fiber_test"
version = "0.1.0"
authors = ["Taylor Cramer <[email protected]>"]
[target.'cfg(windows)'.dependencies]
winapi = "0.2"
kernel32-sys = "0.2"
extern crate winapi;
extern crate kernel32;
use std::cell::{Cell, UnsafeCell};
use std::marker::PhantomData;
use std::mem;
use std::ptr;
use winapi::basetsd::SIZE_T;
use winapi::minwindef::LPVOID;
use winapi::winbase::LPFIBER_START_ROUTINE;
use kernel32::{
ConvertFiberToThread,
ConvertThreadToFiber,
CreateFiber,
SwitchToFiber,
DeleteFiber,
};
trait FiberFunc<Input, Output>: FnOnce(&Yielder<Input, Output>, Input) + Sized {}
impl<Input, Output, FN> FiberFunc<Input, Output> for FN
where FN: FnOnce(&Yielder<Input, Output>, Input) + Sized
{}
struct FiberData<Input, Output, FN> {
func: Option<FN>,
// msg is boxed so that it won't move around, allowing the fibers to write back and forth
msg: Box<UnsafeCell<FiberMsg<Input, Output>>>,
}
enum FiberMsg<Input, Output> {
None,
PreInit,
Init {
root_fiber: LPVOID,
arg: Input,
},
ToFiber(Input),
FromFiber(Output),
Paused,
Done,
}
impl<Input, Output> FiberMsg<Input, Output> {
#[inline(always)]
fn take(&mut self) -> FiberMsg<Input, Output> {
unsafe {
let old_self = ptr::read(self);
ptr::write(self, FiberMsg::None);
old_self
}
}
}
struct Yielder<Input, Output> {
root_fiber: LPVOID,
msg_ptr: *mut FiberMsg<Input, Output>,
}
impl<Input, Output> Yielder<Input, Output> {
fn suspend(&self, val: Output) -> Input {
unsafe {
*self.msg_ptr = FiberMsg::FromFiber(val);
// Yeild out to the root fiber
SwitchToFiber(self.root_fiber);
// We've resumed!
if let FiberMsg::ToFiber(arg) = (*self.msg_ptr).take() {
arg
} else {
panic!("huuuuuhhhh????")
}
}
}
}
struct PanicGuard;
impl Drop for PanicGuard {
fn drop(&mut self) {
panic!("Panic accross FFI boundary.");
}
}
unsafe extern "system" fn fiber_fn<Input, Output, FN> (fiber_data_void: LPVOID)
where FN: FiberFunc<Input, Output>
{
let panic_guard = PanicGuard;
let fiber_data_ptr = fiber_data_void as *mut FiberData<Input, Output, FN>;
let func = (*fiber_data_ptr).func.take().unwrap();
let msg_ptr = (*fiber_data_ptr).msg.get();
if let FiberMsg::Init { root_fiber, arg } = (*msg_ptr).take() {
(func)(&Yielder {
root_fiber,
msg_ptr,
}, arg);
// We're done-- go back
*msg_ptr = FiberMsg::Done;
SwitchToFiber(root_fiber);
} else {
panic!("First message to fiber should be FiberMsg::Init. This is a bug.");
}
mem::forget(panic_guard);
}
const STACK_SIZE: SIZE_T = 2048;
const NULL: LPVOID = 0 as LPVOID;
struct Fiber<'a, Input: 'a, Output: 'a, FN: FiberFunc<Input, Output>> {
fiber_ptr: LPVOID, // NULL if it hasn't run yet
fiber_data: FiberData<Input, Output, FN>,
// Invariant in Input ptr lifetime, variant in Output ptr lifetime
phantom: PhantomData<(&'a (), *mut Input, *const Output)>
}
struct ThreadFiberInfo {
fiber_count: Cell<usize>,
thread_fiber: Cell<LPVOID>,
}
thread_local!(static THREAD_FIBER_INFO: ThreadFiberInfo = ThreadFiberInfo {
fiber_count: Cell::new(0),
thread_fiber: Cell::new(NULL),
});
impl<'a, Input, Output, FN> Fiber<'a, Input, Output, FN>
where
Input: 'a,
Output: 'a,
FN: FiberFunc<Input, Output>
{
fn new(func: FN) -> Self {
THREAD_FIBER_INFO.with(|tfi| {
if tfi.thread_fiber.get() == NULL {
let thread_fiber: LPVOID = unsafe { ConvertThreadToFiber(NULL) };
assert!(thread_fiber != NULL);
tfi.thread_fiber.set(thread_fiber);
}
tfi.fiber_count.set(tfi.fiber_count.get() + 1);
Fiber {
fiber_ptr: NULL,
fiber_data: FiberData {
func: Some(func),
msg: Box::new(UnsafeCell::new(FiberMsg::PreInit)),
},
phantom: PhantomData,
}
})
}
fn resume(&mut self, arg: Input) -> Option<Output> {
THREAD_FIBER_INFO.with(|tfi| {
{
// We can guarantee no aliasing of msg_ptr up until the fiber switch
let msg_ptr: &mut FiberMsg<Input, Output> = unsafe {
&mut *self.fiber_data.msg.get()
};
match msg_ptr.take() {
FiberMsg::None |
FiberMsg::Init { .. } |
FiberMsg::ToFiber(_) |
FiberMsg::FromFiber(_) => unreachable!(),
FiberMsg::PreInit => {
debug_assert!(self.fiber_ptr == NULL);
let start_func: LPFIBER_START_ROUTINE = Some(fiber_fn::<Input, Output, FN>);
*msg_ptr = FiberMsg::Init {
root_fiber: tfi.thread_fiber.get(),
arg: arg,
};
let lp_parameter: LPVOID = &mut self.fiber_data as *mut _ as LPVOID;
self.fiber_ptr = unsafe { CreateFiber(STACK_SIZE, start_func, lp_parameter) };
assert!(self.fiber_ptr != NULL);
},
FiberMsg::Paused => {
*msg_ptr = FiberMsg::ToFiber(arg);
}
FiberMsg::Done => { return None; }
}
}
unsafe { SwitchToFiber(self.fiber_ptr); }
// Once again, we can guarantee no aliasing of msg_ptr after the fiber switch
let msg_ptr: &mut FiberMsg<Input, Output> = unsafe {
&mut *self.fiber_data.msg.get()
};
match msg_ptr.take() {
FiberMsg::None |
FiberMsg::PreInit |
FiberMsg::Init{ .. } |
FiberMsg::ToFiber(_) |
FiberMsg::Paused => unreachable!(),
FiberMsg::FromFiber(ret) => {
*msg_ptr = FiberMsg::Paused;
Some(ret)
},
FiberMsg::Done => {
*msg_ptr = FiberMsg::Done;
None
},
}
})
}
}
impl<'a, Input, Output, FN> Drop for Fiber<'a, Input, Output, FN>
where
Input: 'a,
Output: 'a,
FN: FiberFunc<Input, Output>
{
fn drop(&mut self) {
THREAD_FIBER_INFO.with(|tfi| {
unsafe { DeleteFiber(self.fiber_ptr); }
let fiber_count = tfi.fiber_count.get() - 1;
tfi.fiber_count.set(fiber_count);
if fiber_count == 0 {
let thread_again: bool = unsafe { ConvertFiberToThread() } != 0 ;
assert!(thread_again);
}
});
}
}
fn main() {
let mut fiber = Fiber::new(move |yielder: &Yielder<_, _>, _| {
for i in 0.. { yielder.suspend(i); }
});
println!("{:?} {:?} {:?} {:?} {:?}",
fiber.resume(()),
fiber.resume(()),
fiber.resume(()),
fiber.resume(()),
fiber.resume(()));
println!("Done");
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn count() {
let mut fiber = Fiber::new(move |yielder: &Yielder<_, _>, _| {
for i in 0.. { yielder.suspend(i); }
});
assert_eq!(
[ 0, 1, 2, 3, 4 ],
[
fiber.resume(()).unwrap(),
fiber.resume(()).unwrap(),
fiber.resume(()).unwrap(),
fiber.resume(()).unwrap(),
fiber.resume(()).unwrap(),
]
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment