Last active
December 3, 2023 04:31
-
-
Save cofob/6b89f4d62efb4eddb26a3c8853128c84 to your computer and use it in GitHub Desktop.
Taskiq async log collection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Taskiq log handler.""" | |
import asyncio | |
from logging import Handler, LogRecord | |
from sys import stderr | |
from traceback import print_exc | |
from typing import Callable | |
class TaskiqHandler(Handler): | |
"""Taskiq log handler. | |
It receives all logs from logging and determines asyncio thread name, | |
then passes it to callback. Used to collect async task logs. | |
""" | |
def __init__(self, prefix: str, callback: Callable[[str, LogRecord], None], sync_callback: bool = False) -> None: | |
"""Initialize. | |
Args: | |
prefix (str): Log prefix. For example, "task-" | |
callback (Callable[[str, LogRecord], None]): Callback, which will be called | |
with task name and log record. | |
sync_callback (bool): If True, callback will be called even if log record is | |
emitted from non-async thread with empty task name. | |
""" | |
super().__init__() | |
self._prefix = prefix | |
self._prefix_len = len(self._prefix) | |
self._callback = callback | |
self._sync_callback = sync_callback | |
def _run_callback(self, name: str, record: LogRecord) -> None: | |
"""Run callback and handle exceptions.""" | |
try: | |
self._callback(name, record) | |
except Exception: | |
print("Exception in TaskiqHandler while processing record:", file=stderr) | |
print_exc(file=stderr) | |
def emit(self, record: LogRecord) -> None: | |
"""Emit log record.""" | |
try: | |
current_task = asyncio.current_task() | |
except RuntimeError: | |
current_task = None | |
if current_task is None: | |
if self._sync_callback: | |
self._run_callback("", record) | |
return | |
name = current_task.get_name() | |
if not name.startswith(self._prefix): | |
return | |
self._run_callback(name[self._prefix_len :], record) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Taskiq receiver that can collect async task logs.""" | |
from asyncio import create_task | |
from concurrent.futures import Executor | |
from logging import Formatter, LogRecord, getLogger, root | |
from typing import Any, Callable | |
from taskiq import AsyncBroker, TaskiqMessage, TaskiqResult | |
from taskiq.receiver import Receiver | |
from log_handler import TaskiqHandler | |
logger = getLogger(__name__) | |
class ReceiverWithLog(Receiver): | |
"""Taskiq receiver that can collect async task logs.""" | |
def __init__( | |
self, | |
broker: AsyncBroker, | |
executor: Executor | None = None, | |
validate_params: bool = True, | |
max_async_tasks: int | None = None, | |
max_prefetch: int = 0, | |
propagate_exceptions: bool = True, | |
) -> None: | |
self._logging_handler = TaskiqHandler("task-", self._log_callback) | |
self._logging_formatter = Formatter( | |
fmt="[%(asctime)s] [%(name)s] [%(levelname)s] > %(message)s", datefmt="%Y-%m-%d %H:%M:%S" | |
) | |
self._log_data: dict[str, list[LogRecord]] = {} | |
root.addHandler(self._logging_handler) | |
super().__init__(broker, executor, validate_params, max_async_tasks, max_prefetch, propagate_exceptions) | |
def _log_callback(self, name: str, log: LogRecord) -> None: | |
"""Log callback.""" | |
if name not in self._log_data: | |
self._log_data[name] = [] | |
self._log_data[name].append(log) | |
async def run_task( | |
self, target: Callable[..., Any], message: TaskiqMessage, retry_count: int = 0 | |
) -> TaskiqResult[Any]: | |
"""Run task.""" | |
# Start task | |
task = create_task(super().run_task(target, message), name=f"task-{message.task_id}") | |
# Wait for task to finish | |
await task | |
# Get task logs | |
logs = self._get_task_logs(message.task_id) | |
# Get task result | |
result = task.result() | |
result.log = logs | |
return result | |
def _get_task_logs(self, task_id: str) -> str | None: | |
"""Get task logs.""" | |
logs = self._log_data.pop(task_id, []) | |
if len(logs) == 0: | |
return None | |
formatted_logs = "" | |
for log in logs: | |
formatted_logs += "\n" + self._logging_formatter.format(log) | |
formatted_logs = formatted_logs.strip() | |
logger.debug("Collected %d log entries for task %s", len(logs), task_id) | |
return formatted_logs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment