Created
April 23, 2025 14:29
-
-
Save kbhalerao/c11b0d90ced8d49389112abef4e3d408 to your computer and use it in GitHub Desktop.
Some decorators to make Channels Background Worker consumers a bit more resilient to failure
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import inspect | |
import logging | |
import time | |
import traceback | |
from functools import wraps | |
from typing import Callable, Any | |
from channels.layers import get_channel_layer | |
from django.utils.timezone import now | |
from django.conf import settings | |
import redis | |
def get_redis_con(conn='default'): | |
""" | |
Provides the default redis connection | |
""" | |
try: | |
host, port = settings.CHANNEL_LAYERS[conn]['CONFIG']['hosts'][0] | |
return redis.Redis(host=host, port=port, db=0, decode_responses=True) | |
except ValueError: | |
url = settings.CHANNEL_LAYERS[conn]['CONFIG']['hosts'][0] | |
return redis.Redis.from_url(url=url, decode_responses=True) | |
def get_redis_prefix(conn='default'): | |
try: | |
return settings.CHANNEL_LAYERS[conn]['CONFIG']['prefix'] | |
except ValueError: | |
return settings.CHANNEL_LAYERS[conn]['CONFIG']['CHANNEL_NAME_PREFIX'] | |
logger = logging.getLogger(__name__) | |
redis_connection = get_redis_con() | |
def _handle_task_error( | |
func_name: str, | |
error: Exception, | |
channel_name: str, | |
message: dict, | |
retries: int, | |
max_retries: int, | |
delay: int | |
) -> bool: | |
"""Helper function to log error information and display traceback.""" | |
tb_lines = traceback.format_tb(error.__traceback__) | |
logger.error(f"Error occurred in function {func_name}: {error}") | |
logger.info(f"Requeuing task on channel layer '{channel_name}' again in {delay} seconds...") | |
if retries >= max_retries: | |
logger.error("Maximum retries reached. Task failed.") | |
logger.error("Traceback (last 10 lines): \n" + "".join(tb_lines[-10:])) | |
return False | |
return True | |
def _examine_task(func, *args): | |
"""Helper function to extract channel name and message from task.""" | |
assert args[0].scope['type'] == 'channel', "These decorators can only be used on background tasks" | |
channel_name = args[0].scope['channel'] | |
message = args[1] if len(args) > 1 else {} | |
current_retries = message.get('retries', 0) | |
if current_retries > 0: | |
logger.info(f"Retrying task {func.__name__} attempt {current_retries + 1}...") | |
return channel_name, current_retries, message | |
def requeue_task(max_retries: int = 3, delay: int = 10) -> Callable: | |
""" | |
A decorator for wrapping background worker functions (both sync and async). | |
Provides the ability to retry the function for a specified number of times. | |
Args: | |
max_retries: Maximum number of retry attempts (default: 3) | |
delay: Delay in seconds before retrying (default: 10) | |
Returns: | |
Decorated function with retry capability | |
Raises: | |
ValueError: If channel_name is not provided | |
""" | |
def decorator(func: Callable) -> Callable: | |
is_async = inspect.iscoroutinefunction(func) | |
@wraps(func) | |
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: | |
channel_name, current_retries, message = _examine_task(func, *args) | |
try: | |
return await func(*args, **kwargs) | |
except Exception as e: | |
should_retry = _handle_task_error( | |
func.__name__, e, channel_name, message, | |
current_retries, max_retries, delay | |
) | |
if should_retry: | |
await asyncio.sleep(delay) | |
message["retries"] = current_retries + 1 | |
await get_channel_layer().send(channel_name, message) | |
return None | |
@wraps(func) | |
def sync_wrapper(*args: Any, **kwargs: Any) -> Any: | |
channel_name, current_retries, message = _examine_task(func, *args) | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
should_retry = _handle_task_error( | |
func.__name__, e, channel_name, message, | |
current_retries, max_retries, delay | |
) | |
if should_retry: | |
time.sleep(delay) | |
message["retries"] = current_retries + 1 | |
asyncio.run(get_channel_layer().send(channel_name, message)) | |
return None | |
return async_wrapper if is_async else sync_wrapper | |
return decorator | |
def run_continously(interval: int) -> Callable: | |
""" | |
A decorator to run a function periodically. Useful for setting up a periodic task like a cron job. | |
Please note that if the function passed to it raises an exception, the scheduler will stop running | |
immediately. | |
:param interval: Minimum interval between two successive runs of the function | |
:return: | |
""" | |
def decorator(func: Callable) -> Callable: | |
is_async = inspect.iscoroutinefunction(func) | |
@wraps(func) | |
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: | |
channel_name, _, message = _examine_task(func, *args) | |
while True: | |
try: | |
message["last_run"] = now() | |
await func(*args, **kwargs) | |
message["last_run_successful"] = True | |
if (now() - message["last_run"]).total_seconds() < interval: | |
await asyncio.sleep(int(interval - (now() - message["last_run"]).total_seconds()) + 1) | |
else: | |
await get_channel_layer().send(channel_name, message) | |
break | |
except Exception as e: | |
logger.error(f"Error occurred in function {func.__name__}: {e}. Exiting scheduler") | |
break | |
@wraps(func) | |
def sync_wrapper(*args: Any, **kwargs: Any) -> Any: | |
channel_name, _, message = _examine_task(func, *args) | |
while True: | |
try: | |
message["last_run"] = now() | |
func(*args, **kwargs) | |
message["last_run_successful"] = True | |
if (now() - message["last_run"]).total_seconds() < interval: | |
time.sleep(int(interval - (now() - message["last_run"]).total_seconds()) + 1) | |
else: | |
asyncio.run(get_channel_layer().send(channel_name, message)) | |
break | |
except Exception as e: | |
logger.error(f"Error occurred in function {func.__name__}: {e}. Exiting scheduler") | |
break | |
return async_wrapper if is_async else sync_wrapper | |
return decorator | |
def rate_limiter(min_time_between_runs=1) -> Callable: | |
""" | |
Limits how frequency a task can be run. Stores the last run | |
time stamp in the consumer class. | |
:param min_time_between_runs: in seconds | |
:return: | |
""" | |
def decorator(func: Callable) -> Callable: | |
is_async = inspect.iscoroutinefunction(func) | |
@wraps(func) | |
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: | |
consumer = args[0] | |
last_run = getattr(consumer, f"last_run__{func.__name__}", None) | |
if last_run is not None: | |
if (now() - last_run).total_seconds() < min_time_between_runs: | |
logger.info( | |
f"Task {func.__name__} is rate limited. Sleeping for {min_time_between_runs - (now() - last_run).total_seconds()} seconds") | |
await asyncio.sleep(min_time_between_runs - (now() - last_run).total_seconds()) | |
setattr(consumer, f"last_run__{func.__name__}", now()) | |
return await func(*args, **kwargs) | |
@wraps(func) | |
def sync_wrapper(*args: Any, **kwargs: Any) -> Any: | |
consumer = args[0] | |
last_run = getattr(consumer, f"last_run__{func.__name__}", None) | |
if last_run is not None: | |
if (now() - last_run).total_seconds() < min_time_between_runs: | |
logger.info( | |
f"Task {func.__name__} is rate limited. Sleeping for {min_time_between_runs - (now() - last_run).total_seconds()} seconds") | |
time.sleep(min_time_between_runs - (now() - last_run).total_seconds()) | |
setattr(consumer, f"last_run__{func.__name__}", now()) | |
return func(*args, **kwargs) | |
return async_wrapper if is_async else sync_wrapper | |
return decorator | |
def _make_keys(func, *args): | |
channel_name, _, _ = _examine_task(func, *args) | |
key_base = f"{get_redis_prefix()}_bkgnd_{channel_name}_{func.__name__}" | |
key_success = f"{key_base}_success" | |
key_failed = f"{key_base}_failed" | |
key_pending = f"{key_base}_pending" | |
return key_success, key_failed, key_pending | |
def job_counter(func: Callable) -> Callable: | |
""" | |
Decorator to keep track of how many times a task has succeeded, failed, | |
and how many tasks are pending. | |
:param func: | |
:return: | |
""" | |
@wraps(func) | |
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: | |
key_success, key_failed, key_pending = _make_keys(func, *args) | |
try: | |
redis_connection.incr(key_pending) | |
result = await func(*args, **kwargs) | |
redis_connection.incr(key_success) | |
return result | |
except Exception as e: | |
redis_connection.lpush(key_failed, str(e)) | |
redis_connection.ltrim(key_failed, 0, 20) | |
raise e | |
finally: | |
redis_connection.decr(key_pending) | |
@wraps(func) | |
def sync_wrapper(*args: Any, **kwargs: Any) -> Any: | |
key_success, key_failed, key_pending = _make_keys(func, *args) | |
try: | |
redis_connection.incr(key_pending) | |
result = func(*args, **kwargs) | |
redis_connection.incr(key_success) | |
return result | |
except Exception as e: | |
redis_connection.lpush(key_failed, str(e)) | |
redis_connection.ltrim(key_failed, 0, 20) | |
raise e | |
finally: | |
redis_connection.decr(key_pending) | |
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper | |
class PrintSyncConsumer(SyncConsumer): | |
""" | |
An example consumer that simply pretends to do some work | |
but fails and is re-queued. | |
""" | |
@requeue_task(max_retries=3, delay=2) | |
def test_print(self, message): | |
print("Received: " + message["name"]) | |
print("Received: " + message["email"]) | |
assert False | |
class RateLimitedAsyncExecutor(AsyncConsumer): | |
""" | |
Send multiple messages to kick it off, and watch it limit | |
the execution rate. Also uses @job_counter to | |
track the number of successes and failures. | |
asyncio.run(channel_layer.send("athrottle", { | |
"type": "test.print", | |
"name": "test", | |
"email": "email", | |
})) | |
""" | |
@job_counter | |
@rate_limiter(min_time_between_runs=5) | |
async def test_print(self, message): | |
print("Received: " + message["name"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment