Last active
February 25, 2025 16:12
-
-
Save sminez/9cea7cfcf8ac7fc508508fa0789404ef to your computer and use it in GitHub Desktop.
(ab)using rust async/await to drive a little state machine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! This is a little exploration of (ab)using async/await syntax + a dummy Waker to simplify | |
//! writing sans-io state machine code. | |
use std::{ | |
cell::UnsafeCell, | |
future::Future, | |
io::{self, Cursor, ErrorKind, Read}, | |
pin::{pin, Pin}, | |
sync::Arc, | |
task::{Context, Poll, Wake, Waker}, | |
}; | |
use tokio::io::{AsyncRead, AsyncReadExt}; | |
// ["Hello", "世界"] in 9p wire format. | |
// | |
// In 9p, data items of larger or variable lengths are represented by a two-byte field | |
// specifying a count, n, followed by n bytes of data. | |
// - Strings are represented this way with the data stored as UTF-8 without a trailing null byte. | |
// - Arrays are represented as a length for the sum of the encoded data for all elements followed | |
// by the encoded form of each element. | |
const HELLO_WORLD: [u8; 17] = [ | |
0x02, 0x00, 0x05, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x06, 0x00, 0xe4, 0xb8, 0x96, 0xe7, 0x95, | |
0x8c, | |
]; | |
#[tokio::main] | |
async fn main() -> io::Result<()> { | |
println!(">> reading using std::io::Read"); | |
let v: Vec<String> = read_9p_sync_from_bytes(&mut Cursor::new(HELLO_WORLD.to_vec()))?; | |
println!(" got val: {v:?}\n"); | |
println!(">> reading using tokio::io::AsyncRead"); | |
let v: Vec<String> = read_9p_async_from_bytes(&mut Cursor::new(HELLO_WORLD.to_vec())).await?; | |
println!(" got val: {v:?}"); | |
Ok(()) | |
} | |
/// Read a [Read9p] value using an implementation of [Read] to perform the required IO. | |
fn read_9p_sync_from_bytes<T, R>(r: &mut R) -> io::Result<T> | |
where | |
T: Read9p, | |
R: Read, | |
{ | |
let s = Arc::new(State::default()); | |
let waker = Waker::from(s.clone()); | |
let mut ctx = Context::from_waker(&waker); | |
// SAFETY: assumes the impl of Read9p is a valid future for us to poll | |
let mut fut = unsafe { pin!(T::read()) }; | |
loop { | |
match fut.as_mut().poll(&mut ctx) { | |
Poll::Ready(val) => return val, | |
Poll::Pending => unsafe { | |
let n = s.requested_bytes(); | |
println!("{n} bytes requested"); | |
let mut buf = vec![0; n]; | |
r.read_exact(&mut buf)?; | |
s.set_bytes(buf); | |
}, | |
} | |
} | |
} | |
/// Read a [Read9p] value using an implementation of [AsyncRead] to perform the required IO. | |
async fn read_9p_async_from_bytes<T, R>(r: &mut R) -> io::Result<T> | |
where | |
T: Read9p, | |
R: AsyncRead + Unpin, | |
{ | |
let s = Arc::new(State::default()); | |
let waker = Waker::from(s.clone()); | |
// SAFETY: assumes the impl of Read9p is a valid future for us to poll | |
let mut fut = unsafe { pin!(T::read()) }; | |
loop { | |
match fut.as_mut().poll(&mut Context::from_waker(&waker)) { | |
Poll::Ready(val) => return val, | |
Poll::Pending => unsafe { | |
let n = s.requested_bytes(); | |
println!("{n} bytes requested"); | |
let mut buf = vec![0; n]; | |
r.read_exact(&mut buf).await?; | |
s.set_bytes(buf); | |
}, | |
} | |
} | |
} | |
/// Shared state between a [Read9p] impl and a parent read loop that is performing IO. | |
#[derive(Default, Debug)] | |
struct State { | |
inner: UnsafeCell<StateInner>, | |
} | |
unsafe impl Send for State {} | |
unsafe impl Sync for State {} | |
impl State { | |
unsafe fn requested_bytes(&self) -> usize { | |
unsafe { (*self.inner.get()).n } | |
} | |
unsafe fn set_requested(&self, n: usize) { | |
unsafe { (*self.inner.get()).n = n }; | |
} | |
unsafe fn set_bytes(&self, buf: Vec<u8>) { | |
unsafe { (*self.inner.get()).buf = Some(buf) }; | |
} | |
unsafe fn take_bytes(&self) -> Vec<u8> { | |
unsafe { (*self.inner.get()).buf.take().unwrap_unchecked() } | |
} | |
} | |
#[derive(Default, Debug)] | |
struct StateInner { | |
n: usize, | |
buf: Option<Vec<u8>>, | |
} | |
impl Wake for State { | |
fn wake(self: Arc<Self>) {} | |
fn wake_by_ref(self: &Arc<Self>) {} | |
} | |
/// Helper struct for awaiting a Future that returns pending once so we can return control to the | |
/// poll loop and perform IO. | |
struct RequestBytes { | |
polled: bool, | |
n: usize, | |
} | |
impl Future for RequestBytes { | |
type Output = Vec<u8>; | |
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Vec<u8>> { | |
if self.polled { | |
// SAFETY: we can only poll this future using a waker wrapping State | |
let data = unsafe { | |
(ctx.waker().data() as *mut () as *mut State) | |
.as_mut() | |
.unwrap_unchecked() | |
.take_bytes() | |
}; | |
Poll::Ready(data) | |
} else { | |
self.polled = true; | |
// SAFETY: we can only poll this future using a waker wrapping State | |
unsafe { | |
(ctx.waker().data() as *mut () as *mut State) | |
.as_mut() | |
.unwrap_unchecked() | |
.set_requested(self.n); | |
}; | |
Poll::Pending | |
} | |
} | |
} | |
/// Request a specific number of bytes from the parent poll loop and then yield to that poll loop | |
/// so it can perform IO and provide the requested data. | |
macro_rules! request_bytes { | |
($n:expr) => {{ | |
RequestBytes { | |
polled: false, | |
n: $n, | |
} | |
.await | |
}}; | |
} | |
trait Read9p: Sized { | |
/// This is not a normal async function. It is used to set up a sans-io state machine that | |
/// can be driven by a concrete implementation. | |
/// | |
/// # Safety | |
/// Implementations of `read` need to ensure that the only await points they contain are | |
/// from calls to the [request_bytes] macro. | |
unsafe fn read() -> impl Future<Output = io::Result<Self>> + Send; | |
} | |
impl Read9p for u16 { | |
async unsafe fn read() -> io::Result<u16> { | |
// SAFETY: we only await as part of request_bytes or a nested call to read | |
let n = size_of::<u16>(); | |
let buf = request_bytes!(n); | |
let data = buf[0..n].try_into().unwrap(); | |
Ok(u16::from_le_bytes(data)) | |
} | |
} | |
impl Read9p for String { | |
async unsafe fn read() -> io::Result<Self> { | |
// SAFETY: we only await as part of request_bytes or a nested call to read | |
unsafe { | |
let len = u16::read().await? as usize; | |
let buf = request_bytes!(len); | |
String::from_utf8(buf) | |
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e.to_string())) | |
} | |
} | |
} | |
impl<T: Read9p + Send> Read9p for Vec<T> { | |
async unsafe fn read() -> io::Result<Self> { | |
// SAFETY: we only await as part of request_bytes or a nested call to read | |
unsafe { | |
let len = u16::read().await? as usize; | |
let mut buf = Vec::with_capacity(len); | |
for _ in 0..len { | |
buf.push(T::read().await?); | |
} | |
Ok(buf) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment