Last active
June 9, 2024 05:30
-
-
Save gustabot42/9301b81b2eafd4de065bc5b3c3b72f94 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import asyncio | |
import time | |
from collections import deque | |
from dataclasses import dataclass | |
from dataclasses import field | |
from typing import ClassVar | |
from typing import Generic | |
from typing import TypeVar | |
from pydantic import BaseModel | |
from result import Err | |
from result import Ok | |
from hermes.utils import coalesce | |
from hermes.utils.asyncio import EventResult | |
from hermes.utils.asyncio import delay_awaitable | |
from hermes.utils.logging import logging | |
ActionType = TypeVar("ActionType", bound=BaseModel) | |
@dataclass | |
class Buffer: | |
_values: deque | None = field(default_factory=deque) | |
event_result: EventResult = field(default_factory=EventResult, init=False) | |
def __len__(self) -> int: | |
return len(self._values) if self._values is not None else 0 | |
def append(self, value) -> None: | |
if self._values is not None: | |
self._values.append(value) | |
def expend(self) -> list | None: | |
if self._values is None: | |
return None | |
values = self._values | |
self._values = None | |
return values | |
def is_exhausted(self) -> bool: | |
return self._values is None | |
@dataclass | |
class ActionBulk(Generic[ActionType]): | |
BUFFER_SIZE: ClassVar[int] | |
ACTION_DELAY_SEC: ClassVar[float] | |
AWAIT_TIMEOUT_SEC: ClassVar[float] | |
_buffer: Buffer = None | |
_background_tasks: set = field(default_factory=set) | |
def _create_task(self, buffer: Buffer, delay: float | None = None) -> None: | |
delay = coalesce(delay, self.ACTION_DELAY_SEC) | |
task = asyncio.create_task(delay_awaitable(delay, self.action, buffer)) | |
self._background_tasks.add(task) | |
task.add_done_callback(self._background_tasks.discard) | |
async def register(self, obj: ActionType) -> EventResult: | |
if self._buffer and len(self._buffer) >= self.BUFFER_SIZE: | |
self._create_task(self._buffer, delay=0) | |
self._buffer = None | |
if self._buffer is None or self._buffer.is_exhausted(): | |
self._buffer = Buffer() | |
self._create_task(self._buffer) | |
self._buffer.append(obj) | |
return self._buffer.event_result | |
async def action(self, buffer: Buffer) -> None: | |
values = buffer.expend() | |
if values is None: | |
return | |
try: | |
tic = time.time() | |
await self.perform_action(values) | |
except Exception as e: | |
error_msg = str(e) | |
msg = f"{self.__class__} action error: {error_msg}" | |
logging.error(msg) | |
buffer.event_result.set(Err(error_msg)) | |
finally: | |
elapsed = time.time() - tic | |
if elapsed > self.ACTION_DELAY_SEC: | |
msg = f"{self.__class__} action delay: {elapsed}s" | |
logging.error(msg) | |
buffer.event_result.set(Ok(None)) | |
async def wait(self) -> None: | |
await asyncio.gather(*self._background_tasks) | |
@staticmethod | |
async def perform_action(objs: deque[ActionType]) -> None: | |
"""Ensure the action is idempotent to handle multiple invocations for the same objects, | |
as it can't be cancelled and may be repeated due to latency issues.""" | |
raise NotImplementedError |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import pytest | |
from pydantic import BaseModel | |
from result import Err | |
from result import Ok | |
from hermes.db.action_bulk import ActionBulk | |
class ActionModel(BaseModel): | |
value: int | |
class PassBulk(ActionBulk[ActionModel]): | |
BUFFER_SIZE = 2 | |
ACTION_DELAY_SEC = 0.1 | |
AWAIT_TIMEOUT_SEC = 0.2 | |
@staticmethod | |
async def perform_action(objs: list[ActionModel]) -> None: | |
pass | |
class ErrorBulk(ActionBulk[ActionModel]): | |
BUFFER_SIZE = 2 | |
ACTION_DELAY_SEC = 0.1 | |
AWAIT_TIMEOUT_SEC = 0.2 | |
# ruff: noqa: ARG004 | |
@staticmethod | |
async def perform_action(objs: list[ActionModel]) -> None: | |
msg = "error message" | |
raise ValueError(msg) | |
class TimeoutBulk(ActionBulk[ActionModel]): | |
BUFFER_SIZE = 2 | |
ACTION_DELAY_SEC = 0.1 | |
AWAIT_TIMEOUT_SEC = 0.2 | |
@staticmethod | |
async def perform_action(objs: list[ActionModel]) -> None: | |
await asyncio.sleep(0.5) | |
@pytest.mark.asyncio() | |
async def test_pass_register_and_wait_for(): | |
bulk = PassBulk() | |
event_result = await bulk.register(ActionModel(value=1)) | |
assert event_result.is_set() is False | |
assert await event_result.wait_for() == Ok(None) | |
assert event_result.is_set() is True | |
@pytest.mark.asyncio() | |
async def test_error_register_and_wait_for(): | |
bulk = ErrorBulk() | |
event_result = await bulk.register(ActionModel(value=1)) | |
assert event_result.is_set() is False | |
assert await event_result.wait_for() == Err("error message") | |
assert event_result.is_set() is True | |
@pytest.mark.asyncio() | |
async def test_timeout_register_and_wait_for(): | |
bulk = TimeoutBulk() | |
event_result = await bulk.register(ActionModel(value=1)) | |
assert event_result.is_set() is False | |
assert await event_result.wait_for(0.3) == Err("TimeoutError") | |
assert event_result.is_set() is False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment