Created
January 28, 2026 11:08
-
-
Save Amanieu/98a3853001ea73e9d5ce46cbfe9e501a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| //! Miri support for stack switching. | |
| //! | |
| //! This works differently from other backends in that Miri doesn't actually use | |
| //! the provided stack directly. Instead, we encode a `FiberId` as the stack | |
| //! pointer value and use that to switch to a given fiber. | |
| //! | |
| //! The stack is only used to hold the initial state for the fiber. The parent | |
| //! link holds the `FiberId` of the parent fiber. | |
| //! | |
| //! ```text | |
| //! +--------------+ <- Stack base | |
| //! | Initial func | | |
| //! +--------------+ | |
| //! | Parent link | | |
| //! +--------------+ | |
| //! | | | |
| //! ~ Initial obj ~ | |
| //! | | | |
| //! +--------------+ <- Initial stack pointer | |
| //! ``` | |
| use core::mem; | |
| type FiberId = usize; | |
| extern "Rust" { | |
| fn miri_fiber_create( | |
| body: unsafe extern "Rust" fn(*mut (), *mut u8) -> !, | |
| data: *mut (), | |
| ) -> FiberId; | |
| fn miri_fiber_current() -> FiberId; | |
| fn miri_fiber_switch(target: FiberId, payload: *mut u8) -> *mut u8; | |
| fn miri_fiber_exit_to(target: FiberId, payload: *mut u8) -> !; | |
| } | |
| use super::{allocate_obj_on_stack, push}; | |
| use crate::coroutine::adjusted_stack_base; | |
| use crate::stack::{Stack, StackPointer}; | |
| use crate::trap::TrapHandlerRegs; | |
| use crate::unwind::{InitialFunc, StackCallFunc, TrapHandler}; | |
| use crate::util::EncodedValue; | |
| pub const STACK_ALIGNMENT: usize = mem::size_of::<StackWord>(); | |
| pub const PARENT_STACK_OFFSET: usize = 0; | |
| pub const PARENT_LINK_OFFSET: usize = mem::size_of::<StackWord>() * 2; | |
| pub type StackWord = usize; | |
| #[derive(Clone, Copy)] | |
| struct Payload { | |
| fiber: FiberId, | |
| arg: EncodedValue, | |
| exit: bool, | |
| } | |
| /// Encodes a `FiberId` as a `StackPointer`. | |
| #[inline] | |
| fn encode_fiber(fiber: FiberId) -> StackPointer { | |
| StackPointer::new(fiber + 1).unwrap() | |
| } | |
| /// Decodes a `FiberId` from a `StackPointer`. | |
| #[inline] | |
| fn decode_fiber(sp: StackPointer) -> FiberId { | |
| sp.get() - 1 | |
| } | |
| /// Main body for fibers. | |
| unsafe fn fiber_body(data: *mut (), payload: *mut u8) -> ! { | |
| // Fetch the initial function, initial object and parent link from the | |
| // stack base. | |
| let stack_base = data as *mut StackWord; | |
| let initial_func: InitialFunc<u8> = *stack_base.cast(); | |
| let parent_link: &mut StackPointer = &mut *stack_base.add(1); | |
| let obj = stack_base.add(2).cast(); | |
| // The payload for switching from a parent to a child is just the | |
| // EncodedValue. | |
| initial_func(payload as EncodedValue, parent_link, obj) | |
| } | |
| #[inline] | |
| pub unsafe fn init_stack<T>(stack: &impl Stack, func: InitialFunc<T>, obj: T) -> StackPointer { | |
| let stack_base = adjusted_stack_base(stack).get(); | |
| let mut sp = stack_base; | |
| // Initial function. | |
| push(&mut sp, Some(func as StackWord)); | |
| // Placeholder for parent link. | |
| push(&mut sp, None); | |
| // Allocate space on the stack for the initial object, rounding to | |
| // STACK_ALIGNMENT. | |
| allocate_obj_on_stack(&mut sp, mem::size_of::<StackWord>() * 2, obj); | |
| // The stack is aligned to STACK_ALIGNMENT at this point. | |
| debug_assert_eq!(sp % STACK_ALIGNMENT, 0); | |
| // Create a Miri fiber and pass it the stack base. | |
| let fiber = miri_fiber_create(fiber_body, stack_base as *mut ()); | |
| // Return the fiber ID encoded as the stack pointer value. | |
| encode_fiber(fiber) | |
| } | |
| #[inline] | |
| pub unsafe fn switch_and_link( | |
| arg: EncodedValue, | |
| sp: StackPointer, | |
| stack_base: StackPointer, | |
| ) -> (EncodedValue, Option<StackPointer>) { | |
| // Write the current fiber to the parent link on the child stack. | |
| *(stack_base.get() as *mut StackWord).add(1).cast() = encode_fiber(miri_fiber_current()); | |
| // Switch to the child fiber. | |
| let result = miri_fiber_switch(decode_fiber(sp), arg as *mut u8); | |
| // Read the returned value and new child fiber from the returned payload. | |
| let payload = *(result as *const Payload); | |
| (payload.arg, (!exit).then(encode_fiber(payload.fiber))) | |
| // TODO: Call miri_destroy_fiber to free the fiber here if done | |
| } | |
| #[inline(always)] | |
| pub unsafe fn switch_yield(arg: EncodedValue, parent_link: *mut StackPointer) -> EncodedValue { | |
| let payload = Payload { | |
| fiber: miri_fiber_current(), | |
| arg, | |
| exit: false, | |
| }; | |
| let fiber = decode_fiber(*parent_link); | |
| miri_fiber_switch(fiber, &payload as *const Payload as *mut u8) as EncodedValue | |
| } | |
| #[inline(always)] | |
| pub unsafe fn switch_and_reset(arg: EncodedValue, parent_link: *mut StackPointer) -> ! { | |
| let payload = Payload { | |
| fiber: miri_fiber_current(), | |
| arg, | |
| exit: true, | |
| }; | |
| let fiber = decode_fiber(*parent_link); | |
| miri_fiber_switch(fiber, &payload as *const Payload as *mut u8); | |
| unreachable!() | |
| } | |
| #[inline] | |
| pub unsafe fn drop_initial_obj( | |
| _stack_base: StackPointer, | |
| stack_ptr: StackPointer, | |
| drop_fn: unsafe fn(ptr: *mut u8), | |
| ) { | |
| // TODO: This is currently wrong | |
| let ptr = stack_ptr.get() as *mut u8; | |
| drop_fn(ptr); | |
| } | |
| pub unsafe fn setup_trap_trampoline<T>( | |
| stack_base: StackPointer, | |
| val: T, | |
| handler: TrapHandler<T>, | |
| ) -> TrapHandlerRegs { | |
| panic!("Trap handling is not supported in Miri"); | |
| } | |
| /// This function executes a function on the given stack. The argument is passed | |
| /// through to the called function. | |
| #[inline] | |
| pub unsafe fn on_stack(arg: *mut u8, stack: impl Stack, f: StackCallFunc) { | |
| // Miri doesn't use an in-memory stack, so just call the function directly. | |
| f(arg) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment