Skip to content

Instantly share code, notes, and snippets.

@wyfo
Last active October 22, 2024 16:12
Show Gist options
  • Save wyfo/c3c8d27e0eb3feb3c3ca9583be8369db to your computer and use it in GitHub Desktop.
Save wyfo/c3c8d27e0eb3feb3c3ca9583be8369db to your computer and use it in GitHub Desktop.
Python asyncio select implementation
import asyncio
import collections.abc
import inspect
import random
import typing
@typing.overload
def select[T](
awaitables: collections.abc.Sequence[collections.abc.Awaitable[T]],
*,
ordered: bool = False,
) -> collections.abc.Awaitable[tuple[T, int]]:
...
@typing.overload
def select[T](
awaitables: collections.abc.Sequence[collections.abc.Awaitable[T]],
*,
ordered: bool = False,
sync: typing.Literal[True],
) -> tuple[T, int] | None:
...
def select[T](
awaitables: collections.abc.Sequence[collections.abc.Awaitable[T]],
*,
ordered: bool = False,
sync: bool = False,
) -> collections.abc.Awaitable[tuple[T, int]] | tuple[T, int] | None:
count = len(awaitables)
# fast path for simple case
if sync and count == 1:
try:
next(_future_iter(awaitables[0]))
except StopIteration as err:
return err.value, 0
else:
return None
ordered = ordered or count < 2
future_iters = [_future_iter(a) for a in awaitables]
futures: list[asyncio.Future | None] = [None] * count
for i in range(count) if ordered else random.sample(range(count), count):
try:
futures[i] = next(future_iters[i])
except StopIteration as err:
# close non-awaited coroutine to avoid warning
for a, f in zip(awaitables, futures):
if f is None and asyncio.iscoroutine(a):
a.close()
return err.value, i
return None if sync else select_loop(future_iters, futures, ordered) # type: ignore
def _future_iter[T](
awaitable: collections.abc.Awaitable[T],
) -> collections.abc.Iterator[asyncio.Future[T]]:
try:
__await__ = awaitable.__await__
except AttributeError:
if inspect.isawaitable(awaitable):
# coroutine decorated with `@types.coroutine` don't have `__await__` method
__await__ = awaitable.__iter__ # type: ignore
else:
error = f"expected a sequence of awaitables, found an item of type {type(awaitable).__name__}"
raise TypeError(error)
return __await__()
async def select_loop[T](
future_iters: list[collections.abc.Iterator[asyncio.Future[T]]],
futures: list[asyncio.Future],
ordered: bool,
) -> tuple[T, int]:
# `set_result` captures `future` **by reference**, so there is no need to add
# a new callback if `future` is replaced by a new one
def set_result(_):
if not future.done():
future.set_result(None)
future: asyncio.Future = asyncio.Future()
for fut in futures:
fut.add_done_callback(set_result)
count = len(future_iters)
while True:
await future
for i in range(count) if ordered else random.sample(range(count), count):
if futures[i].done():
try:
fut = next(future_iters[i])
except StopIteration as err:
return err.value, i
else:
fut.add_done_callback(set_result)
futures[i] = fut
future = asyncio.Future()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment