Skip to content

Instantly share code, notes, and snippets.

@agronholm
Last active May 1, 2018 12:07
Show Gist options
  • Save agronholm/1abbeccf3a922084a7748e85b58c2d8b to your computer and use it in GitHub Desktop.
Save agronholm/1abbeccf3a922084a7748e85b58c2d8b to your computer and use it in GitHub Desktop.
Asyncio nursery
import inspect
from asyncio import get_event_loop, gather, Task, wait
from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from functools import wraps, partial
from typing import Callable, Union, TypeVar, Awaitable, List, Dict
T_Retval = TypeVar('T_Retval')
class MultiError(Exception):
"""
Raised when multiple exceptions have been raised in tasks running in parallel.
:ivar exceptions: the exceptions that were raised
:vartype exceptions: List[Exception]
"""
separator = '----------------------------\n'
template = '{msg}:\n{separator}{tracebacks}'
def __init__(self, exceptions: List[Exception], msg: str) -> None:
tracebacks = self.separator.join(
'\n'.join(format_exception(type(exc), exc, exc.__traceback__))
for exc in exceptions)
super().__init__(self.template.format(msg=msg, separator=self.separator,
tracebacks=tracebacks))
self.exceptions = exceptions
class NurseryException(Exception):
"""The superclass for all nursery related exceptions."""
class NurseryClosed(NurseryException):
"""Raised when trying to start a new task but the nursery has been closed."""
def __init__(self) -> None:
super().__init__('This nursery has already been closed')
class Nursery:
__slots__ = '_closed', '_host_task', '_futures', '_executors'
def __init__(self, executors: Dict[str, Executor] = None) -> None:
self._closed = False
self._host_task = None # type: Task
self._futures = set() # type: Set[Future]
self._executors = executors or {}
if 'threadpool' not in self._executors:
self.add_executor(ThreadPoolExecutor(), 'threadpool')
elif not isinstance(self._executors['threadpool'], ThreadPoolExecutor):
raise TypeError('the "threadpool" executor must be a ThreadPoolExecutor')
if 'processpool' not in self._executors:
self.add_executor(ProcessPoolExecutor(), 'processpool')
elif not isinstance(self._executors['processpool'], ProcessPoolExecutor):
raise TypeError('the "processpool" executor must be a ProcessPoolExecutor')
def _check_closed(self) -> None:
if self._closed:
raise NurseryClosed
def _task_finished(self, task: Task) -> None:
# If a task raised an exception (other than CancelledError), cancel all other tasks
if not task.cancelled() and task.exception():
if not self._closed:
self._closed = True
self._host_task.cancel()
for task in self._futures:
task.cancel()
else:
self._futures.discard(task)
def add_executor(self, executor: Executor, name: str) -> None:
self._check_closed()
if self._executors.setdefault(name, executor) is not executor:
raise NurseryException(
'This nursery already has an executor by the name of {!r}'.format(name))
async def call_in_thread(self, func: Callable[..., T_Retval], *args, **kwargs) -> T_Retval:
return await self.start_in_executor('threadpool', func, *args, **kwargs)
async def call_in_subprocess(self, func: Callable[..., T_Retval], *args, **kwargs) -> T_Retval:
return await self.start_in_executor('processpool', func, *args, **kwargs)
async def call_in_executor(self, executor: str, func: Callable[..., T_Retval], *args,
**kwargs) -> T_Retval:
return await self.start_in_executor(executor, func, *args, **kwargs)
def start_in_thread(self, func: Callable, *args, **kwargs) -> Future:
return self.start_in_executor('threadpool', func, *args, **kwargs)
def start_in_subprocess(self, func: Callable, *args, **kwargs) -> Future:
return self.start_in_executor('processpool', func, *args, **kwargs)
def start_in_executor(self, executor: str, func: Callable, *args, **kwargs) -> Future:
try:
executor_ = self._executors[executor]
except KeyError:
raise NurseryException(
'No such executor in this nursery: {}'.format(executor)) from None
func = partial(func, **kwargs) if kwargs else func
return self.start_task(get_event_loop().run_in_executor, executor_, func, *args)
def start_task(self, func: Callable[..., Awaitable], *args, **kwargs) -> Future:
self._check_closed()
future = ensure_future(func(*args, **kwargs))
self._futures.add(future)
future.add_done_callback(self._task_finished)
return future
async def __aenter__(self) -> 'Nursery':
self._check_closed()
self._host_task = Task.current_task()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
exceptions = [] if exc_val is None else [exc_val]
if self._closed:
# A subtask raised an exception
if exc_type is CancelledError:
exceptions.clear()
elif exc_val is not None:
# The host task raised an exception or was cancelled externally
for task in self._futures:
task.cancel()
self._closed = True
cancel_exception = None
while True:
try:
if self._executors:
executors, self._executors = self._executors, None
with ThreadPoolExecutor(max_workers=len(executors)) as temp_pool:
loop = get_event_loop()
coros = [loop.run_in_executor(temp_pool, executor.shutdown)
for executor in executors.values()]
await gather(*coros)
if self._futures:
futures, self._futures = self._futures, None
done, _pending = await wait(futures)
exceptions.extend([f.exception() for f in done
if not f.cancelled() and f.exception()])
break
except CancelledError as e:
cancel_exception = e
for future in self._futures:
future.cancel()
if len(exceptions) > 1:
raise MultiError(exceptions, 'multiple tasks failed')
elif len(exceptions) == 1:
raise exceptions[0]
elif cancel_exception:
raise cancel_exception
class TestNursery:
@staticmethod
def sync_sum(x, y, delay=0.1):
if delay:
sleep(delay)
return x + y
@staticmethod
def error(text, delay=0.1):
if delay:
sleep(delay)
raise Exception(text)
@staticmethod
async def async_sum(x, y, delay=0.2):
await asyncio.sleep(delay)
return x + y
@staticmethod
async def async_error(text, delay=0.1):
try:
if delay:
await asyncio.sleep(delay)
finally:
raise Exception(text)
@pytest.mark.asyncio
async def test_success(self):
async with Nursery() as nursery:
f1 = nursery.start_task(self.async_sum, 1, 2)
f2 = nursery.start_in_thread(self.sync_sum, 1, 2)
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2)
for f in f1, f2, f3:
assert f.result() == 3
@pytest.mark.asyncio
async def test_host_exception(self):
with pytest.raises(Exception) as exc:
async with Nursery() as nursery:
f1 = nursery.start_task(self.async_sum, 1, 2)
f2 = nursery.start_in_thread(self.sync_sum, 1, 2)
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2)
raise Exception('dummy error')
exc.match('dummy error')
for f in f1, f2, f3:
assert f.cancelled()
@pytest.mark.asyncio
async def test_task_cancelled(self):
async with Nursery() as nursery:
f1 = nursery.start_task(self.async_sum, 1, 2)
f2 = nursery.start_in_thread(self.sync_sum, 1, 2)
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2)
f2.cancel()
assert f2.cancelled()
for f in f1, f3:
assert f.result() == 3
@pytest.mark.asyncio
async def test_host_cancelled_before_aexit(self):
with pytest.raises(asyncio.CancelledError):
async with Nursery() as nursery:
f1 = nursery.start_task(self.async_sum, 1, 2)
f2 = nursery.start_in_thread(self.sync_sum, 1, 2)
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2)
raise asyncio.CancelledError
for f in f1, f2, f3:
assert f.cancelled()
@pytest.mark.asyncio
async def test_host_cancelled_during_aexit(self, event_loop):
with pytest.raises(asyncio.CancelledError):
async with Nursery() as nursery:
f1 = nursery.start_task(self.async_sum, 1, 2)
f2 = nursery.start_in_thread(self.sync_sum, 1, 2)
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2)
event_loop.call_soon(asyncio.Task.current_task().cancel)
for f in f1, f2, f3:
assert f.cancelled()
@pytest.mark.asyncio
async def test_multi_error(self):
with pytest.raises(MultiError) as exc:
async with Nursery() as nursery:
f1 = nursery.start_task(self.async_error, 'async', delay=2)
f2 = nursery.start_in_thread(self.error, 'thread', delay=1)
f3 = nursery.start_in_subprocess(self.error, 'process')
assert len(exc.value.exceptions) == 2
assert f1.exception() in exc.value.exceptions
assert f2.cancelled()
assert f3.exception() in exc.value.exceptions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment