Last active
May 1, 2018 12:07
-
-
Save agronholm/1abbeccf3a922084a7748e85b58c2d8b to your computer and use it in GitHub Desktop.
Asyncio nursery
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from asyncio import get_event_loop, gather, Task, wait | |
from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor | |
from functools import wraps, partial | |
from typing import Callable, Union, TypeVar, Awaitable, List, Dict | |
T_Retval = TypeVar('T_Retval') | |
class MultiError(Exception): | |
""" | |
Raised when multiple exceptions have been raised in tasks running in parallel. | |
:ivar exceptions: the exceptions that were raised | |
:vartype exceptions: List[Exception] | |
""" | |
separator = '----------------------------\n' | |
template = '{msg}:\n{separator}{tracebacks}' | |
def __init__(self, exceptions: List[Exception], msg: str) -> None: | |
tracebacks = self.separator.join( | |
'\n'.join(format_exception(type(exc), exc, exc.__traceback__)) | |
for exc in exceptions) | |
super().__init__(self.template.format(msg=msg, separator=self.separator, | |
tracebacks=tracebacks)) | |
self.exceptions = exceptions | |
class NurseryException(Exception): | |
"""The superclass for all nursery related exceptions.""" | |
class NurseryClosed(NurseryException): | |
"""Raised when trying to start a new task but the nursery has been closed.""" | |
def __init__(self) -> None: | |
super().__init__('This nursery has already been closed') | |
class Nursery: | |
__slots__ = '_closed', '_host_task', '_futures', '_executors' | |
def __init__(self, executors: Dict[str, Executor] = None) -> None: | |
self._closed = False | |
self._host_task = None # type: Task | |
self._futures = set() # type: Set[Future] | |
self._executors = executors or {} | |
if 'threadpool' not in self._executors: | |
self.add_executor(ThreadPoolExecutor(), 'threadpool') | |
elif not isinstance(self._executors['threadpool'], ThreadPoolExecutor): | |
raise TypeError('the "threadpool" executor must be a ThreadPoolExecutor') | |
if 'processpool' not in self._executors: | |
self.add_executor(ProcessPoolExecutor(), 'processpool') | |
elif not isinstance(self._executors['processpool'], ProcessPoolExecutor): | |
raise TypeError('the "processpool" executor must be a ProcessPoolExecutor') | |
def _check_closed(self) -> None: | |
if self._closed: | |
raise NurseryClosed | |
def _task_finished(self, task: Task) -> None: | |
# If a task raised an exception (other than CancelledError), cancel all other tasks | |
if not task.cancelled() and task.exception(): | |
if not self._closed: | |
self._closed = True | |
self._host_task.cancel() | |
for task in self._futures: | |
task.cancel() | |
else: | |
self._futures.discard(task) | |
def add_executor(self, executor: Executor, name: str) -> None: | |
self._check_closed() | |
if self._executors.setdefault(name, executor) is not executor: | |
raise NurseryException( | |
'This nursery already has an executor by the name of {!r}'.format(name)) | |
async def call_in_thread(self, func: Callable[..., T_Retval], *args, **kwargs) -> T_Retval: | |
return await self.start_in_executor('threadpool', func, *args, **kwargs) | |
async def call_in_subprocess(self, func: Callable[..., T_Retval], *args, **kwargs) -> T_Retval: | |
return await self.start_in_executor('processpool', func, *args, **kwargs) | |
async def call_in_executor(self, executor: str, func: Callable[..., T_Retval], *args, | |
**kwargs) -> T_Retval: | |
return await self.start_in_executor(executor, func, *args, **kwargs) | |
def start_in_thread(self, func: Callable, *args, **kwargs) -> Future: | |
return self.start_in_executor('threadpool', func, *args, **kwargs) | |
def start_in_subprocess(self, func: Callable, *args, **kwargs) -> Future: | |
return self.start_in_executor('processpool', func, *args, **kwargs) | |
def start_in_executor(self, executor: str, func: Callable, *args, **kwargs) -> Future: | |
try: | |
executor_ = self._executors[executor] | |
except KeyError: | |
raise NurseryException( | |
'No such executor in this nursery: {}'.format(executor)) from None | |
func = partial(func, **kwargs) if kwargs else func | |
return self.start_task(get_event_loop().run_in_executor, executor_, func, *args) | |
def start_task(self, func: Callable[..., Awaitable], *args, **kwargs) -> Future: | |
self._check_closed() | |
future = ensure_future(func(*args, **kwargs)) | |
self._futures.add(future) | |
future.add_done_callback(self._task_finished) | |
return future | |
async def __aenter__(self) -> 'Nursery': | |
self._check_closed() | |
self._host_task = Task.current_task() | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: | |
exceptions = [] if exc_val is None else [exc_val] | |
if self._closed: | |
# A subtask raised an exception | |
if exc_type is CancelledError: | |
exceptions.clear() | |
elif exc_val is not None: | |
# The host task raised an exception or was cancelled externally | |
for task in self._futures: | |
task.cancel() | |
self._closed = True | |
cancel_exception = None | |
while True: | |
try: | |
if self._executors: | |
executors, self._executors = self._executors, None | |
with ThreadPoolExecutor(max_workers=len(executors)) as temp_pool: | |
loop = get_event_loop() | |
coros = [loop.run_in_executor(temp_pool, executor.shutdown) | |
for executor in executors.values()] | |
await gather(*coros) | |
if self._futures: | |
futures, self._futures = self._futures, None | |
done, _pending = await wait(futures) | |
exceptions.extend([f.exception() for f in done | |
if not f.cancelled() and f.exception()]) | |
break | |
except CancelledError as e: | |
cancel_exception = e | |
for future in self._futures: | |
future.cancel() | |
if len(exceptions) > 1: | |
raise MultiError(exceptions, 'multiple tasks failed') | |
elif len(exceptions) == 1: | |
raise exceptions[0] | |
elif cancel_exception: | |
raise cancel_exception |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TestNursery: | |
@staticmethod | |
def sync_sum(x, y, delay=0.1): | |
if delay: | |
sleep(delay) | |
return x + y | |
@staticmethod | |
def error(text, delay=0.1): | |
if delay: | |
sleep(delay) | |
raise Exception(text) | |
@staticmethod | |
async def async_sum(x, y, delay=0.2): | |
await asyncio.sleep(delay) | |
return x + y | |
@staticmethod | |
async def async_error(text, delay=0.1): | |
try: | |
if delay: | |
await asyncio.sleep(delay) | |
finally: | |
raise Exception(text) | |
@pytest.mark.asyncio | |
async def test_success(self): | |
async with Nursery() as nursery: | |
f1 = nursery.start_task(self.async_sum, 1, 2) | |
f2 = nursery.start_in_thread(self.sync_sum, 1, 2) | |
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2) | |
for f in f1, f2, f3: | |
assert f.result() == 3 | |
@pytest.mark.asyncio | |
async def test_host_exception(self): | |
with pytest.raises(Exception) as exc: | |
async with Nursery() as nursery: | |
f1 = nursery.start_task(self.async_sum, 1, 2) | |
f2 = nursery.start_in_thread(self.sync_sum, 1, 2) | |
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2) | |
raise Exception('dummy error') | |
exc.match('dummy error') | |
for f in f1, f2, f3: | |
assert f.cancelled() | |
@pytest.mark.asyncio | |
async def test_task_cancelled(self): | |
async with Nursery() as nursery: | |
f1 = nursery.start_task(self.async_sum, 1, 2) | |
f2 = nursery.start_in_thread(self.sync_sum, 1, 2) | |
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2) | |
f2.cancel() | |
assert f2.cancelled() | |
for f in f1, f3: | |
assert f.result() == 3 | |
@pytest.mark.asyncio | |
async def test_host_cancelled_before_aexit(self): | |
with pytest.raises(asyncio.CancelledError): | |
async with Nursery() as nursery: | |
f1 = nursery.start_task(self.async_sum, 1, 2) | |
f2 = nursery.start_in_thread(self.sync_sum, 1, 2) | |
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2) | |
raise asyncio.CancelledError | |
for f in f1, f2, f3: | |
assert f.cancelled() | |
@pytest.mark.asyncio | |
async def test_host_cancelled_during_aexit(self, event_loop): | |
with pytest.raises(asyncio.CancelledError): | |
async with Nursery() as nursery: | |
f1 = nursery.start_task(self.async_sum, 1, 2) | |
f2 = nursery.start_in_thread(self.sync_sum, 1, 2) | |
f3 = nursery.start_in_subprocess(self.sync_sum, 1, 2) | |
event_loop.call_soon(asyncio.Task.current_task().cancel) | |
for f in f1, f2, f3: | |
assert f.cancelled() | |
@pytest.mark.asyncio | |
async def test_multi_error(self): | |
with pytest.raises(MultiError) as exc: | |
async with Nursery() as nursery: | |
f1 = nursery.start_task(self.async_error, 'async', delay=2) | |
f2 = nursery.start_in_thread(self.error, 'thread', delay=1) | |
f3 = nursery.start_in_subprocess(self.error, 'process') | |
assert len(exc.value.exceptions) == 2 | |
assert f1.exception() in exc.value.exceptions | |
assert f2.cancelled() | |
assert f3.exception() in exc.value.exceptions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment