-
-
Save arthur-tacca/6c676a21d0dcc0582edb50c9c2aa3e3c to your computer and use it in GitHub Desktop.
Trio: results-gathering nursery wrapper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original idea by smurfix: https://gist.github.com/smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86 | |
# aioresult variant: https://gist.github.com/arthur-tacca/5c717ae68ac037e72ae45fd1e9ca1345 | |
from collections import deque | |
import math | |
import trio | |
class StreamResultsNursery: | |
def __init__(self, max_running_tasks=math.inf): | |
self._nursery = trio.open_nursery() | |
self._results = deque() | |
self._unfinished_tasks_count = 0 # Includes both running and waiting to run | |
self._capacity_limiter = trio.CapacityLimiter(max_running_tasks) | |
self._nm = None | |
self._parking_lot = trio.lowlevel.ParkingLot() | |
self._loop_finished = False | |
@property | |
def cancel_scope(self): | |
return self._nm.cancel_scope | |
@property | |
def max_running_tasks(self): | |
return self._capacity_limiter.total_tokens | |
@max_running_tasks.setter | |
def max_running_tasks(self, value): | |
self._capacity_limiter.total_tokens = value | |
@property | |
def running_tasks_count(self): | |
return self._capacity_limiter.borrowed_tokens | |
async def __aenter__(self): | |
self._nm = await self._nursery.__aenter__() | |
return self | |
def __aexit__(self, *exc): | |
return self._nursery.__aexit__(*exc) | |
async def _wrap(self, p, a, task_status=trio.TASK_STATUS_IGNORED): | |
try: | |
async with self._capacity_limiter: | |
task_status.started() | |
self._results.append(await p(*a)) | |
finally: | |
self._unfinished_tasks_count -= 1 | |
self._parking_lot.unpark() | |
def start_soon(self, p, *a): | |
if self._nm is None: | |
raise RuntimeError("Enter context manager before starting tasks") | |
if self._loop_finished: | |
raise RuntimeError("Loop over results has already completed") | |
self._unfinished_tasks_count += 1 | |
self._nm.start_soon(self._wrap, p, a) | |
async def start(self, p, *a): | |
if self._nm is None: | |
raise RuntimeError("Enter context manager before starting tasks") | |
if self._loop_finished: | |
raise RuntimeError("Loop over results has already completed") | |
self._unfinished_tasks_count += 1 | |
await self._nm.start(self._wrap, p, a) | |
def __aiter__(self): | |
return self | |
async def __anext__(self): | |
await trio.lowlevel.checkpoint() # Ensure this function is always a checkpoint | |
while len(self._results) == 0 and self._unfinished_tasks_count != 0: | |
await self._parking_lot.park() # Need to wait for a result to be produced | |
if self._results: | |
return self._results.popleft() | |
self._loop_finished = True | |
raise StopAsyncIteration # All tasks done and all results retrieved | |
if __name__ == "__main__": | |
import random | |
async def rand(): | |
sleep_length = random.random() | |
try: | |
print(f"Starting: {sleep_length}") | |
await trio.sleep(sleep_length) | |
print(f"Finished: {sleep_length}") | |
return sleep_length | |
finally: | |
print(f"Done: {sleep_length}") | |
async def main(count): | |
async with trio.open_nursery() as outer_nursery: | |
async with StreamResultsNursery(max_running_tasks=3) as N: | |
for i in range(count): | |
print(f"Starting task {i}") | |
N.start_soon(rand) | |
i = 0 | |
async for rn in N: | |
i += 1 | |
print(f"Got {i}: {rn}\n") | |
if i == count: | |
print(f"starting extra task") | |
N.start_soon(rand) | |
trio.run(main,10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment