Last active
July 25, 2022 09:48
-
-
Save valsteen/ea51e3259e65295890bed6813161bbf4 to your computer and use it in GitHub Desktop.
How to limit concurrency with Python asyncio?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from typing import Awaitable, Callable, Coroutine, Iterator | |
from asyncio_pool import AioPool | |
import pytest as pytest | |
from more_itertools import peekable | |
""" | |
Different approaches to "How to limit concurrency with Python asyncio?" | |
https://stackoverflow.com/questions/48483348/how-to-limit-concurrency-with-python-asyncio/48484593#48484593 | |
test_gather_with_concurrency demonstrates if those methods work as expected | |
Problem statement: | |
Define a function whose signature is: | |
``` | |
async def gather_with_concurrency( | |
concurrency: int, coroutines: Iterator[Coroutine] | |
): | |
``` | |
and fulfils those invariants: | |
- when returning from the function, all coroutines are completed | |
- the coroutines are executed concurrently with a maximum concurrency of `concurrency` | |
- a slower coroutine does not prevent coroutines that follow to be scheduled, as long as maximum concurrency | |
is not reached | |
""" | |
async def gather_with_concurrency_adam( | |
concurrency: int, coroutines: Iterator[Coroutine] | |
): | |
semaphore = asyncio.Semaphore(concurrency) | |
tasks = [] | |
for coroutine in coroutines: | |
async with semaphore: | |
tasks.append(asyncio.create_task(coroutine)) | |
await asyncio.gather(*tasks) | |
async def gather_with_concurrency_emin(concurrency: int, coros: Iterator[Coroutine]): | |
"""Gather asyncio coroutines with concurrency.""" | |
semaphore = asyncio.Semaphore(concurrency) | |
async def sem_task(task: Coroutine): | |
async with semaphore: | |
return await task | |
return await asyncio.gather(*(sem_task(task) for task in coros)) | |
async def gather_with_concurrency_aiopool(concurrency: int, coros: Iterator[Coroutine]): | |
# adapted from https://stackoverflow.com/a/57381896/34871 | |
pool = AioPool(size=concurrency) | |
coros = peekable(coros) | |
if coros.peek(None): | |
await pool.map(lambda f: f, coros) | |
async def gather_so_anwser(concurrency: int, coroutines: Iterator[Coroutine]): | |
# adapted from https://stackoverflow.com/a/48484593/34871 part1 | |
pending = set() | |
for coroutine in coroutines: | |
if len(pending) >= concurrency: | |
_, pending = await asyncio.tasks.wait( | |
pending, return_when=asyncio.tasks.FIRST_COMPLETED | |
) | |
pending.add(asyncio.create_task(coroutine)) | |
if len(pending) > 0: | |
await asyncio.tasks.wait(pending) | |
async def gather_so_anwser_part2(concurrency: int, coros: Iterator[Coroutine]): | |
# adapted from https://stackoverflow.com/a/48484593/34871 part2 | |
queue = asyncio.Queue() | |
async def worker(): | |
while True: | |
await (await queue.get()) | |
queue.task_done() | |
workers = [asyncio.create_task(worker()) for _ in range(concurrency)] | |
for coro in coros: | |
await queue.put(coro) | |
await queue.join() # wait for all tasks to be processed | |
for worker in workers: | |
worker.cancel() | |
await asyncio.gather(*workers, return_exceptions=True) | |
@pytest.mark.asyncio | |
@pytest.mark.parametrize( | |
"concurrency,size", ((1, 1), (10, 1), (10, 101), (10, 0), (10, 10)) | |
) | |
@pytest.mark.parametrize( | |
"method", | |
( | |
gather_with_concurrency_adam, | |
gather_with_concurrency_emin, | |
gather_with_concurrency_aiopool, | |
gather_so_anwser, | |
gather_so_anwser_part2, | |
), | |
) | |
async def test_gather_with_concurrency( | |
concurrency: int, | |
size: int, | |
method: Callable[[int, Iterator[Coroutine]], Awaitable[None]], | |
): | |
done = [] | |
pending = set() | |
max_concurrency = 0 | |
async def getter(i): | |
nonlocal max_concurrency | |
pending.add(i) | |
max_concurrency = max(len(pending), max_concurrency) | |
# reverse-order completion, to assess that concurrency is happening | |
await asyncio.sleep(1 - i / 10.0) | |
pending.remove(i) | |
done.append(i) | |
await method(concurrency, (getter(i) for i in range(size))) | |
assert len(pending) == 0 | |
if size >= concurrency: | |
assert ( | |
max_concurrency == concurrency | |
), "expected maximum concurrency {}, got {} instead".format( | |
concurrency, max_concurrency | |
) | |
if size > 1: | |
assert done != sorted(done) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment