Created
February 5, 2021 00:09
-
-
Save StephenBrown2/c59463a9d8903f8e2daa0ee1d1debcc3 to your computer and use it in GitHub Desktop.
httpx ThrottleTransport
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import datetime | |
import httpx | |
class ThrottleTransport(httpx.HTTPTransport): | |
""" | |
A custom httpx transport that adds some basic rate-throttling functionality. | |
Args: | |
num_requests (int): The number of requests that will be allowed before throttling. | |
per_duration (datetime.timedelta): The time period to allow requests in before throttling. | |
kwargs: You can give any other args for httpx.HTTPTransport, which will be passed through. | |
Example: | |
Create a ThrottleTransport and pass it to an httpx.Client for use. For example, | |
to allow only 20 requests per 10 seconds:: | |
transport = ThrottleTransport(20, datetime.timedelta(seconds=10)) | |
client = httpx.Client(transport=transport) | |
""" | |
__queue = {} | |
def __init__(self, num_requests: int, per_duration: datetime.timedelta, **kwargs): | |
self.history = ThrottleTransport.__queue.setdefault((num_requests, per_duration), []) | |
self.max_in_history = num_requests | |
self.cutoff = per_duration | |
super().__init__(**kwargs) | |
def request(self, *args, **kwargs): | |
now = datetime.datetime.now() | |
while len(self.history) >= self.max_in_history: | |
expiry = now - self.cutoff | |
# Expire old entries in the history. | |
self.history = [timestamp for timestamp in self.history if timestamp > expiry] | |
# Sleep for a bit if we've exceeded the throttle rate. | |
time.sleep(0.1) | |
now = datetime.datetime.now() | |
self.history.append(now) | |
return super().request(*args, **kwargs) | |
with httpx.Client(transport=ThrottleTransport(8, datetime.timedelta(seconds=15))) as client: | |
for i in range(100): | |
print(f"Getting request {i+1}") | |
r = client.get("https://httpbin.org/get") | |
print(f"Got request {i+1}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment