Last active
April 10, 2023 19:48
-
-
Save ktbarrett/178cfed4f9963642eaf9ad27a3e32e16 to your computer and use it in GitHub Desktop.
Join blocks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
async def test(dut): | |
with TaskManager() as tm: | |
@tm.fork | |
async def stimulate(): | |
pass # stimulate an interface | |
@tm.fork | |
async def analyze(): | |
pass # analyze an output | |
await tm.join_all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
async def test(dut): | |
tm = await join_any( | |
RisingEdge(dut.valid), | |
ClockCycles(dut.clk), | |
) | |
assert tm[0].done() # otherwise timeout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import AbstractContextManager | |
from typing import ( | |
Any, | |
Awaitable, | |
Callable, | |
Coroutine, | |
Dict, | |
Iterator, | |
List, | |
Mapping, | |
TypeVar, | |
overload, | |
) | |
from cocotbext.compat import Event, Task, fork | |
K = TypeVar("K") | |
Self = TypeVar("Self") | |
class TaskManager(Mapping[K, Task], AbstractContextManager): | |
def __init__(self) -> None: | |
self._tasks: Dict[K, Task] = {} | |
self._error_squashers: Dict[K, Task] = {} | |
self._event = Event() | |
async def _waiter(self, awaitable: Awaitable) -> Any: | |
try: | |
return await awaitable | |
except Exception as e: | |
self._exception = e | |
finally: | |
self._event.set() | |
async def _error_squasher(self, task: Task) -> None: | |
try: | |
await task | |
except Exception: | |
pass | |
def add(self, id: object, awaitable: Awaitable[Any]) -> Task: | |
task = fork(self._waiter(awaitable)) | |
if id in self._tasks: | |
raise ValueError("Duplicate IDs in TaskManager") | |
self._tasks[id] = task | |
self._error_squashers[id] = fork(self._error_squasher(task)) | |
return task | |
@overload | |
def fork(self, __coro: Callable[[], Coroutine[Any, Any, Any]]) -> Task: | |
... | |
@overload | |
def fork( | |
self, __name: str | |
) -> Callable[[Callable[[], Coroutine[Any, Any, Any]]], Task]: | |
... | |
def fork(self, __coro): # type: ignore | |
if isinstance(__coro, str): | |
name = __coro | |
def decorator(coro: Callable[[], Coroutine[Any, Any, Any]]) -> Task: | |
return self.add(name, coro()) | |
return decorator | |
else: | |
return self.add(__coro.__name__, __coro()) | |
async def _join_next(self) -> None: | |
self._exception = None | |
self._event.clear() | |
await self._event.wait() | |
if self._exception: | |
raise self._exception | |
async def join_next(self) -> None: | |
if not all(task.done() for task in self._tasks.values()): | |
await self._join_next() | |
async def join_any(self) -> None: | |
if any(task.done() for task in self._tasks.values()): | |
return | |
await self._join_next() | |
async def join_all(self) -> None: | |
while not all(task.done() for task in self._tasks.values()): | |
await self._join_next() | |
def cancel_all(self) -> None: | |
for task in self._error_squashers.values(): | |
if not task.done(): | |
task.cancel() | |
for task in self._tasks.values(): | |
if not task.done(): | |
task.cancel() | |
def __getitem__(self, id: K) -> Task: | |
return self._tasks[id] | |
def __iter__(self) -> Iterator[K]: | |
return iter(self._tasks) | |
def __len__(self) -> int: | |
return len(self._tasks) | |
def __enter__(self: Self) -> Self: | |
return self | |
def __exit__(self, *exc_info: Any) -> None: | |
running_ids: List[K] = [] | |
for id, task in self._tasks.items(): | |
if not task.done(): | |
running_ids.append(id) | |
if running_ids: | |
running_ids_str = ", ".join(repr(id) for id in running_ids) | |
raise RuntimeError( | |
f"TaskManager exited with still running Tasks: {running_ids_str}" | |
) | |
class _join_base(TaskManager): | |
def __init__(self, *args: Awaitables[Any], **kwargs: Awaitables[Any]) -> None: | |
super().__init__() | |
for i, arg in enumerate(args): | |
super().add(i, arg) | |
for name, arg in kwargs.items(): | |
super().add(name, arg) | |
class join_any(_join_base): | |
def __await__(self) -> Generator[Any, Any, TaskManager]: | |
yield from super().join_any().__await__() | |
return self | |
class join_all(_join_base): | |
def __await__(self) -> Generator[Any, Any, TaskManager]: | |
yield from super().join_all().__await__() | |
return self |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment