Created
February 18, 2025 23:13
-
-
Save krisselden/80925b75328120f1915d28a884a689a4 to your computer and use it in GitHub Desktop.
Structured concurrency
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Generator, Collection, Iterable | |
from contextlib import contextmanager | |
from typing import Any, Awaitable, Callable, ContextManager, Mapping, TypeVar | |
from anyio import ( | |
create_memory_object_stream, | |
to_thread, | |
from_thread, | |
Semaphore, | |
create_task_group, | |
) | |
from anyio.streams.memory import MemoryObjectSendStream | |
InputType = TypeVar("InputType") | |
OutputType = TypeVar("OutputType", bound=Mapping[str, Any]) | |
@contextmanager | |
def simple_progress(expected: int) -> Generator[Iterable[int], None, None]: | |
yield range(expected) | |
async def process_items_to_output( | |
items: Collection[InputType], | |
process: Callable[[InputType], Awaitable[OutputType]], | |
output: Callable[[Generator[OutputType, None, None]], Collection[OutputType]], | |
progress: Callable[[int], ContextManager[Iterable[int]]] = simple_progress, | |
concurrency: int = 50, | |
name: object = None, | |
) -> Collection[OutputType]: | |
""" | |
Concurrently process items as a task group while streaming the results as they | |
are finished to be synchronously output in another thread. | |
This allows you to stream to a file using blocking file I/O operations in the | |
output function while still processing items concurrently. | |
""" | |
send_stream, receive_stream = create_memory_object_stream[OutputType]() | |
semaphore = Semaphore(concurrency) | |
async def process_item( | |
item: InputType, send_stream: MemoryObjectSendStream | |
) -> None: | |
async with semaphore, send_stream: | |
result = await process(item) | |
await send_stream.send(result) | |
def process_output_stream() -> Collection[OutputType]: | |
def generator(): | |
with progress(len(items)) as bar: | |
for _ in bar: | |
# receive result on the event loop | |
result = from_thread.run(receive_stream.receive) | |
yield result | |
return output(generator()) | |
async with create_task_group() as tg: | |
async with send_stream: | |
for item in items: | |
tg.start_soon( | |
process_item, | |
item, | |
# since we call start_soon | |
# we need to clone the stream for each task | |
# so it can close independently | |
# otherwise, the stream will close when we exit | |
# the context manager. If we stayed in the context | |
# manager to await receive_stream.receive() | |
# we would hang indefinitely on errors since we would | |
# never get to the expected number of results | |
# never exiting the task group context manager | |
# to get the errors from the failed tasks | |
send_stream.clone(), | |
name=name, | |
) | |
async with receive_stream: | |
return await to_thread.run_sync(process_output_stream) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment