Skip to content

Instantly share code, notes, and snippets.

@Amanieu
Created January 28, 2026 11:08
Show Gist options
  • Select an option

  • Save Amanieu/98a3853001ea73e9d5ce46cbfe9e501a to your computer and use it in GitHub Desktop.

Select an option

Save Amanieu/98a3853001ea73e9d5ce46cbfe9e501a to your computer and use it in GitHub Desktop.
//! Miri support for stack switching.
//!
//! This works differently from other backends in that Miri doesn't actually use
//! the provided stack directly. Instead, we encode a `FiberId` as the stack
//! pointer value and use that to switch to a given fiber.
//!
//! The stack is only used to hold the initial state for the fiber. The parent
//! link holds the `FiberId` of the parent fiber.
//!
//! ```text
//! +--------------+ <- Stack base
//! | Initial func |
//! +--------------+
//! | Parent link |
//! +--------------+
//! | |
//! ~ Initial obj ~
//! | |
//! +--------------+ <- Initial stack pointer
//! ```
use core::mem;
type FiberId = usize;
extern "Rust" {
fn miri_fiber_create(
body: unsafe extern "Rust" fn(*mut (), *mut u8) -> !,
data: *mut (),
) -> FiberId;
fn miri_fiber_current() -> FiberId;
fn miri_fiber_switch(target: FiberId, payload: *mut u8) -> *mut u8;
fn miri_fiber_exit_to(target: FiberId, payload: *mut u8) -> !;
}
use super::{allocate_obj_on_stack, push};
use crate::coroutine::adjusted_stack_base;
use crate::stack::{Stack, StackPointer};
use crate::trap::TrapHandlerRegs;
use crate::unwind::{InitialFunc, StackCallFunc, TrapHandler};
use crate::util::EncodedValue;
pub const STACK_ALIGNMENT: usize = mem::size_of::<StackWord>();
pub const PARENT_STACK_OFFSET: usize = 0;
pub const PARENT_LINK_OFFSET: usize = mem::size_of::<StackWord>() * 2;
pub type StackWord = usize;
#[derive(Clone, Copy)]
struct Payload {
fiber: FiberId,
arg: EncodedValue,
exit: bool,
}
/// Encodes a `FiberId` as a `StackPointer`.
#[inline]
fn encode_fiber(fiber: FiberId) -> StackPointer {
StackPointer::new(fiber + 1).unwrap()
}
/// Decodes a `FiberId` from a `StackPointer`.
#[inline]
fn decode_fiber(sp: StackPointer) -> FiberId {
sp.get() - 1
}
/// Main body for fibers.
unsafe fn fiber_body(data: *mut (), payload: *mut u8) -> ! {
// Fetch the initial function, initial object and parent link from the
// stack base.
let stack_base = data as *mut StackWord;
let initial_func: InitialFunc<u8> = *stack_base.cast();
let parent_link: &mut StackPointer = &mut *stack_base.add(1);
let obj = stack_base.add(2).cast();
// The payload for switching from a parent to a child is just the
// EncodedValue.
initial_func(payload as EncodedValue, parent_link, obj)
}
#[inline]
pub unsafe fn init_stack<T>(stack: &impl Stack, func: InitialFunc<T>, obj: T) -> StackPointer {
let stack_base = adjusted_stack_base(stack).get();
let mut sp = stack_base;
// Initial function.
push(&mut sp, Some(func as StackWord));
// Placeholder for parent link.
push(&mut sp, None);
// Allocate space on the stack for the initial object, rounding to
// STACK_ALIGNMENT.
allocate_obj_on_stack(&mut sp, mem::size_of::<StackWord>() * 2, obj);
// The stack is aligned to STACK_ALIGNMENT at this point.
debug_assert_eq!(sp % STACK_ALIGNMENT, 0);
// Create a Miri fiber and pass it the stack base.
let fiber = miri_fiber_create(fiber_body, stack_base as *mut ());
// Return the fiber ID encoded as the stack pointer value.
encode_fiber(fiber)
}
#[inline]
pub unsafe fn switch_and_link(
arg: EncodedValue,
sp: StackPointer,
stack_base: StackPointer,
) -> (EncodedValue, Option<StackPointer>) {
// Write the current fiber to the parent link on the child stack.
*(stack_base.get() as *mut StackWord).add(1).cast() = encode_fiber(miri_fiber_current());
// Switch to the child fiber.
let result = miri_fiber_switch(decode_fiber(sp), arg as *mut u8);
// Read the returned value and new child fiber from the returned payload.
let payload = *(result as *const Payload);
(payload.arg, (!exit).then(encode_fiber(payload.fiber)))
// TODO: Call miri_destroy_fiber to free the fiber here if done
}
#[inline(always)]
pub unsafe fn switch_yield(arg: EncodedValue, parent_link: *mut StackPointer) -> EncodedValue {
let payload = Payload {
fiber: miri_fiber_current(),
arg,
exit: false,
};
let fiber = decode_fiber(*parent_link);
miri_fiber_switch(fiber, &payload as *const Payload as *mut u8) as EncodedValue
}
#[inline(always)]
pub unsafe fn switch_and_reset(arg: EncodedValue, parent_link: *mut StackPointer) -> ! {
let payload = Payload {
fiber: miri_fiber_current(),
arg,
exit: true,
};
let fiber = decode_fiber(*parent_link);
miri_fiber_switch(fiber, &payload as *const Payload as *mut u8);
unreachable!()
}
#[inline]
pub unsafe fn drop_initial_obj(
_stack_base: StackPointer,
stack_ptr: StackPointer,
drop_fn: unsafe fn(ptr: *mut u8),
) {
// TODO: This is currently wrong
let ptr = stack_ptr.get() as *mut u8;
drop_fn(ptr);
}
pub unsafe fn setup_trap_trampoline<T>(
stack_base: StackPointer,
val: T,
handler: TrapHandler<T>,
) -> TrapHandlerRegs {
panic!("Trap handling is not supported in Miri");
}
/// This function executes a function on the given stack. The argument is passed
/// through to the called function.
#[inline]
pub unsafe fn on_stack(arg: *mut u8, stack: impl Stack, f: StackCallFunc) {
// Miri doesn't use an in-memory stack, so just call the function directly.
f(arg)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment