-
-
Save jonpovey/81d908cdff730dc790529794b7e0d667 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
async with join_any() as tm: | |
@tm.fork | |
async def block1() -> None: | |
... | |
tm.add("test", Combine(...)) | |
# join_any here | |
await tm.join_all() # wait for them all to finish now |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
async with join_any() as tm: | |
@tm.fork | |
async def block1() -> None: | |
... | |
@tm.fork | |
async def block2() -> None: | |
... | |
# task manager is collected, which kills remaining tasks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tasks = await gather(coro_fails(), coro2(), coro3()) | |
tasks[0].cancelled() # True | |
tasks[0].result() # throws exception |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Awaitable, Dict, KeysView, List, ValuesView | |
import cocotb | |
from cocotb.triggers import Event | |
class TaskManager: | |
def __init__(self) -> None: | |
self._tasks: Dict[object, cocotb.Task[Any]] = {} | |
self._joins: Dict[object, cocotb.Join] = {} | |
self._event = Event() | |
def add(self, id: object, awaitable: Awaitable[T]) -> cocotb.Task[T]: | |
task = cocotb.start_soon(self._waiter(awaitable)) | |
self._tasks[id] = task | |
self._joins[id] = task.join() | |
return task | |
async def _waiter(self, awaitable: Awaitable[T]) -> T: | |
try: | |
await awaitable | |
finally: | |
self._event.set() | |
def fork(self, f: Callable[[], Coroutine[Any, Any, T]) -> cocotb.Task[T]: | |
self.add(f.__name__, f()) | |
async def join_next(self) -> None: | |
self._event.clear() | |
await self._event.wait() | |
async def join_any(self) -> None: | |
if any(task.done() for task in self.tasks()): | |
return | |
await self.join_next() | |
async def join_all(self) -> None: | |
while any(not task.done() for task in self.tasks()): | |
await self.join_next() | |
def disable(self) -> None: | |
for task in self._tasks.values(): | |
if not task.done(): | |
task.cancel() | |
def ids(self) -> KeysView[object]: | |
return self._tasks.keys() | |
def tasks(self) -> ValuesView[Any]: | |
return self._tasks.values() | |
def __del__(self) -> None: | |
self.disable() | |
def __getitem__(self, id: object) -> cocotb.Task[Any]: | |
return self._tasks[id] | |
@contextlib.asynccontextmanager | |
async def join_any(*awaitables: Awaitable[Any]) -> TaskManager: | |
tm = TaskManager() | |
yield tm | |
await tm.join_any() | |
@contextlib.asynccontextmanager | |
async def join_all() -> TaskManager: | |
tm = TaskManager() | |
yield tm | |
await tm.join_all() | |
async def gather(*awaitables: Awaitable[T]) -> List[Task[T]]: | |
tm = TaskManager() | |
for i, awaitable in enumerate(awaitables): | |
tm.add(i, awaitable) | |
await tm.join_all() | |
return [tm[i] for i in tm.tasks()] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment