Last active
October 22, 2024 16:12
-
-
Save wyfo/c3c8d27e0eb3feb3c3ca9583be8369db to your computer and use it in GitHub Desktop.
Python asyncio select implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import collections.abc | |
import inspect | |
import random | |
import typing | |
@typing.overload | |
def select[T]( | |
awaitables: collections.abc.Sequence[collections.abc.Awaitable[T]], | |
*, | |
ordered: bool = False, | |
) -> collections.abc.Awaitable[tuple[T, int]]: | |
... | |
@typing.overload | |
def select[T]( | |
awaitables: collections.abc.Sequence[collections.abc.Awaitable[T]], | |
*, | |
ordered: bool = False, | |
sync: typing.Literal[True], | |
) -> tuple[T, int] | None: | |
... | |
def select[T]( | |
awaitables: collections.abc.Sequence[collections.abc.Awaitable[T]], | |
*, | |
ordered: bool = False, | |
sync: bool = False, | |
) -> collections.abc.Awaitable[tuple[T, int]] | tuple[T, int] | None: | |
count = len(awaitables) | |
# fast path for simple case | |
if sync and count == 1: | |
try: | |
next(_future_iter(awaitables[0])) | |
except StopIteration as err: | |
return err.value, 0 | |
else: | |
return None | |
ordered = ordered or count < 2 | |
future_iters = [_future_iter(a) for a in awaitables] | |
futures: list[asyncio.Future | None] = [None] * count | |
for i in range(count) if ordered else random.sample(range(count), count): | |
try: | |
futures[i] = next(future_iters[i]) | |
except StopIteration as err: | |
# close non-awaited coroutine to avoid warning | |
for a, f in zip(awaitables, futures): | |
if f is None and asyncio.iscoroutine(a): | |
a.close() | |
return err.value, i | |
return None if sync else select_loop(future_iters, futures, ordered) # type: ignore | |
def _future_iter[T]( | |
awaitable: collections.abc.Awaitable[T], | |
) -> collections.abc.Iterator[asyncio.Future[T]]: | |
try: | |
__await__ = awaitable.__await__ | |
except AttributeError: | |
if inspect.isawaitable(awaitable): | |
# coroutine decorated with `@types.coroutine` don't have `__await__` method | |
__await__ = awaitable.__iter__ # type: ignore | |
else: | |
error = f"expected a sequence of awaitables, found an item of type {type(awaitable).__name__}" | |
raise TypeError(error) | |
return __await__() | |
async def select_loop[T]( | |
future_iters: list[collections.abc.Iterator[asyncio.Future[T]]], | |
futures: list[asyncio.Future], | |
ordered: bool, | |
) -> tuple[T, int]: | |
# `set_result` captures `future` **by reference**, so there is no need to add | |
# a new callback if `future` is replaced by a new one | |
def set_result(_): | |
if not future.done(): | |
future.set_result(None) | |
future: asyncio.Future = asyncio.Future() | |
for fut in futures: | |
fut.add_done_callback(set_result) | |
count = len(future_iters) | |
while True: | |
await future | |
for i in range(count) if ordered else random.sample(range(count), count): | |
if futures[i].done(): | |
try: | |
fut = next(future_iters[i]) | |
except StopIteration as err: | |
return err.value, i | |
else: | |
fut.add_done_callback(set_result) | |
futures[i] = fut | |
future = asyncio.Future() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment