Skip to content

Instantly share code, notes, and snippets.

@jonpovey
Forked from ktbarrett/example1.py
Created March 22, 2022 00:15
Show Gist options
  • Save jonpovey/81d908cdff730dc790529794b7e0d667 to your computer and use it in GitHub Desktop.
Save jonpovey/81d908cdff730dc790529794b7e0d667 to your computer and use it in GitHub Desktop.
async with join_any() as tm:
@tm.fork
async def block1() -> None:
...
tm.add("test", Combine(...))
# join_any here
await tm.join_all() # wait for them all to finish now
async with join_any() as tm:
@tm.fork
async def block1() -> None:
...
@tm.fork
async def block2() -> None:
...
# task manager is collected, which kills remaining tasks
tasks = await gather(coro_fails(), coro2(), coro3())
tasks[0].cancelled() # True
tasks[0].result() # throws exception
from typing import Any, Awaitable, Dict, KeysView, List, ValuesView
import cocotb
from cocotb.triggers import Event
class TaskManager:
def __init__(self) -> None:
self._tasks: Dict[object, cocotb.Task[Any]] = {}
self._joins: Dict[object, cocotb.Join] = {}
self._event = Event()
def add(self, id: object, awaitable: Awaitable[T]) -> cocotb.Task[T]:
task = cocotb.start_soon(self._waiter(awaitable))
self._tasks[id] = task
self._joins[id] = task.join()
return task
async def _waiter(self, awaitable: Awaitable[T]) -> T:
try:
await awaitable
finally:
self._event.set()
def fork(self, f: Callable[[], Coroutine[Any, Any, T]) -> cocotb.Task[T]:
self.add(f.__name__, f())
async def join_next(self) -> None:
self._event.clear()
await self._event.wait()
async def join_any(self) -> None:
if any(task.done() for task in self.tasks()):
return
await self.join_next()
async def join_all(self) -> None:
while any(not task.done() for task in self.tasks()):
await self.join_next()
def disable(self) -> None:
for task in self._tasks.values():
if not task.done():
task.cancel()
def ids(self) -> KeysView[object]:
return self._tasks.keys()
def tasks(self) -> ValuesView[Any]:
return self._tasks.values()
def __del__(self) -> None:
self.disable()
def __getitem__(self, id: object) -> cocotb.Task[Any]:
return self._tasks[id]
@contextlib.asynccontextmanager
async def join_any(*awaitables: Awaitable[Any]) -> TaskManager:
tm = TaskManager()
yield tm
await tm.join_any()
@contextlib.asynccontextmanager
async def join_all() -> TaskManager:
tm = TaskManager()
yield tm
await tm.join_all()
async def gather(*awaitables: Awaitable[T]) -> List[Task[T]]:
tm = TaskManager()
for i, awaitable in enumerate(awaitables):
tm.add(i, awaitable)
await tm.join_all()
return [tm[i] for i in tm.tasks()]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment