Skip to content

Instantly share code, notes, and snippets.

@joshuadavidthomas
Created November 13, 2024 22:11
Show Gist options
  • Save joshuadavidthomas/4526473cba996e99d695b2d7ed0619bd to your computer and use it in GitHub Desktop.
Save joshuadavidthomas/4526473cba996e99d695b2d7ed0619bd to your computer and use it in GitHub Desktop.
Decorator that auto-generates corresponding sync methods from async methods, using `asgiref.sync.async_to_sync`
from __future__ import annotations
import functools
import inspect
import re
from collections.abc import Awaitable
from collections.abc import Sequence
from typing import Any
from typing import Callable
from typing import Generic
from typing import ParamSpec
from typing import Protocol
from typing import TypeVar
from typing import overload
from asgiref.sync import async_to_sync
T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")
AsyncMethod = Callable[P, Awaitable[R]]
SyncMethod = Callable[P, R]
Decorator = Callable[[type[T]], type[T]]
class HasAsyncMethods(Protocol):
def __getattr__(self, name: str) -> AsyncMethod[..., Any]: ...
class HasSyncMethods(Protocol):
def __getattr__(self, name: str) -> SyncMethod[..., Any]: ...
T_WithSync = TypeVar("T_WithSync", bound=HasAsyncMethods)
class SyncMethodsDecorator(Generic[T]):
"""Class decorator that creates sync versions of async methods."""
def __init__(
self,
prefix: str = "a",
ignore_methods: Sequence[str] | None = None,
ignore_pattern: str | None = None,
include_private: bool = False,
include_dunder: bool = True,
inherit: bool = False,
) -> None:
self.prefix = prefix
self.ignored = set(ignore_methods or [])
self.ignore_re = re.compile(ignore_pattern) if ignore_pattern else None
self.include_private = include_private
self.include_dunder = include_dunder
self.inherit = inherit
def __call__(self, cls: type[T]) -> type[T]:
async_methods = self._get_async_methods(cls)
methods_to_process = self._filter_methods(async_methods)
if not methods_to_process:
return cls
for method in methods_to_process:
self._add_sync_method(cls, method)
return cls
def _get_async_methods(
self, cls: type[T]
) -> Sequence[tuple[str, AsyncMethod[..., Any]]]:
"""Get async methods based on inheritance preference."""
if self.inherit:
async_methods = inspect.getmembers(
cls, predicate=inspect.iscoroutinefunction
)
else:
async_methods = [
(name, method)
for name, method in inspect.getmembers(
cls, predicate=inspect.iscoroutinefunction
)
if method.__qualname__.startswith(cls.__qualname__)
]
return async_methods
def _filter_methods(
self, methods: list[tuple[str, AsyncMethod[..., Any]]]
) -> Sequence[tuple[str, AsyncMethod[..., Any]]]:
"""Filter methods based on configuration."""
filtered: list[tuple[str, AsyncMethod[..., Any]]] = []
for name, method in methods:
is_dunder = name.startswith("__") and name.endswith("__")
is_private = name.startswith("_") and not is_dunder
if name in self.ignored:
continue
if self.ignore_re and self.ignore_re.match(name):
continue
if is_private and not self.include_private:
continue
if is_dunder and not self.include_dunder:
continue
if not is_dunder and not name.removeprefix("_").startswith(self.prefix):
continue
filtered.append((name, method))
return filtered
def _add_sync_method(
self, cls: type[T], method_info: tuple[str, AsyncMethod[..., Any]]
) -> None:
"""Add a sync version of the async method to the class."""
async_name, async_method = method_info
if async_name.startswith("__") and async_name.endswith("__"):
inner_name = async_name[2:-2]
if inner_name.startswith(self.prefix):
sync_name = f"__{inner_name[1:]}__"
else:
sync_name = async_name
elif async_name.startswith("_"):
sync_name = f"_{async_name[1:][len(self.prefix):]}"
else:
sync_name = async_name[len(self.prefix) :]
if hasattr(cls, sync_name):
return
@functools.wraps(async_method)
def sync_method(
self: Any,
*args: P.args,
_async_method: AsyncMethod[P, R] = async_method,
**kwargs: P.kwargs,
) -> R:
bound_method = _async_method.__get__(self, self.__class__)
return async_to_sync(bound_method)(*args, **kwargs)
sync_method.__wrapped__ = async_method
sync_method.__doc__ = (
f"Synchronous version of {async_name}.\n\n{async_method.__doc__ or ''}"
)
setattr(cls, sync_name, sync_method)
@overload
def with_sync_methods(cls: type[T]) -> type[T]: ...
@overload
def with_sync_methods(
cls: None = None,
*,
prefix: str = "a",
ignore_methods: Sequence[str] | None = None,
ignore_pattern: str | None = None,
include_private: bool = False,
include_dunder: bool = True,
inherit: bool = False,
) -> Callable[[type[T]], type[T]]: ...
def with_sync_methods(
cls: type[T] | None = None,
*,
prefix: str = "a",
ignore_methods: Sequence[str] | None = None,
ignore_pattern: str | None = None,
include_private: bool = False,
include_dunder: bool = True,
inherit: bool = False,
) -> type[T] | Callable[[type[T]], type[T]]:
"""
Class decorator that creates sync versions of async methods.
"""
decorator = SyncMethodsDecorator[T](
prefix=prefix,
ignore_methods=ignore_methods,
ignore_pattern=ignore_pattern,
include_private=include_private,
include_dunder=include_dunder,
inherit=inherit,
)
if cls is None:
return decorator
return decorator(cls)
from __future__ import annotations
from typing import Any
from sync import with_sync_methods
class TestSyncMethodsDecorator:
def test_basic_conversion(self):
@with_sync_methods
class TestClass:
async def amethod(self) -> str:
return "test"
obj = TestClass()
assert hasattr(obj, "method")
assert obj.method() == "test"
def test_custom_prefix(self):
@with_sync_methods(prefix="async_")
class TestClass:
async def async_method(self) -> str:
return "test"
obj = TestClass()
assert hasattr(obj, "method")
assert obj.method() == "test"
def test_ignore_methods(self):
@with_sync_methods(ignore_methods=["aignored"])
class TestClass:
async def amethod(self) -> str:
return "test"
async def aignored(self) -> str:
return "ignored"
obj = TestClass()
assert hasattr(obj, "method")
assert not hasattr(obj, "ignored")
def test_ignore_pattern(self):
@with_sync_methods(ignore_pattern=r"a.*_skip$")
class TestClass:
async def amethod(self) -> str:
return "test"
async def amethod_skip(self) -> str:
return "skipped"
obj = TestClass()
assert hasattr(obj, "method")
assert not hasattr(obj, "method_skip")
def test_private_methods(self):
@with_sync_methods(include_private=True)
class TestClass:
async def amethod(self) -> str:
return "public"
async def _amethod(self) -> str:
return "private"
obj = TestClass()
assert hasattr(obj, "method")
assert hasattr(obj, "_method")
assert obj._method() == "private"
def test_dunder_methods(self):
@with_sync_methods
class TestClass:
async def __aenter__(self) -> str:
return "entered"
async def __aexit__(self, *args: Any) -> None:
pass
obj = TestClass()
assert hasattr(obj, "__enter__")
assert obj.__enter__() == "entered"
def test_inheritance(self):
class BaseClass:
async def amethod(self) -> str:
return "base"
@with_sync_methods(inherit=True)
class ChildClass(BaseClass):
async def achild_method(self) -> str:
return "child"
obj = ChildClass()
assert hasattr(obj, "method")
assert hasattr(obj, "child_method")
assert obj.method() == "base"
assert obj.child_method() == "child"
def test_method_documentation(self):
@with_sync_methods
class TestClass:
async def amethod(self) -> str:
"""Test docstring."""
return "test"
assert "Synchronous version of amethod" in TestClass.method.__doc__
assert "Test docstring" in TestClass.method.__doc__
def test_no_async_methods(self):
@with_sync_methods
class TestClass:
def method(self) -> str:
return "test"
obj = TestClass()
assert not hasattr(obj, "sync_method")
assert obj.method() == "test"
def test_existing_sync_method(self):
@with_sync_methods
class TestClass:
async def amethod(self) -> str:
return "async"
def method(self) -> str:
return "sync"
obj = TestClass()
assert obj.method() == "sync"
def test_decorator_with_args(self):
decorator = with_sync_methods(prefix="async_")
assert callable(decorator)
@decorator
class TestClass:
async def async_method(self) -> str:
return "test"
obj = TestClass()
assert hasattr(obj, "method")
assert obj.method() == "test"
@joshuadavidthomas
Copy link
Author

Type checkers hate this one trick.. but it works!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment