Last active
January 18, 2023 16:44
-
-
Save heyrutvik/42e7992f7d7d282f4124c57a226d6b68 to your computer and use it in GitHub Desktop.
The dance of polling and waking in Rust!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Cargo.toml | |
// | |
// [dependencies] | |
// futures = "0.3.25" | |
use std::future::Future; | |
use std::ops::Deref; | |
use std::pin::Pin; | |
use std::sync::{Arc, Mutex}; | |
use std::sync::mpsc::{Receiver, sync_channel, SyncSender}; | |
use std::task::{Context, Poll, Waker}; | |
use std::thread; | |
use std::time::Duration; | |
use futures::FutureExt; | |
use futures::future::BoxFuture; | |
use futures::task::{ArcWake, waker_ref}; | |
struct SharedState { | |
completed: bool, | |
waker: Option<Waker>, | |
} | |
struct Timer { | |
shared_state: Arc<Mutex<SharedState>>, | |
} | |
impl Future for Timer { | |
type Output = (); | |
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | |
let mut shared_state = self.shared_state.lock().unwrap(); | |
if shared_state.completed { | |
Poll::Ready(()) | |
} else { | |
shared_state.waker = Some(cx.waker().clone()); | |
Poll::Pending | |
} | |
} | |
} | |
impl Timer { | |
pub fn new(duration: Duration) -> Self { | |
let shared_state = Arc::new(Mutex::new(SharedState { | |
completed: false, | |
waker: None, | |
})); | |
let thread_shared_state = shared_state.clone(); | |
thread::spawn(move || { | |
thread::sleep(duration); | |
let mut shared_state = thread_shared_state.lock().unwrap(); | |
shared_state.completed = true; | |
if let Some(waker) = shared_state.waker.take() { | |
waker.wake(); | |
} | |
}); | |
Timer { shared_state } | |
} | |
} | |
struct Task { | |
future: Mutex<Option<BoxFuture<'static, ()>>>, | |
task_sender: SyncSender<Arc<Task>>, | |
} | |
impl ArcWake for Task { | |
fn wake_by_ref(arc_self: &Arc<Self>) { | |
let cloned = arc_self.clone(); | |
arc_self | |
.task_sender | |
.send(cloned) | |
.expect("too many tasks queued"); | |
} | |
} | |
struct Spawner { | |
task_sender: SyncSender<Arc<Task>>, | |
} | |
impl Spawner { | |
fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) { | |
let future = future.boxed(); | |
let task = Arc::new(Task { | |
future: Mutex::new(Some(future)), | |
task_sender: self.task_sender.clone(), | |
}); | |
self.task_sender.send(task).expect("too many tasks queued"); | |
} | |
} | |
struct Executor { | |
ready_queue: Receiver<Arc<Task>>, | |
} | |
impl Executor { | |
fn run(&self) { | |
while let Ok(task) = self.ready_queue.recv() { | |
let mut future_slot = task.future.lock().unwrap(); | |
if let Some(mut future) = future_slot.take() { | |
let waker = waker_ref(&task); | |
let context = &mut Context::from_waker(waker.deref()); | |
if future.as_mut().poll(context).is_pending() { | |
*future_slot = Some(future); | |
} | |
} | |
} | |
} | |
} | |
fn main() { | |
const MAX_QUEUED_TASKS: usize = 10_000; | |
let (task_sender, ready_queue) = sync_channel(MAX_QUEUED_TASKS); | |
let executor = Executor { ready_queue }; | |
let spawner = Spawner { task_sender }; | |
spawner.spawn(async { | |
println!("spawned!"); | |
Timer::new(Duration::new(3, 0)).await; | |
println!("completed!"); | |
}); | |
drop(spawner); | |
executor.run(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment