Skip to content

Instantly share code, notes, and snippets.

@luzfcb
Last active July 8, 2021 22:02
Show Gist options
  • Select an option

  • Save luzfcb/bdfa294261c17c395cea9c14beb2c8ff to your computer and use it in GitHub Desktop.

Select an option

Save luzfcb/bdfa294261c17c395cea9c14beb2c8ff to your computer and use it in GitHub Desktop.
Rate limit control using requests . Celery, Python.
import functools
import inspect
import random
#
from celery.utils.log import get_task_logger
# https://github.com/jazzband/django-redis#raw-client-access
from django_redis import get_redis_connection
logger = get_task_logger(__name__)
# Based on https://stackoverflow.com/a/66161773/2975300
# and https://gist.github.com/Vigrond/2bbea9be6413415e5479998e79a1b11a
def parse_rate(rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
(Stolen from Django Rest Framework.)
:param rate: str
:rtype tuple[int, int]
"""
if '/' not in rate:
raise ValueError(
'rate argument should be in '
'the format "<int>/<char>" or <int>/<int><char> '
'like 100/s or 90/5d'
)
num, period = rate.lower().split('/')
num_requests = int(num)
if len(period) > 1:
# It takes the form of a 5d, or 10s, or whatever
duration_multiplier = int(period[0:-1])
duration_unit = period[-1]
else:
duration_multiplier = 1
duration_unit = period[-1]
duration_base = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[duration_unit]
duration = duration_base * duration_multiplier
return num_requests, duration
def throttle_task(
rate,
jitter=(1, 10),
key=None,
redis_connection_name='default'
):
"""A decorator for throttling tasks to a given rate.
Usage:
@celery_app.task(bind=True, max_retries=5)
@throttle_task('4/s', key='url', jitter=(1, 10))
def foobar_task(self, url):
do_stuff(url)
The result is that you'll throttle the task to 4 runs per second
per url parameter, with a random retry jitter between 1-10s.
The key parameter is optional, but corresponds to a parameter
in your task. If the key parameter is not given, it'll just
throttle the task to the rate given. If it is provided,
then the throttle will apply to the (task, key) dyad.
:param rate: The maximum rate that you want your task to run. Takes the
form of '1/m', or '10/2h' or similar.
:type rate: str
:param jitter: A tuple of the range of backoff times you want for throttled
tasks. If the task is throttled, it will wait a random amount of time
between these values before being tried again.
:type jitter: tuple[float, float]
:param key: An argument name whose value should be used as part of the
throttle key in redis. This allows you to create per-argument throttles by
simply passing the name of the argument you wish to key on.
:type key: Any
:return: The decorated function
:rtype: Callable
"""
def decorator_func(func):
"""
:param func:
:type func: Callable
:return:
:rtype Callable
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Inspect the decorated function's parameters to get the task
# itself and the value of the parameter referenced by key.
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
task = bound_args.arguments['self']
key_value = None
if key:
try:
key_value = bound_args.arguments[key]
except KeyError:
raise KeyError(
f"Unknown parameter '{key}' in throttle_task "
f'decorator of function {task.name}. '
f'`key` parameter must match a parameter '
f"name from function signature: '{sig}'"
)
proceed = is_rate_okay(task, rate, key=key_value, redis_connection_name=redis_connection_name)
if not proceed:
logger.info(
'Throttling task %s (%s) via decorator.',
task.name,
task.request.id,
)
# Decrement the number of times the task has retried. If you
# fail to do this, it gets auto-incremented, and you'll expend
# retries during the backoff.
task.request.retries = task.request.retries - 1
return task.retry(countdown=random.uniform(*jitter))
else:
# All set. Run the task.
return func(*args, **kwargs)
return wrapper
return decorator_func
def is_rate_okay(task, rate='1/s', key=None, redis_connection_name='default'):
"""Keep a global throttle for tasks
Can be used via the `throttle_task` decorator above.
This implements the timestamp-based algorithm detailed here:
https://www.figma.com/blog/an-alternative-approach-to-rate-limiting/
Basically, you keep track of the number of requests and use the key
expiration as a reset of the counter.
So you have a rate of 5/m, and your first task comes in. You create a key:
celery_throttle:task_name = 1
celery_throttle:task_name.expires = 60
Another task comes in a few seconds later:
celery_throttle:task_name = 2
Do not update the ttl, it now has 58s remaining
And so forth, until:
celery_throttle:task_name = 6
(10s remaining)
We're over the threshold. Re-queue the task for later. 10s later:
Key expires b/c no more ttl.
Another task comes in:
celery_throttle:task_name = 1
celery_throttle:task_name.expires = 60
And so forth.
:param task: The task that is being checked
:type task: Task
:param rate: How many times the task can be run during the time period.
Something like, 1/s, 2/h or similar.
:type rate: str
:param key: If given, add this to the key placed in Redis for the item.
Typically, this will correspond to the value of an argument passed to the
throttled task.
:param key: str
:param redis_connection_name: str
:return: Whether the task should be throttled or not.
:rtype bool
"""
key = f"celery_throttle:{task.name}{':' + str(key) if key else ''}"
r = get_redis_connection(redis_connection_name)
num_tasks, duration = parse_rate(rate)
# Check the count in redis
count = r.get(key)
if count is None:
# No key. Set the value to 1 and set the ttl of the key.
r.set(key, 1)
r.expire(key, duration)
return True
else:
# Key found. Check it.
if int(count) <= num_tasks:
# We're OK, run it.
r.incr(key, 1)
return True
else:
return False
import random
# https://django-constance.readthedocs.io/en/latest/
from constance import config as django_constance_config
from celery.utils.log import get_task_logger
from myapp import celery_app
from .celery_extras import is_rate_okay, throttle_task
from .crawlers import extract_content_from_url
logger = get_task_logger(__name__)
# use is_rate_okay
@celery_app.task(bind=True, max_retries=8, default_retry_delay=3.8)
def task_extract_content_from_url(self, url):
rate_limit_per_minute = django_constance_config.EXTRACT_CONTENT_RATE_LIMIT
rate = f'{rate_limit_per_minute}/m'
jitter = random.uniform(0.5, 3.0)
if is_rate_okay(task=self, rate=rate):
try:
extract_content_from_url(url)
except Exception as exc:
logger.info(
f'task_extract_content_from_url: '
f'An exception has occurred ({exc.args[0]})'
f'In {rate_limit_reset} seconds, the task '
f'will try to query again: {shipped_orders_resource_url}'
)
raise self.retry(exc=exc, countdown=jitter)
else:
logger.info(
f'task_extract_content_from_url: '
f'The global rate limit ({rate}) has been reached. '
f'In {rate_limit_reset} seconds, the task '
f'will try to query again: {shipped_orders_resource_url}'
)
self.request.retries = self.request.retries - 1 # Don't count this as against max_retries.
# This essentially causes the task
# to try to run indefinitely until is_rate_okay return True
return self.retry(countdown=jitter)
@celery_app.task(bind=True, max_retries=5)
@throttle_task('100/m', jitter=(0.5, 3.0))
def task_extract_content_from_url_v2(self, url):
try:
extract_content_from_url(url)
except Exception as exc:
logger.info(
f'task_extract_content_from_url: '
f'An exception has occurred ({exc.args[0]})'
f'In {rate_limit_reset} seconds, the task '
f'will try to query again: {shipped_orders_resource_url}'
)
raise self.retry(exc=exc, countdown=random.uniform(0.5, 3.0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment