Created
March 2, 2019 22:16
-
-
Save oliver-giersch/8878d769e47bd97b96aa6833f01d91eb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(thread_spawn_unchecked)] | |
use std::any::Any; | |
use std::io; | |
use std::marker::PhantomData; | |
use std::ptr::NonNull; | |
use std::sync::Mutex; | |
use std::thread::Thread; | |
use std::thread::{self, JoinHandle}; | |
pub fn scope<'env, F, R>(f: F) -> Result<R, Box<[Box<dyn Any + Send + 'static>]>> | |
where | |
F: for<'scope> FnOnce(&'scope Scope) -> R + 'env, | |
{ | |
let scope = Scope::new(); | |
// executes specified closure (may panic) | |
let res = f(&scope); | |
// join all threads that have not been joined as part of the scope closure | |
let panics = scope | |
.joins | |
.lock() | |
.unwrap() | |
.iter_mut() | |
.filter_map(|handle| handle.join_mut().err()) | |
.collect::<Vec<_>>() | |
.into_boxed_slice(); | |
if panics.len() == 0 { | |
Ok(res) | |
} else { | |
Err(panics) | |
} | |
} | |
unsafe impl<'env> Send for Scope<'env> {} | |
unsafe impl<'env> Sync for Scope<'env> {} | |
pub struct Scope<'env> { | |
joins: Mutex<Vec<Box<dyn Join + 'env>>>, | |
_marker: PhantomData<&'env ()>, | |
} | |
impl<'env> Scope<'env> { | |
pub fn builder(&self) -> ScopedThreadBuilder<'env, '_> { | |
ScopedThreadBuilder(self, thread::Builder::new()) | |
} | |
pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T> | |
where | |
F: FnOnce() -> T + Send + 'env, | |
T: Send + 'env, | |
'env: 'scope, | |
{ | |
self.builder().spawn(f).unwrap() | |
} | |
fn new() -> Self { | |
Self { | |
joins: Mutex::new(Vec::new()), | |
_marker: PhantomData, | |
} | |
} | |
} | |
impl<'env> Drop for Scope<'env> { | |
fn drop(&mut self) { | |
for handle in &mut *self.joins.lock().unwrap() { | |
let _res = handle.join_mut(); | |
} | |
} | |
} | |
pub struct ScopedThreadBuilder<'env, 'scope>(&'scope Scope<'env>, thread::Builder); | |
impl<'env, 'scope> ScopedThreadBuilder<'env, 'scope> | |
where | |
'env: 'scope, | |
{ | |
pub fn name(mut self, name: String) -> Self { | |
self.1 = self.1.name(name); | |
self | |
} | |
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<'scope, T>> | |
where | |
F: FnOnce() -> T + Send + 'env, | |
T: Send + 'env | |
{ | |
let ScopedThreadBuilder(scope, builder) = self; | |
let res = unsafe { builder.spawn_unchecked(f) }; | |
res.map(|join| { | |
let mut boxed = Box::new(Some(join)); | |
let handle = NonNull::from(&mut *boxed); | |
scope.joins.lock().unwrap().push(boxed); | |
ScopedJoinHandle { | |
handle, | |
_marker: PhantomData, | |
} | |
}) | |
} | |
} | |
pub struct ScopedJoinHandle<'scope, T> { | |
handle: NonNull<Option<JoinHandle<T>>>, | |
_marker: PhantomData<&'scope mut Option<JoinHandle<T>>>, | |
} | |
impl<'scope, T> ScopedJoinHandle<'scope, T> { | |
pub fn join(self) -> thread::Result<T> { | |
let mut ptr = self.handle; | |
let handle = unsafe { ptr.as_mut().take().unwrap() }; | |
handle.join() | |
} | |
pub fn thread(&self) -> &Thread { | |
let handle = unsafe { self.handle.as_ref().as_ref().unwrap() }; | |
handle.thread() | |
} | |
} | |
trait Join { | |
fn join_mut(&mut self) -> thread::Result<()>; | |
} | |
impl<T> Join for Option<JoinHandle<T>> { | |
fn join_mut(&mut self) -> thread::Result<()> { | |
self.take() | |
.map(|handle| handle.join().map(|_| ())) | |
.unwrap_or(Ok(())) | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use std::sync::Mutex; | |
use super::*; | |
#[test] | |
fn simple() { | |
let mut a = 0; | |
scope(|scope| { | |
scope.spawn(|| { | |
a = 1; | |
}); | |
}).unwrap(); | |
assert_eq!(a, 1); | |
} | |
#[test] | |
fn multiple_writers() { | |
let count = Mutex::new(0); | |
scope(|scope| { | |
scope.spawn(|| *count.lock().unwrap() += 1); | |
scope.spawn(|| *count.lock().unwrap() += 1); | |
scope.spawn(|| *count.lock().unwrap() += 1); | |
scope.spawn(|| *count.lock().unwrap() += 1); | |
scope.spawn(|| *count.lock().unwrap() += 1); | |
}).unwrap(); | |
assert_eq!(*count.lock().unwrap(), 5); | |
} | |
#[test] | |
fn manual_join() { | |
let count = Mutex::new(0); | |
scope(|scope| { | |
let handles = (0..5) | |
.map(|_| scope.spawn(|| *count.lock().unwrap() += 1)) | |
.collect::<Vec<_>>(); | |
for handle in handles { | |
let _ = handle.join(); | |
} | |
assert_eq!(*count.lock().unwrap(), 5); | |
}).unwrap(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment