Skip to content

Instantly share code, notes, and snippets.

@cofob
Last active December 3, 2023 04:31
Show Gist options
  • Save cofob/6b89f4d62efb4eddb26a3c8853128c84 to your computer and use it in GitHub Desktop.
Save cofob/6b89f4d62efb4eddb26a3c8853128c84 to your computer and use it in GitHub Desktop.
Taskiq async log collection
"""Taskiq log handler."""
import asyncio
from logging import Handler, LogRecord
from sys import stderr
from traceback import print_exc
from typing import Callable
class TaskiqHandler(Handler):
"""Taskiq log handler.
It receives all logs from logging and determines asyncio thread name,
then passes it to callback. Used to collect async task logs.
"""
def __init__(self, prefix: str, callback: Callable[[str, LogRecord], None], sync_callback: bool = False) -> None:
"""Initialize.
Args:
prefix (str): Log prefix. For example, "task-"
callback (Callable[[str, LogRecord], None]): Callback, which will be called
with task name and log record.
sync_callback (bool): If True, callback will be called even if log record is
emitted from non-async thread with empty task name.
"""
super().__init__()
self._prefix = prefix
self._prefix_len = len(self._prefix)
self._callback = callback
self._sync_callback = sync_callback
def _run_callback(self, name: str, record: LogRecord) -> None:
"""Run callback and handle exceptions."""
try:
self._callback(name, record)
except Exception:
print("Exception in TaskiqHandler while processing record:", file=stderr)
print_exc(file=stderr)
def emit(self, record: LogRecord) -> None:
"""Emit log record."""
try:
current_task = asyncio.current_task()
except RuntimeError:
current_task = None
if current_task is None:
if self._sync_callback:
self._run_callback("", record)
return
name = current_task.get_name()
if not name.startswith(self._prefix):
return
self._run_callback(name[self._prefix_len :], record)
"""Taskiq receiver that can collect async task logs."""
from asyncio import create_task
from concurrent.futures import Executor
from logging import Formatter, LogRecord, getLogger, root
from typing import Any, Callable
from taskiq import AsyncBroker, TaskiqMessage, TaskiqResult
from taskiq.receiver import Receiver
from log_handler import TaskiqHandler
logger = getLogger(__name__)
class ReceiverWithLog(Receiver):
"""Taskiq receiver that can collect async task logs."""
def __init__(
self,
broker: AsyncBroker,
executor: Executor | None = None,
validate_params: bool = True,
max_async_tasks: int | None = None,
max_prefetch: int = 0,
propagate_exceptions: bool = True,
) -> None:
self._logging_handler = TaskiqHandler("task-", self._log_callback)
self._logging_formatter = Formatter(
fmt="[%(asctime)s] [%(name)s] [%(levelname)s] > %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
self._log_data: dict[str, list[LogRecord]] = {}
root.addHandler(self._logging_handler)
super().__init__(broker, executor, validate_params, max_async_tasks, max_prefetch, propagate_exceptions)
def _log_callback(self, name: str, log: LogRecord) -> None:
"""Log callback."""
if name not in self._log_data:
self._log_data[name] = []
self._log_data[name].append(log)
async def run_task(
self, target: Callable[..., Any], message: TaskiqMessage, retry_count: int = 0
) -> TaskiqResult[Any]:
"""Run task."""
# Start task
task = create_task(super().run_task(target, message), name=f"task-{message.task_id}")
# Wait for task to finish
await task
# Get task logs
logs = self._get_task_logs(message.task_id)
# Get task result
result = task.result()
result.log = logs
return result
def _get_task_logs(self, task_id: str) -> str | None:
"""Get task logs."""
logs = self._log_data.pop(task_id, [])
if len(logs) == 0:
return None
formatted_logs = ""
for log in logs:
formatted_logs += "\n" + self._logging_formatter.format(log)
formatted_logs = formatted_logs.strip()
logger.debug("Collected %d log entries for task %s", len(logs), task_id)
return formatted_logs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment