Skip to content

Instantly share code, notes, and snippets.

@sminez
Last active February 25, 2025 16:12
Show Gist options
  • Save sminez/9cea7cfcf8ac7fc508508fa0789404ef to your computer and use it in GitHub Desktop.
Save sminez/9cea7cfcf8ac7fc508508fa0789404ef to your computer and use it in GitHub Desktop.
(ab)using rust async/await to drive a little state machine
//! This is a little exploration of (ab)using async/await syntax + a dummy Waker to simplify
//! writing sans-io state machine code.
use std::{
cell::UnsafeCell,
future::Future,
io::{self, Cursor, ErrorKind, Read},
pin::{pin, Pin},
sync::Arc,
task::{Context, Poll, Wake, Waker},
};
use tokio::io::{AsyncRead, AsyncReadExt};
// ["Hello", "世界"] in 9p wire format.
//
// In 9p, data items of larger or variable lengths are represented by a two-byte field
// specifying a count, n, followed by n bytes of data.
// - Strings are represented this way with the data stored as UTF-8 without a trailing null byte.
// - Arrays are represented as a length for the sum of the encoded data for all elements followed
// by the encoded form of each element.
const HELLO_WORLD: [u8; 17] = [
0x02, 0x00, 0x05, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x06, 0x00, 0xe4, 0xb8, 0x96, 0xe7, 0x95,
0x8c,
];
#[tokio::main]
async fn main() -> io::Result<()> {
println!(">> reading using std::io::Read");
let v: Vec<String> = read_9p_sync_from_bytes(&mut Cursor::new(HELLO_WORLD.to_vec()))?;
println!(" got val: {v:?}\n");
println!(">> reading using tokio::io::AsyncRead");
let v: Vec<String> = read_9p_async_from_bytes(&mut Cursor::new(HELLO_WORLD.to_vec())).await?;
println!(" got val: {v:?}");
Ok(())
}
/// Read a [Read9p] value using an implementation of [Read] to perform the required IO.
fn read_9p_sync_from_bytes<T, R>(r: &mut R) -> io::Result<T>
where
T: Read9p,
R: Read,
{
let s = Arc::new(State::default());
let waker = Waker::from(s.clone());
let mut ctx = Context::from_waker(&waker);
// SAFETY: assumes the impl of Read9p is a valid future for us to poll
let mut fut = unsafe { pin!(T::read()) };
loop {
match fut.as_mut().poll(&mut ctx) {
Poll::Ready(val) => return val,
Poll::Pending => unsafe {
let n = s.requested_bytes();
println!("{n} bytes requested");
let mut buf = vec![0; n];
r.read_exact(&mut buf)?;
s.set_bytes(buf);
},
}
}
}
/// Read a [Read9p] value using an implementation of [AsyncRead] to perform the required IO.
async fn read_9p_async_from_bytes<T, R>(r: &mut R) -> io::Result<T>
where
T: Read9p,
R: AsyncRead + Unpin,
{
let s = Arc::new(State::default());
let waker = Waker::from(s.clone());
// SAFETY: assumes the impl of Read9p is a valid future for us to poll
let mut fut = unsafe { pin!(T::read()) };
loop {
match fut.as_mut().poll(&mut Context::from_waker(&waker)) {
Poll::Ready(val) => return val,
Poll::Pending => unsafe {
let n = s.requested_bytes();
println!("{n} bytes requested");
let mut buf = vec![0; n];
r.read_exact(&mut buf).await?;
s.set_bytes(buf);
},
}
}
}
/// Shared state between a [Read9p] impl and a parent read loop that is performing IO.
#[derive(Default, Debug)]
struct State {
inner: UnsafeCell<StateInner>,
}
unsafe impl Send for State {}
unsafe impl Sync for State {}
impl State {
unsafe fn requested_bytes(&self) -> usize {
unsafe { (*self.inner.get()).n }
}
unsafe fn set_requested(&self, n: usize) {
unsafe { (*self.inner.get()).n = n };
}
unsafe fn set_bytes(&self, buf: Vec<u8>) {
unsafe { (*self.inner.get()).buf = Some(buf) };
}
unsafe fn take_bytes(&self) -> Vec<u8> {
unsafe { (*self.inner.get()).buf.take().unwrap_unchecked() }
}
}
#[derive(Default, Debug)]
struct StateInner {
n: usize,
buf: Option<Vec<u8>>,
}
impl Wake for State {
fn wake(self: Arc<Self>) {}
fn wake_by_ref(self: &Arc<Self>) {}
}
/// Helper struct for awaiting a Future that returns pending once so we can return control to the
/// poll loop and perform IO.
struct RequestBytes {
polled: bool,
n: usize,
}
impl Future for RequestBytes {
type Output = Vec<u8>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Vec<u8>> {
if self.polled {
// SAFETY: we can only poll this future using a waker wrapping State
let data = unsafe {
(ctx.waker().data() as *mut () as *mut State)
.as_mut()
.unwrap_unchecked()
.take_bytes()
};
Poll::Ready(data)
} else {
self.polled = true;
// SAFETY: we can only poll this future using a waker wrapping State
unsafe {
(ctx.waker().data() as *mut () as *mut State)
.as_mut()
.unwrap_unchecked()
.set_requested(self.n);
};
Poll::Pending
}
}
}
/// Request a specific number of bytes from the parent poll loop and then yield to that poll loop
/// so it can perform IO and provide the requested data.
macro_rules! request_bytes {
($n:expr) => {{
RequestBytes {
polled: false,
n: $n,
}
.await
}};
}
trait Read9p: Sized {
/// This is not a normal async function. It is used to set up a sans-io state machine that
/// can be driven by a concrete implementation.
///
/// # Safety
/// Implementations of `read` need to ensure that the only await points they contain are
/// from calls to the [request_bytes] macro.
unsafe fn read() -> impl Future<Output = io::Result<Self>> + Send;
}
impl Read9p for u16 {
async unsafe fn read() -> io::Result<u16> {
// SAFETY: we only await as part of request_bytes or a nested call to read
let n = size_of::<u16>();
let buf = request_bytes!(n);
let data = buf[0..n].try_into().unwrap();
Ok(u16::from_le_bytes(data))
}
}
impl Read9p for String {
async unsafe fn read() -> io::Result<Self> {
// SAFETY: we only await as part of request_bytes or a nested call to read
unsafe {
let len = u16::read().await? as usize;
let buf = request_bytes!(len);
String::from_utf8(buf)
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e.to_string()))
}
}
}
impl<T: Read9p + Send> Read9p for Vec<T> {
async unsafe fn read() -> io::Result<Self> {
// SAFETY: we only await as part of request_bytes or a nested call to read
unsafe {
let len = u16::read().await? as usize;
let mut buf = Vec::with_capacity(len);
for _ in 0..len {
buf.push(T::read().await?);
}
Ok(buf)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment