Skip to content

Instantly share code, notes, and snippets.

@oliver-giersch
Created March 2, 2019 22:16
Show Gist options
  • Save oliver-giersch/8878d769e47bd97b96aa6833f01d91eb to your computer and use it in GitHub Desktop.
Save oliver-giersch/8878d769e47bd97b96aa6833f01d91eb to your computer and use it in GitHub Desktop.
#![feature(thread_spawn_unchecked)]
use std::any::Any;
use std::io;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::sync::Mutex;
use std::thread::Thread;
use std::thread::{self, JoinHandle};
pub fn scope<'env, F, R>(f: F) -> Result<R, Box<[Box<dyn Any + Send + 'static>]>>
where
F: for<'scope> FnOnce(&'scope Scope) -> R + 'env,
{
let scope = Scope::new();
// executes specified closure (may panic)
let res = f(&scope);
// join all threads that have not been joined as part of the scope closure
let panics = scope
.joins
.lock()
.unwrap()
.iter_mut()
.filter_map(|handle| handle.join_mut().err())
.collect::<Vec<_>>()
.into_boxed_slice();
if panics.len() == 0 {
Ok(res)
} else {
Err(panics)
}
}
unsafe impl<'env> Send for Scope<'env> {}
unsafe impl<'env> Sync for Scope<'env> {}
pub struct Scope<'env> {
joins: Mutex<Vec<Box<dyn Join + 'env>>>,
_marker: PhantomData<&'env ()>,
}
impl<'env> Scope<'env> {
pub fn builder(&self) -> ScopedThreadBuilder<'env, '_> {
ScopedThreadBuilder(self, thread::Builder::new())
}
pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'env,
T: Send + 'env,
'env: 'scope,
{
self.builder().spawn(f).unwrap()
}
fn new() -> Self {
Self {
joins: Mutex::new(Vec::new()),
_marker: PhantomData,
}
}
}
impl<'env> Drop for Scope<'env> {
fn drop(&mut self) {
for handle in &mut *self.joins.lock().unwrap() {
let _res = handle.join_mut();
}
}
}
pub struct ScopedThreadBuilder<'env, 'scope>(&'scope Scope<'env>, thread::Builder);
impl<'env, 'scope> ScopedThreadBuilder<'env, 'scope>
where
'env: 'scope,
{
pub fn name(mut self, name: String) -> Self {
self.1 = self.1.name(name);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'env,
T: Send + 'env
{
let ScopedThreadBuilder(scope, builder) = self;
let res = unsafe { builder.spawn_unchecked(f) };
res.map(|join| {
let mut boxed = Box::new(Some(join));
let handle = NonNull::from(&mut *boxed);
scope.joins.lock().unwrap().push(boxed);
ScopedJoinHandle {
handle,
_marker: PhantomData,
}
})
}
}
pub struct ScopedJoinHandle<'scope, T> {
handle: NonNull<Option<JoinHandle<T>>>,
_marker: PhantomData<&'scope mut Option<JoinHandle<T>>>,
}
impl<'scope, T> ScopedJoinHandle<'scope, T> {
pub fn join(self) -> thread::Result<T> {
let mut ptr = self.handle;
let handle = unsafe { ptr.as_mut().take().unwrap() };
handle.join()
}
pub fn thread(&self) -> &Thread {
let handle = unsafe { self.handle.as_ref().as_ref().unwrap() };
handle.thread()
}
}
trait Join {
fn join_mut(&mut self) -> thread::Result<()>;
}
impl<T> Join for Option<JoinHandle<T>> {
fn join_mut(&mut self) -> thread::Result<()> {
self.take()
.map(|handle| handle.join().map(|_| ()))
.unwrap_or(Ok(()))
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
#[test]
fn simple() {
let mut a = 0;
scope(|scope| {
scope.spawn(|| {
a = 1;
});
}).unwrap();
assert_eq!(a, 1);
}
#[test]
fn multiple_writers() {
let count = Mutex::new(0);
scope(|scope| {
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
scope.spawn(|| *count.lock().unwrap() += 1);
}).unwrap();
assert_eq!(*count.lock().unwrap(), 5);
}
#[test]
fn manual_join() {
let count = Mutex::new(0);
scope(|scope| {
let handles = (0..5)
.map(|_| scope.spawn(|| *count.lock().unwrap() += 1))
.collect::<Vec<_>>();
for handle in handles {
let _ = handle.join();
}
assert_eq!(*count.lock().unwrap(), 5);
}).unwrap();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment