Last active
October 2, 2018 23:53
-
-
Save njsmith/7ea44ec07e901cb78ebe1dd8dd846cb9 to your computer and use it in GitHub Desktop.
Rate limited async iterator decorator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class TokenBucket: | |
| def __init__(self, tokens_per_second, max_tokens, starting_tokens): | |
| self._tokens_per_second = tokens_per_second | |
| self._max_tokens = max_tokens | |
| # floating point, because tracking partial tokens makes the code | |
| # simpler. | |
| self._tokens = starting_tokens | |
| self._last_update_time = trio.current_time() | |
| def _update(self): | |
| now = trio.current_time() | |
| elapsed = now - self._last_update_time | |
| self._tokens += elapsed * self._tokens_per_second | |
| self._last_update_time = now | |
| async def take_token(self): | |
| # The loop lets us handle multiple simultaneous calls to take_token. | |
| # This class isn't really designed to handle this -- it's a bit | |
| # inefficient in this case -- but at least it doesn't break. | |
| while True: | |
| self._update() | |
| if self._tokens >= 1: | |
| self._tokens -= 1 | |
| return | |
| next_token_after = (1 - self._tokens) / self._tokens_per_second | |
| await trio.sleep(next_token_after) | |
| def rate_limit(aiterable, tokens_per_second, max_tokens, starting_tokens): | |
| token_bucket = TokenBucket(tokens_per_second, max_tokens, starting_tokens) | |
| async for value in aiterable: | |
| await token_bucket.take_token() | |
| yield value | |
| # Usage: | |
| # max 10 requests/second on average, bursting up to 5 at once, and no burst at the beginning | |
| async for url in rate_limit(urls_aiter, 10, 5, 1): | |
| ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment