Last active
March 21, 2024 15:56
-
-
Save arthur-tacca/0f407b901c41ee9ff6733499f4989c84 to your computer and use it in GitHub Desktop.
Possible modifications to asyncio.TaskGroup class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted with permission from the EdgeDB project; | |
# license: PSFL. | |
__all__ = ("TaskGroup",) | |
from . import events | |
from . import exceptions | |
from . import tasks | |
class TaskGroup: | |
"""Asynchronous context manager for managing groups of tasks. | |
Example use: | |
async with asyncio.TaskGroup() as group: | |
task1 = group.create_task(some_coroutine(...)) | |
task2 = group.create_task(other_coroutine(...)) | |
print("Both tasks have completed now.") | |
All tasks are awaited when the context manager exits. | |
Any exceptions other than `asyncio.CancelledError` raised within | |
a task will cancel all remaining tasks and wait for them to exit. | |
The exceptions are then combined and raised as an `ExceptionGroup`. | |
""" | |
def __init__(self): | |
self._entered = False | |
self._exiting = False | |
self._aborting = False | |
self._loop = None | |
self._parent_task = None | |
self._parent_cancel_requested = False | |
self._tasks = set() | |
self._errors = [] | |
self._base_error = None | |
self._on_completed_fut = None | |
def __repr__(self): | |
info = [''] | |
if self._tasks: | |
info.append(f'tasks={len(self._tasks)}') | |
if self._errors: | |
info.append(f'errors={len(self._errors)}') | |
if self._aborting: | |
info.append('cancelling') | |
elif self._entered: | |
info.append('entered') | |
info_str = ' '.join(info) | |
return f'<TaskGroup{info_str}>' | |
async def __aenter__(self): | |
if self._entered: | |
raise RuntimeError( | |
f"TaskGroup {self!r} has already been entered") | |
if self._loop is None: | |
self._loop = events.get_running_loop() | |
self._parent_task = tasks.current_task(self._loop) | |
if self._parent_task is None: | |
raise RuntimeError( | |
f'TaskGroup {self!r} cannot determine the parent task') | |
self._entered = True | |
return self | |
async def __aexit__(self, et, exc, tb): | |
self._exiting = True | |
uncancel_called = False | |
if (exc is not None and | |
self._is_base_error(exc) and | |
self._base_error is None): | |
self._base_error = exc | |
propagate_cancellation_error = \ | |
exc if et is exceptions.CancelledError else None | |
if self._parent_cancel_requested: | |
# If this flag is set we *must* call uncancel(). | |
uncancel_called = True | |
if self._parent_task.uncancel() == 0: | |
# If there are no pending cancellations left, | |
# don't propagate CancelledError. | |
propagate_cancellation_error = None | |
if et is not None: | |
if not self._aborting: | |
# Our parent task is being cancelled: | |
# | |
# async with TaskGroup() as g: | |
# g.create_task(...) | |
# await ... # <- CancelledError | |
# | |
# or there's an exception in "async with": | |
# | |
# async with TaskGroup() as g: | |
# g.create_task(...) | |
# 1 / 0 | |
# | |
self._abort() | |
# We use while-loop here because "self._on_completed_fut" | |
# can be cancelled multiple times if our parent task | |
# is being cancelled repeatedly (or even once, when | |
# our own cancellation is already in progress) | |
while self._tasks: | |
if self._on_completed_fut is None: | |
self._on_completed_fut = self._loop.create_future() | |
try: | |
await self._on_completed_fut | |
except exceptions.CancelledError as ex: | |
if not self._aborting: | |
# Our parent task is being cancelled: | |
# | |
# async def wrapper(): | |
# async with TaskGroup() as g: | |
# g.create_task(foo) | |
# | |
# "wrapper" is being cancelled while "foo" is | |
# still running. | |
propagate_cancellation_error = ex | |
self._abort() | |
self._on_completed_fut = None | |
assert not self._tasks | |
if self._base_error is not None: | |
raise self._base_error | |
if self._parent_cancel_requested and not uncancel_called: | |
# If this flag is set we *must* call uncancel(). | |
uncancel_called = True | |
if self._parent_task.uncancel() == 0: | |
# If there are no pending cancellations left, | |
# don't propagate CancelledError. | |
propagate_cancellation_error = None | |
# Propagate CancelledError if there is one, except if there | |
# are other errors -- those have priority. | |
if propagate_cancellation_error and not self._errors: | |
raise propagate_cancellation_error | |
if et is not None and et is not exceptions.CancelledError: | |
self._errors.append(exc) | |
if self._errors: | |
if self._parent_task.cancelling() > 0: | |
self._parent_task.uncancel() | |
self._parent_task.cancel() | |
# Exceptions are heavy objects that can have object | |
# cycles (bad for GC); let's not keep a reference to | |
# a bunch of them. | |
try: | |
me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors) | |
raise me from None | |
finally: | |
self._errors = None | |
def create_task(self, coro, *, name=None, context=None): | |
"""Create a new task in this group and return it. | |
Similar to `asyncio.create_task`. | |
""" | |
if not self._entered: | |
raise RuntimeError(f"TaskGroup {self!r} has not been entered") | |
if self._exiting and not self._tasks: | |
raise RuntimeError(f"TaskGroup {self!r} is finished") | |
if self._aborting: | |
raise RuntimeError(f"TaskGroup {self!r} is shutting down") | |
if context is None: | |
task = self._loop.create_task(coro) | |
else: | |
task = self._loop.create_task(coro, context=context) | |
tasks._set_task_name(task, name) | |
# optimization: Immediately call the done callback if the task is | |
# already done (e.g. if the coro was able to complete eagerly), | |
# and skip scheduling a done callback | |
if task.done(): | |
self._on_task_done(task) | |
else: | |
self._tasks.add(task) | |
task.add_done_callback(self._on_task_done) | |
return task | |
# Since Python 3.8 Tasks propagate all exceptions correctly, | |
# except for KeyboardInterrupt and SystemExit which are | |
# still considered special. | |
def _is_base_error(self, exc: BaseException) -> bool: | |
assert isinstance(exc, BaseException) | |
return isinstance(exc, (SystemExit, KeyboardInterrupt)) | |
def _abort(self): | |
self._aborting = True | |
for t in self._tasks: | |
if not t.done(): | |
t.cancel() | |
def _on_task_done(self, task): | |
self._tasks.discard(task) | |
if self._on_completed_fut is not None and not self._tasks: | |
if not self._on_completed_fut.done(): | |
self._on_completed_fut.set_result(True) | |
if task.cancelled(): | |
return | |
exc = task.exception() | |
if exc is None: | |
return | |
self._errors.append(exc) | |
if self._is_base_error(exc) and self._base_error is None: | |
self._base_error = exc | |
if self._parent_task.done(): | |
# Not sure if this case is possible, but we want to handle | |
# it anyways. | |
self._loop.call_exception_handler({ | |
'message': f'Task {task!r} has errored out but its parent ' | |
f'task {self._parent_task} is already completed', | |
'exception': exc, | |
'task': task, | |
}) | |
return | |
if not self._aborting and not self._parent_cancel_requested: | |
# If parent task *is not* being cancelled, it means that we want | |
# to manually cancel it to abort whatever is being run right now | |
# in the TaskGroup. But we want to mark parent task as | |
# "not cancelled" later in __aexit__. Example situation that | |
# we need to handle: | |
# | |
# async def foo(): | |
# try: | |
# async with TaskGroup() as g: | |
# g.create_task(crash_soon()) | |
# await something # <- this needs to be canceled | |
# # by the TaskGroup, e.g. | |
# # foo() needs to be cancelled | |
# except Exception: | |
# # Ignore any exceptions raised in the TaskGroup | |
# pass | |
# await something_else # this line has to be called | |
# # after TaskGroup is finished. | |
self._abort() | |
self._parent_cancel_requested = True | |
self._parent_task.cancel() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment