Created
November 13, 2024 22:11
-
-
Save joshuadavidthomas/4526473cba996e99d695b2d7ed0619bd to your computer and use it in GitHub Desktop.
Decorator that auto-generates corresponding sync methods from async methods, using `asgiref.sync.async_to_sync`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import functools | |
import inspect | |
import re | |
from collections.abc import Awaitable | |
from collections.abc import Sequence | |
from typing import Any | |
from typing import Callable | |
from typing import Generic | |
from typing import ParamSpec | |
from typing import Protocol | |
from typing import TypeVar | |
from typing import overload | |
from asgiref.sync import async_to_sync | |
T = TypeVar("T") | |
P = ParamSpec("P") | |
R = TypeVar("R") | |
AsyncMethod = Callable[P, Awaitable[R]] | |
SyncMethod = Callable[P, R] | |
Decorator = Callable[[type[T]], type[T]] | |
class HasAsyncMethods(Protocol): | |
def __getattr__(self, name: str) -> AsyncMethod[..., Any]: ... | |
class HasSyncMethods(Protocol): | |
def __getattr__(self, name: str) -> SyncMethod[..., Any]: ... | |
T_WithSync = TypeVar("T_WithSync", bound=HasAsyncMethods) | |
class SyncMethodsDecorator(Generic[T]): | |
"""Class decorator that creates sync versions of async methods.""" | |
def __init__( | |
self, | |
prefix: str = "a", | |
ignore_methods: Sequence[str] | None = None, | |
ignore_pattern: str | None = None, | |
include_private: bool = False, | |
include_dunder: bool = True, | |
inherit: bool = False, | |
) -> None: | |
self.prefix = prefix | |
self.ignored = set(ignore_methods or []) | |
self.ignore_re = re.compile(ignore_pattern) if ignore_pattern else None | |
self.include_private = include_private | |
self.include_dunder = include_dunder | |
self.inherit = inherit | |
def __call__(self, cls: type[T]) -> type[T]: | |
async_methods = self._get_async_methods(cls) | |
methods_to_process = self._filter_methods(async_methods) | |
if not methods_to_process: | |
return cls | |
for method in methods_to_process: | |
self._add_sync_method(cls, method) | |
return cls | |
def _get_async_methods( | |
self, cls: type[T] | |
) -> Sequence[tuple[str, AsyncMethod[..., Any]]]: | |
"""Get async methods based on inheritance preference.""" | |
if self.inherit: | |
async_methods = inspect.getmembers( | |
cls, predicate=inspect.iscoroutinefunction | |
) | |
else: | |
async_methods = [ | |
(name, method) | |
for name, method in inspect.getmembers( | |
cls, predicate=inspect.iscoroutinefunction | |
) | |
if method.__qualname__.startswith(cls.__qualname__) | |
] | |
return async_methods | |
def _filter_methods( | |
self, methods: list[tuple[str, AsyncMethod[..., Any]]] | |
) -> Sequence[tuple[str, AsyncMethod[..., Any]]]: | |
"""Filter methods based on configuration.""" | |
filtered: list[tuple[str, AsyncMethod[..., Any]]] = [] | |
for name, method in methods: | |
is_dunder = name.startswith("__") and name.endswith("__") | |
is_private = name.startswith("_") and not is_dunder | |
if name in self.ignored: | |
continue | |
if self.ignore_re and self.ignore_re.match(name): | |
continue | |
if is_private and not self.include_private: | |
continue | |
if is_dunder and not self.include_dunder: | |
continue | |
if not is_dunder and not name.removeprefix("_").startswith(self.prefix): | |
continue | |
filtered.append((name, method)) | |
return filtered | |
def _add_sync_method( | |
self, cls: type[T], method_info: tuple[str, AsyncMethod[..., Any]] | |
) -> None: | |
"""Add a sync version of the async method to the class.""" | |
async_name, async_method = method_info | |
if async_name.startswith("__") and async_name.endswith("__"): | |
inner_name = async_name[2:-2] | |
if inner_name.startswith(self.prefix): | |
sync_name = f"__{inner_name[1:]}__" | |
else: | |
sync_name = async_name | |
elif async_name.startswith("_"): | |
sync_name = f"_{async_name[1:][len(self.prefix):]}" | |
else: | |
sync_name = async_name[len(self.prefix) :] | |
if hasattr(cls, sync_name): | |
return | |
@functools.wraps(async_method) | |
def sync_method( | |
self: Any, | |
*args: P.args, | |
_async_method: AsyncMethod[P, R] = async_method, | |
**kwargs: P.kwargs, | |
) -> R: | |
bound_method = _async_method.__get__(self, self.__class__) | |
return async_to_sync(bound_method)(*args, **kwargs) | |
sync_method.__wrapped__ = async_method | |
sync_method.__doc__ = ( | |
f"Synchronous version of {async_name}.\n\n{async_method.__doc__ or ''}" | |
) | |
setattr(cls, sync_name, sync_method) | |
@overload | |
def with_sync_methods(cls: type[T]) -> type[T]: ... | |
@overload | |
def with_sync_methods( | |
cls: None = None, | |
*, | |
prefix: str = "a", | |
ignore_methods: Sequence[str] | None = None, | |
ignore_pattern: str | None = None, | |
include_private: bool = False, | |
include_dunder: bool = True, | |
inherit: bool = False, | |
) -> Callable[[type[T]], type[T]]: ... | |
def with_sync_methods( | |
cls: type[T] | None = None, | |
*, | |
prefix: str = "a", | |
ignore_methods: Sequence[str] | None = None, | |
ignore_pattern: str | None = None, | |
include_private: bool = False, | |
include_dunder: bool = True, | |
inherit: bool = False, | |
) -> type[T] | Callable[[type[T]], type[T]]: | |
""" | |
Class decorator that creates sync versions of async methods. | |
""" | |
decorator = SyncMethodsDecorator[T]( | |
prefix=prefix, | |
ignore_methods=ignore_methods, | |
ignore_pattern=ignore_pattern, | |
include_private=include_private, | |
include_dunder=include_dunder, | |
inherit=inherit, | |
) | |
if cls is None: | |
return decorator | |
return decorator(cls) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import Any | |
from sync import with_sync_methods | |
class TestSyncMethodsDecorator: | |
def test_basic_conversion(self): | |
@with_sync_methods | |
class TestClass: | |
async def amethod(self) -> str: | |
return "test" | |
obj = TestClass() | |
assert hasattr(obj, "method") | |
assert obj.method() == "test" | |
def test_custom_prefix(self): | |
@with_sync_methods(prefix="async_") | |
class TestClass: | |
async def async_method(self) -> str: | |
return "test" | |
obj = TestClass() | |
assert hasattr(obj, "method") | |
assert obj.method() == "test" | |
def test_ignore_methods(self): | |
@with_sync_methods(ignore_methods=["aignored"]) | |
class TestClass: | |
async def amethod(self) -> str: | |
return "test" | |
async def aignored(self) -> str: | |
return "ignored" | |
obj = TestClass() | |
assert hasattr(obj, "method") | |
assert not hasattr(obj, "ignored") | |
def test_ignore_pattern(self): | |
@with_sync_methods(ignore_pattern=r"a.*_skip$") | |
class TestClass: | |
async def amethod(self) -> str: | |
return "test" | |
async def amethod_skip(self) -> str: | |
return "skipped" | |
obj = TestClass() | |
assert hasattr(obj, "method") | |
assert not hasattr(obj, "method_skip") | |
def test_private_methods(self): | |
@with_sync_methods(include_private=True) | |
class TestClass: | |
async def amethod(self) -> str: | |
return "public" | |
async def _amethod(self) -> str: | |
return "private" | |
obj = TestClass() | |
assert hasattr(obj, "method") | |
assert hasattr(obj, "_method") | |
assert obj._method() == "private" | |
def test_dunder_methods(self): | |
@with_sync_methods | |
class TestClass: | |
async def __aenter__(self) -> str: | |
return "entered" | |
async def __aexit__(self, *args: Any) -> None: | |
pass | |
obj = TestClass() | |
assert hasattr(obj, "__enter__") | |
assert obj.__enter__() == "entered" | |
def test_inheritance(self): | |
class BaseClass: | |
async def amethod(self) -> str: | |
return "base" | |
@with_sync_methods(inherit=True) | |
class ChildClass(BaseClass): | |
async def achild_method(self) -> str: | |
return "child" | |
obj = ChildClass() | |
assert hasattr(obj, "method") | |
assert hasattr(obj, "child_method") | |
assert obj.method() == "base" | |
assert obj.child_method() == "child" | |
def test_method_documentation(self): | |
@with_sync_methods | |
class TestClass: | |
async def amethod(self) -> str: | |
"""Test docstring.""" | |
return "test" | |
assert "Synchronous version of amethod" in TestClass.method.__doc__ | |
assert "Test docstring" in TestClass.method.__doc__ | |
def test_no_async_methods(self): | |
@with_sync_methods | |
class TestClass: | |
def method(self) -> str: | |
return "test" | |
obj = TestClass() | |
assert not hasattr(obj, "sync_method") | |
assert obj.method() == "test" | |
def test_existing_sync_method(self): | |
@with_sync_methods | |
class TestClass: | |
async def amethod(self) -> str: | |
return "async" | |
def method(self) -> str: | |
return "sync" | |
obj = TestClass() | |
assert obj.method() == "sync" | |
def test_decorator_with_args(self): | |
decorator = with_sync_methods(prefix="async_") | |
assert callable(decorator) | |
@decorator | |
class TestClass: | |
async def async_method(self) -> str: | |
return "test" | |
obj = TestClass() | |
assert hasattr(obj, "method") | |
assert obj.method() == "test" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Type checkers hate this one trick.. but it works!