Last active
July 8, 2021 22:02
-
-
Save luzfcb/bdfa294261c17c395cea9c14beb2c8ff to your computer and use it in GitHub Desktop.
Rate limit control using requests . Celery, Python.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import functools | |
| import inspect | |
| import random | |
| # | |
| from celery.utils.log import get_task_logger | |
| # https://github.com/jazzband/django-redis#raw-client-access | |
| from django_redis import get_redis_connection | |
| logger = get_task_logger(__name__) | |
| # Based on https://stackoverflow.com/a/66161773/2975300 | |
| # and https://gist.github.com/Vigrond/2bbea9be6413415e5479998e79a1b11a | |
| def parse_rate(rate): | |
| """ | |
| Given the request rate string, return a two tuple of: | |
| <allowed number of requests>, <period of time in seconds> | |
| (Stolen from Django Rest Framework.) | |
| :param rate: str | |
| :rtype tuple[int, int] | |
| """ | |
| if '/' not in rate: | |
| raise ValueError( | |
| 'rate argument should be in ' | |
| 'the format "<int>/<char>" or <int>/<int><char> ' | |
| 'like 100/s or 90/5d' | |
| ) | |
| num, period = rate.lower().split('/') | |
| num_requests = int(num) | |
| if len(period) > 1: | |
| # It takes the form of a 5d, or 10s, or whatever | |
| duration_multiplier = int(period[0:-1]) | |
| duration_unit = period[-1] | |
| else: | |
| duration_multiplier = 1 | |
| duration_unit = period[-1] | |
| duration_base = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[duration_unit] | |
| duration = duration_base * duration_multiplier | |
| return num_requests, duration | |
| def throttle_task( | |
| rate, | |
| jitter=(1, 10), | |
| key=None, | |
| redis_connection_name='default' | |
| ): | |
| """A decorator for throttling tasks to a given rate. | |
| Usage: | |
| @celery_app.task(bind=True, max_retries=5) | |
| @throttle_task('4/s', key='url', jitter=(1, 10)) | |
| def foobar_task(self, url): | |
| do_stuff(url) | |
| The result is that you'll throttle the task to 4 runs per second | |
| per url parameter, with a random retry jitter between 1-10s. | |
| The key parameter is optional, but corresponds to a parameter | |
| in your task. If the key parameter is not given, it'll just | |
| throttle the task to the rate given. If it is provided, | |
| then the throttle will apply to the (task, key) dyad. | |
| :param rate: The maximum rate that you want your task to run. Takes the | |
| form of '1/m', or '10/2h' or similar. | |
| :type rate: str | |
| :param jitter: A tuple of the range of backoff times you want for throttled | |
| tasks. If the task is throttled, it will wait a random amount of time | |
| between these values before being tried again. | |
| :type jitter: tuple[float, float] | |
| :param key: An argument name whose value should be used as part of the | |
| throttle key in redis. This allows you to create per-argument throttles by | |
| simply passing the name of the argument you wish to key on. | |
| :type key: Any | |
| :return: The decorated function | |
| :rtype: Callable | |
| """ | |
| def decorator_func(func): | |
| """ | |
| :param func: | |
| :type func: Callable | |
| :return: | |
| :rtype Callable | |
| """ | |
| @functools.wraps(func) | |
| def wrapper(*args, **kwargs): | |
| # Inspect the decorated function's parameters to get the task | |
| # itself and the value of the parameter referenced by key. | |
| sig = inspect.signature(func) | |
| bound_args = sig.bind(*args, **kwargs) | |
| task = bound_args.arguments['self'] | |
| key_value = None | |
| if key: | |
| try: | |
| key_value = bound_args.arguments[key] | |
| except KeyError: | |
| raise KeyError( | |
| f"Unknown parameter '{key}' in throttle_task " | |
| f'decorator of function {task.name}. ' | |
| f'`key` parameter must match a parameter ' | |
| f"name from function signature: '{sig}'" | |
| ) | |
| proceed = is_rate_okay(task, rate, key=key_value, redis_connection_name=redis_connection_name) | |
| if not proceed: | |
| logger.info( | |
| 'Throttling task %s (%s) via decorator.', | |
| task.name, | |
| task.request.id, | |
| ) | |
| # Decrement the number of times the task has retried. If you | |
| # fail to do this, it gets auto-incremented, and you'll expend | |
| # retries during the backoff. | |
| task.request.retries = task.request.retries - 1 | |
| return task.retry(countdown=random.uniform(*jitter)) | |
| else: | |
| # All set. Run the task. | |
| return func(*args, **kwargs) | |
| return wrapper | |
| return decorator_func | |
| def is_rate_okay(task, rate='1/s', key=None, redis_connection_name='default'): | |
| """Keep a global throttle for tasks | |
| Can be used via the `throttle_task` decorator above. | |
| This implements the timestamp-based algorithm detailed here: | |
| https://www.figma.com/blog/an-alternative-approach-to-rate-limiting/ | |
| Basically, you keep track of the number of requests and use the key | |
| expiration as a reset of the counter. | |
| So you have a rate of 5/m, and your first task comes in. You create a key: | |
| celery_throttle:task_name = 1 | |
| celery_throttle:task_name.expires = 60 | |
| Another task comes in a few seconds later: | |
| celery_throttle:task_name = 2 | |
| Do not update the ttl, it now has 58s remaining | |
| And so forth, until: | |
| celery_throttle:task_name = 6 | |
| (10s remaining) | |
| We're over the threshold. Re-queue the task for later. 10s later: | |
| Key expires b/c no more ttl. | |
| Another task comes in: | |
| celery_throttle:task_name = 1 | |
| celery_throttle:task_name.expires = 60 | |
| And so forth. | |
| :param task: The task that is being checked | |
| :type task: Task | |
| :param rate: How many times the task can be run during the time period. | |
| Something like, 1/s, 2/h or similar. | |
| :type rate: str | |
| :param key: If given, add this to the key placed in Redis for the item. | |
| Typically, this will correspond to the value of an argument passed to the | |
| throttled task. | |
| :param key: str | |
| :param redis_connection_name: str | |
| :return: Whether the task should be throttled or not. | |
| :rtype bool | |
| """ | |
| key = f"celery_throttle:{task.name}{':' + str(key) if key else ''}" | |
| r = get_redis_connection(redis_connection_name) | |
| num_tasks, duration = parse_rate(rate) | |
| # Check the count in redis | |
| count = r.get(key) | |
| if count is None: | |
| # No key. Set the value to 1 and set the ttl of the key. | |
| r.set(key, 1) | |
| r.expire(key, duration) | |
| return True | |
| else: | |
| # Key found. Check it. | |
| if int(count) <= num_tasks: | |
| # We're OK, run it. | |
| r.incr(key, 1) | |
| return True | |
| else: | |
| return False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| # https://django-constance.readthedocs.io/en/latest/ | |
| from constance import config as django_constance_config | |
| from celery.utils.log import get_task_logger | |
| from myapp import celery_app | |
| from .celery_extras import is_rate_okay, throttle_task | |
| from .crawlers import extract_content_from_url | |
| logger = get_task_logger(__name__) | |
| # use is_rate_okay | |
| @celery_app.task(bind=True, max_retries=8, default_retry_delay=3.8) | |
| def task_extract_content_from_url(self, url): | |
| rate_limit_per_minute = django_constance_config.EXTRACT_CONTENT_RATE_LIMIT | |
| rate = f'{rate_limit_per_minute}/m' | |
| jitter = random.uniform(0.5, 3.0) | |
| if is_rate_okay(task=self, rate=rate): | |
| try: | |
| extract_content_from_url(url) | |
| except Exception as exc: | |
| logger.info( | |
| f'task_extract_content_from_url: ' | |
| f'An exception has occurred ({exc.args[0]})' | |
| f'In {rate_limit_reset} seconds, the task ' | |
| f'will try to query again: {shipped_orders_resource_url}' | |
| ) | |
| raise self.retry(exc=exc, countdown=jitter) | |
| else: | |
| logger.info( | |
| f'task_extract_content_from_url: ' | |
| f'The global rate limit ({rate}) has been reached. ' | |
| f'In {rate_limit_reset} seconds, the task ' | |
| f'will try to query again: {shipped_orders_resource_url}' | |
| ) | |
| self.request.retries = self.request.retries - 1 # Don't count this as against max_retries. | |
| # This essentially causes the task | |
| # to try to run indefinitely until is_rate_okay return True | |
| return self.retry(countdown=jitter) | |
| @celery_app.task(bind=True, max_retries=5) | |
| @throttle_task('100/m', jitter=(0.5, 3.0)) | |
| def task_extract_content_from_url_v2(self, url): | |
| try: | |
| extract_content_from_url(url) | |
| except Exception as exc: | |
| logger.info( | |
| f'task_extract_content_from_url: ' | |
| f'An exception has occurred ({exc.args[0]})' | |
| f'In {rate_limit_reset} seconds, the task ' | |
| f'will try to query again: {shipped_orders_resource_url}' | |
| ) | |
| raise self.retry(exc=exc, countdown=random.uniform(0.5, 3.0)) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Another implementation from the comment https://stackoverflow.com/questions/29854102/celery-rate-limit-on-tasks-with-the-same-parameters/66161773#comment120428734_66161773
https://github.com/freelawproject/courtlistener/blob/969d912c1bc85fa26597815fcb02e7e492eb1bba/cl/lib/celery_utils.py#L123