Skip to content

Instantly share code, notes, and snippets.

@heyrutvik
Last active January 18, 2023 16:44
Show Gist options
  • Save heyrutvik/42e7992f7d7d282f4124c57a226d6b68 to your computer and use it in GitHub Desktop.
Save heyrutvik/42e7992f7d7d282f4124c57a226d6b68 to your computer and use it in GitHub Desktop.
The dance of polling and waking in Rust!
// Cargo.toml
//
// [dependencies]
// futures = "0.3.25"
use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{Receiver, sync_channel, SyncSender};
use std::task::{Context, Poll, Waker};
use std::thread;
use std::time::Duration;
use futures::FutureExt;
use futures::future::BoxFuture;
use futures::task::{ArcWake, waker_ref};
struct SharedState {
completed: bool,
waker: Option<Waker>,
}
struct Timer {
shared_state: Arc<Mutex<SharedState>>,
}
impl Future for Timer {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.shared_state.lock().unwrap();
if shared_state.completed {
Poll::Ready(())
} else {
shared_state.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
impl Timer {
pub fn new(duration: Duration) -> Self {
let shared_state = Arc::new(Mutex::new(SharedState {
completed: false,
waker: None,
}));
let thread_shared_state = shared_state.clone();
thread::spawn(move || {
thread::sleep(duration);
let mut shared_state = thread_shared_state.lock().unwrap();
shared_state.completed = true;
if let Some(waker) = shared_state.waker.take() {
waker.wake();
}
});
Timer { shared_state }
}
}
struct Task {
future: Mutex<Option<BoxFuture<'static, ()>>>,
task_sender: SyncSender<Arc<Task>>,
}
impl ArcWake for Task {
fn wake_by_ref(arc_self: &Arc<Self>) {
let cloned = arc_self.clone();
arc_self
.task_sender
.send(cloned)
.expect("too many tasks queued");
}
}
struct Spawner {
task_sender: SyncSender<Arc<Task>>,
}
impl Spawner {
fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) {
let future = future.boxed();
let task = Arc::new(Task {
future: Mutex::new(Some(future)),
task_sender: self.task_sender.clone(),
});
self.task_sender.send(task).expect("too many tasks queued");
}
}
struct Executor {
ready_queue: Receiver<Arc<Task>>,
}
impl Executor {
fn run(&self) {
while let Ok(task) = self.ready_queue.recv() {
let mut future_slot = task.future.lock().unwrap();
if let Some(mut future) = future_slot.take() {
let waker = waker_ref(&task);
let context = &mut Context::from_waker(waker.deref());
if future.as_mut().poll(context).is_pending() {
*future_slot = Some(future);
}
}
}
}
}
fn main() {
const MAX_QUEUED_TASKS: usize = 10_000;
let (task_sender, ready_queue) = sync_channel(MAX_QUEUED_TASKS);
let executor = Executor { ready_queue };
let spawner = Spawner { task_sender };
spawner.spawn(async {
println!("spawned!");
Timer::new(Duration::new(3, 0)).await;
println!("completed!");
});
drop(spawner);
executor.run();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment