Created
July 8, 2023 01:52
-
-
Save clbarnes/d61311ca49cf6aa57a8c4022a8c1301b to your computer and use it in GitHub Desktop.
httpx.AsyncClient subclass with semaphore-based rate limiting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import datetime as dt | |
from functools import wraps | |
from typing import Union | |
from httpx import AsyncClient | |
# unless you keep a strong reference to a running task, it can be dropped during execution | |
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task | |
_background_tasks = set() | |
class RateLimitedClient(AsyncClient): | |
"""httpx.AsyncClient with a rate limit.""" | |
def __init__(self, interval: Union[dt.timedelta, float], count=1, **kwargs): | |
""" | |
Parameters | |
---------- | |
interval : Union[dt.timedelta, float] | |
Length of interval. | |
If a float is given, seconds are assumed. | |
numerator : int, optional | |
Number of requests which can be sent in any given interval (default 1). | |
""" | |
if isinstance(interval, dt.timedelta): | |
interval = interval.total_seconds() | |
self.interval = interval | |
self.semaphore = asyncio.Semaphore(count) | |
super().__init__(**kwargs) | |
def _schedule_semaphore_release(self): | |
wait = asyncio.create_task(asyncio.sleep(self.interval)) | |
_background_tasks.add(wait) | |
def wait_cb(task): | |
self.semaphore.release() | |
_background_tasks.discard(task) | |
wait.add_done_callback(wait_cb) | |
@wraps(AsyncClient.send) | |
async def send(self, *args, **kwargs): | |
await self.semaphore.acquire() | |
send = asyncio.create_task(super().send(*args, **kwargs)) | |
self._schedule_semaphore_release() | |
return await send |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment