Skip to content

Instantly share code, notes, and snippets.

@kbhalerao
Created April 23, 2025 14:29
Show Gist options
  • Save kbhalerao/c11b0d90ced8d49389112abef4e3d408 to your computer and use it in GitHub Desktop.
Save kbhalerao/c11b0d90ced8d49389112abef4e3d408 to your computer and use it in GitHub Desktop.
Some decorators to make Channels Background Worker consumers a bit more resilient to failure
import asyncio
import inspect
import logging
import time
import traceback
from functools import wraps
from typing import Callable, Any
from channels.layers import get_channel_layer
from django.utils.timezone import now
from django.conf import settings
import redis
def get_redis_con(conn='default'):
"""
Provides the default redis connection
"""
try:
host, port = settings.CHANNEL_LAYERS[conn]['CONFIG']['hosts'][0]
return redis.Redis(host=host, port=port, db=0, decode_responses=True)
except ValueError:
url = settings.CHANNEL_LAYERS[conn]['CONFIG']['hosts'][0]
return redis.Redis.from_url(url=url, decode_responses=True)
def get_redis_prefix(conn='default'):
try:
return settings.CHANNEL_LAYERS[conn]['CONFIG']['prefix']
except ValueError:
return settings.CHANNEL_LAYERS[conn]['CONFIG']['CHANNEL_NAME_PREFIX']
logger = logging.getLogger(__name__)
redis_connection = get_redis_con()
def _handle_task_error(
func_name: str,
error: Exception,
channel_name: str,
message: dict,
retries: int,
max_retries: int,
delay: int
) -> bool:
"""Helper function to log error information and display traceback."""
tb_lines = traceback.format_tb(error.__traceback__)
logger.error(f"Error occurred in function {func_name}: {error}")
logger.info(f"Requeuing task on channel layer '{channel_name}' again in {delay} seconds...")
if retries >= max_retries:
logger.error("Maximum retries reached. Task failed.")
logger.error("Traceback (last 10 lines): \n" + "".join(tb_lines[-10:]))
return False
return True
def _examine_task(func, *args):
"""Helper function to extract channel name and message from task."""
assert args[0].scope['type'] == 'channel', "These decorators can only be used on background tasks"
channel_name = args[0].scope['channel']
message = args[1] if len(args) > 1 else {}
current_retries = message.get('retries', 0)
if current_retries > 0:
logger.info(f"Retrying task {func.__name__} attempt {current_retries + 1}...")
return channel_name, current_retries, message
def requeue_task(max_retries: int = 3, delay: int = 10) -> Callable:
"""
A decorator for wrapping background worker functions (both sync and async).
Provides the ability to retry the function for a specified number of times.
Args:
max_retries: Maximum number of retry attempts (default: 3)
delay: Delay in seconds before retrying (default: 10)
Returns:
Decorated function with retry capability
Raises:
ValueError: If channel_name is not provided
"""
def decorator(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
channel_name, current_retries, message = _examine_task(func, *args)
try:
return await func(*args, **kwargs)
except Exception as e:
should_retry = _handle_task_error(
func.__name__, e, channel_name, message,
current_retries, max_retries, delay
)
if should_retry:
await asyncio.sleep(delay)
message["retries"] = current_retries + 1
await get_channel_layer().send(channel_name, message)
return None
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
channel_name, current_retries, message = _examine_task(func, *args)
try:
return func(*args, **kwargs)
except Exception as e:
should_retry = _handle_task_error(
func.__name__, e, channel_name, message,
current_retries, max_retries, delay
)
if should_retry:
time.sleep(delay)
message["retries"] = current_retries + 1
asyncio.run(get_channel_layer().send(channel_name, message))
return None
return async_wrapper if is_async else sync_wrapper
return decorator
def run_continously(interval: int) -> Callable:
"""
A decorator to run a function periodically. Useful for setting up a periodic task like a cron job.
Please note that if the function passed to it raises an exception, the scheduler will stop running
immediately.
:param interval: Minimum interval between two successive runs of the function
:return:
"""
def decorator(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
channel_name, _, message = _examine_task(func, *args)
while True:
try:
message["last_run"] = now()
await func(*args, **kwargs)
message["last_run_successful"] = True
if (now() - message["last_run"]).total_seconds() < interval:
await asyncio.sleep(int(interval - (now() - message["last_run"]).total_seconds()) + 1)
else:
await get_channel_layer().send(channel_name, message)
break
except Exception as e:
logger.error(f"Error occurred in function {func.__name__}: {e}. Exiting scheduler")
break
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
channel_name, _, message = _examine_task(func, *args)
while True:
try:
message["last_run"] = now()
func(*args, **kwargs)
message["last_run_successful"] = True
if (now() - message["last_run"]).total_seconds() < interval:
time.sleep(int(interval - (now() - message["last_run"]).total_seconds()) + 1)
else:
asyncio.run(get_channel_layer().send(channel_name, message))
break
except Exception as e:
logger.error(f"Error occurred in function {func.__name__}: {e}. Exiting scheduler")
break
return async_wrapper if is_async else sync_wrapper
return decorator
def rate_limiter(min_time_between_runs=1) -> Callable:
"""
Limits how frequency a task can be run. Stores the last run
time stamp in the consumer class.
:param min_time_between_runs: in seconds
:return:
"""
def decorator(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
consumer = args[0]
last_run = getattr(consumer, f"last_run__{func.__name__}", None)
if last_run is not None:
if (now() - last_run).total_seconds() < min_time_between_runs:
logger.info(
f"Task {func.__name__} is rate limited. Sleeping for {min_time_between_runs - (now() - last_run).total_seconds()} seconds")
await asyncio.sleep(min_time_between_runs - (now() - last_run).total_seconds())
setattr(consumer, f"last_run__{func.__name__}", now())
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
consumer = args[0]
last_run = getattr(consumer, f"last_run__{func.__name__}", None)
if last_run is not None:
if (now() - last_run).total_seconds() < min_time_between_runs:
logger.info(
f"Task {func.__name__} is rate limited. Sleeping for {min_time_between_runs - (now() - last_run).total_seconds()} seconds")
time.sleep(min_time_between_runs - (now() - last_run).total_seconds())
setattr(consumer, f"last_run__{func.__name__}", now())
return func(*args, **kwargs)
return async_wrapper if is_async else sync_wrapper
return decorator
def _make_keys(func, *args):
channel_name, _, _ = _examine_task(func, *args)
key_base = f"{get_redis_prefix()}_bkgnd_{channel_name}_{func.__name__}"
key_success = f"{key_base}_success"
key_failed = f"{key_base}_failed"
key_pending = f"{key_base}_pending"
return key_success, key_failed, key_pending
def job_counter(func: Callable) -> Callable:
"""
Decorator to keep track of how many times a task has succeeded, failed,
and how many tasks are pending.
:param func:
:return:
"""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
key_success, key_failed, key_pending = _make_keys(func, *args)
try:
redis_connection.incr(key_pending)
result = await func(*args, **kwargs)
redis_connection.incr(key_success)
return result
except Exception as e:
redis_connection.lpush(key_failed, str(e))
redis_connection.ltrim(key_failed, 0, 20)
raise e
finally:
redis_connection.decr(key_pending)
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
key_success, key_failed, key_pending = _make_keys(func, *args)
try:
redis_connection.incr(key_pending)
result = func(*args, **kwargs)
redis_connection.incr(key_success)
return result
except Exception as e:
redis_connection.lpush(key_failed, str(e))
redis_connection.ltrim(key_failed, 0, 20)
raise e
finally:
redis_connection.decr(key_pending)
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
class PrintSyncConsumer(SyncConsumer):
"""
An example consumer that simply pretends to do some work
but fails and is re-queued.
"""
@requeue_task(max_retries=3, delay=2)
def test_print(self, message):
print("Received: " + message["name"])
print("Received: " + message["email"])
assert False
class RateLimitedAsyncExecutor(AsyncConsumer):
"""
Send multiple messages to kick it off, and watch it limit
the execution rate. Also uses @job_counter to
track the number of successes and failures.
asyncio.run(channel_layer.send("athrottle", {
"type": "test.print",
"name": "test",
"email": "email",
}))
"""
@job_counter
@rate_limiter(min_time_between_runs=5)
async def test_print(self, message):
print("Received: " + message["name"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment