Last active
May 18, 2025 11:47
-
-
Save sonthonaxrk/b64bfff39e85393fab74871e11a1d5b7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::thread::sleep; | |
use std::time::Duration; | |
mod scoped_thread { | |
use std::marker::PhantomData; | |
use std::thread::{spawn, JoinHandle}; | |
// Remove default | |
pub struct ThreadScope<'thread> { | |
_marker: PhantomData<&'thread ()>, | |
joins: Vec<JoinHandle<()>>, | |
} | |
impl<'thread> ThreadScope<'thread> { | |
pub fn spawn<F>(&mut self, f: F) | |
where | |
F: FnOnce() + Send + 'thread, | |
{ | |
let func: Box<dyn FnOnce() + Send + 'thread> = Box::new(f); | |
// Here we 'transmute' 'thread to 'static | |
let closure: Box<dyn FnOnce() + Send + 'static> = unsafe { | |
std::mem::transmute(func) | |
}; | |
let handle = spawn(closure); | |
self.joins.push(handle); | |
} | |
} | |
impl<'thread> Drop for ThreadScope<'thread> { | |
fn drop(&mut self) { | |
for handle in self.joins.drain(..) { | |
handle.join().unwrap(); | |
} | |
} | |
} | |
// This is the new function | |
pub fn scope<'thread, F>(f: F) | |
where | |
F: for<'a> FnOnce(&'a mut ThreadScope<'thread>), | |
{ | |
let mut thread_scope = ThreadScope { | |
_marker: PhantomData, | |
joins: Vec::new(), | |
}; | |
f(&mut thread_scope); | |
// user can spawn threads using references bound to `'thread` | |
// Drop happens here, guarantees join before anything bound to `'thread` can end | |
} | |
} | |
fn main() { | |
let val = String::from("hello"); | |
let val_ref = val.as_str(); | |
scoped_thread::scope(|s| { | |
s.spawn(|| { | |
sleep(Duration::from_millis(1)); | |
}); | |
s.spawn(|| { | |
loop { | |
println!("{:?}", &val_ref); | |
sleep(Duration::from_millis(1)); | |
} | |
}); | |
// This is safe because it's forgetting the reference not | |
// the owned value | |
std::mem::forget(s); | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment