Skip to content

Instantly share code, notes, and snippets.

@sug0
Last active December 19, 2024 12:12
Show Gist options
  • Save sug0/8693a43ae678fd83215807068b71146d to your computer and use it in GitHub Desktop.
Save sug0/8693a43ae678fd83215807068b71146d to your computer and use it in GitHub Desktop.
CPS Rust state machines
use std::marker::PhantomData;
use std::panic;
use std::sync::Once;
fn main() {
for _ in 0..3 {
println!("RUNNING STATE MACHINE");
println!("=====================");
run_state_machine((), beginning_state);
println!();
}
}
struct AbortExecution;
type State<Ctx> = fn(&mut Ctx) -> NextState<Ctx>;
#[repr(transparent)]
struct NextState<Ctx> {
inner: fn(),
_marker: PhantomData<fn(Ctx)>,
}
impl<Ctx> NextState<Ctx> {
fn get(&self) -> State<Ctx> {
unsafe { std::mem::transmute(self.inner) }
}
fn bail() -> Self {
panic::panic_any(AbortExecution)
}
fn from_fn(state: State<Ctx>) -> Self {
Self {
inner: unsafe { std::mem::transmute(state) },
_marker: PhantomData,
}
}
}
fn abort_execution_panic_middleware(
next: Box<dyn Fn(&panic::PanicHookInfo<'_>) + Sync + Send + 'static>,
) -> Box<dyn Fn(&panic::PanicHookInfo<'_>) + Sync + Send + 'static> {
Box::new(move |panic_info: &panic::PanicHookInfo<'_>| {
if !panic_info.payload().is::<AbortExecution>() {
next(panic_info);
}
})
}
fn run_state_machine<Ctx>(context: Ctx, init_state: State<Ctx>)
where
Ctx: panic::UnwindSafe,
{
static INSTALL_HANDLER: Once = Once::new();
INSTALL_HANDLER.call_once(|| {
let previous_hook = panic::take_hook();
panic::set_hook(abort_execution_panic_middleware(previous_hook));
});
// run the state machine until a panic is encountered
let result = panic::catch_unwind(|| run_state_machine_inner(context, init_state));
// handle the panic
match result {
// this is an actual panic
Err(exception) if !exception.is::<AbortExecution>() => {
panic::resume_unwind(exception);
}
// this is a panic initiated by the
// state machine, in order to abort
// execution
_ => {}
}
}
fn run_state_machine_inner<Ctx>(mut context: Ctx, mut state: State<Ctx>) -> ! {
loop {
state = state(&mut context).get();
}
}
fn beginning_state(_: &mut ()) -> NextState<()> {
println!("begin work");
NextState::from_fn(middle_state)
}
fn middle_state(_: &mut ()) -> NextState<()> {
println!("do work...");
NextState::from_fn(end_state)
}
fn end_state(_: &mut ()) -> NextState<()> {
println!("the work is done. alright, cya");
NextState::bail()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment