Skip to content

Instantly share code, notes, and snippets.

@ktbarrett
Created August 29, 2024 17:55
Show Gist options
  • Save ktbarrett/17643f490fb3a0a3241ce7038954ce4a to your computer and use it in GitHub Desktop.
Save ktbarrett/17643f490fb3a0a3241ce7038954ce4a to your computer and use it in GitHub Desktop.
class TaskManager(Mapping[str, Task[Any]]):
"""Associative container of names to child Tasks.
Tasks created with this object are not allowed to leak;
child tasks will be cancelled if the TaskManager is collected.
Typically this is used as an async context manager,
forcing all child Tasks to finish before returning.
.. code-block:: python3
inputs = range(100)
async with TaskManager() as tm:
@tm.fork
async def drive_input():
for inp in inputs:
await driver.send(inp)
@tm.fork
async def monitor_output():
for inp in inputs:
expected_output = model(inp)
actual_output = await monitor.recv()
assert actual_output == expected_output
# Only continues executing here after all child Tasks finish.
If a child Task throws an error, all other child Tasks are cancelled and the Exception is propogated.
.. code-block:: python3
try:
async with TaskManager() as tm:
@tm.fork
async def fails():
...
raise KeyError("Oops!")
except KeyError as e:
print(f"Caught KeyError here! {e}")
After a TaskManager exits, you can inspect the outcome of the child Tasks by indexing into the TaskManager using the Task's name.
Names in a TaskManager must be unique.
Attempting to register a Task with the taken name results in a :exc:`ValueError`.
.. code-block:: python3
async with TaskManager() as tm:
@tm.fork
async def foo():
...
return 8
@tm.fork
async def bar():
...
return "wow"
assert tm["foo"].result() == 8
assert tm["bar"].result() == "wow"
.. note::
After TaskManagers used as async context managers exit, no more Tasks can be added to them.
There are several methods for creating Tasks from various objects: :meth:`add`, :meth:`start_soon`, :meth:`start`, :meth:`fork`, and :meth:`fork_soon`.
All of these methods return :class:`~cocotb.task.Task` objects that can be ``await``\ ed, inspected, :meth:`~cocotb.task.Task.cancel`\ ed, etc.
There are also the :meth:`join_next`, :meth:`join_any`, :meth:`join_all`, and :meth:`cancel_all`, which the user can use to move the state of the entire group of Tasks.
.. code-block:: python3
async with TaskManager() as tm:
# start A and B at the same time
@tm.fork
async def do_A():
...
@tm.fork
async def do_B():
...
# wait for B to finish, then start C
await do_B
@tm.fork
async def do_C():
...
# wait for C to finish, then cancel A
await do_C
do_A.cancel()
"""
def add(self, name: str, aw: Awaitable[T]) -> Task[T]:
# wraps Awaitable in Task
# adds to internal map using given name
@overload
def start_soon(self, name: str, coro: Coroutine[Any, Any, T]) -> Task[T]:
"""Creates Task and registers it with TaskManager using the given name.
Task will start soon.
"""
@overload
def start_soon(self, coro: Coroutine[Any, Any, T]) -> Task[T]:
"""Creates Task and registers it with TaskManager using ``coro.__name__``.
Task will start soon.
"""
@overload
async def start(self, name: str, coro: Coroutine[Any, Any, T]) -> Task[T]:
"""Creates Task and registers it with TaskManager using the given name.
Current execution pauses until Task has started.
"""
@overload
async def start(self, coro: Coroutine[Any, Any, T]) -> Task[T]:
"""Creates Task and registers coroutine with TaskManager using the given name.
Current execution pauses until Task has started.
"""
def fork(self, coro_func: Callable[[], Coroutine[Any, Any, T]]) -> Task[T]:
"""Creates Task and registers coroutine with TaskManager using ``coro_func.__name__``.
Current execution pauses until Task has started.
"""
def __getitem__(self, key: str) -> Task[Any]:
"""Returns the child Task associated with the given name."""
def __iter__(self) -> Iterator[Task[Any]]:
...
def __len__(self) -> int:
...
# inherits __contains__, keys(), values(), items(), get(), __eq__
async def join_next(self) -> None:
"""Blocks until a current unfinished child Task finishes.
Returns immediately if all child Tasks are finished.
"""
async def join_any(self) -> None:
"""Blocks until at least one child Task finishes."""
async def join_all(self) -> None:
"""Blocks until all child Tasks finish."""
def cancel_all(self) -> None:
"""Cancels all living child Tasks."""
# Can be used as an async context manager.
async def __aenter__(self) -> None:
# set parent Task ???
...
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.join_all()
# If not used as an async context manager, we make sure we clean up with a warning if it falls out of scope.
# No Tasks leaking here!
def __del__(self) -> None:
# emit warning
self.cancel_all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment