Last active
March 9, 2025 15:18
-
-
Save ddelange/643fbb791b398783c04d1ceb90102163 to your computer and use it in GitHub Desktop.
Make a sync function async
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from functools import wraps, partial | |
def run_in_executor(fn=None, *, executor=None): | |
"""Make a sync function async. By default uses ThreadPoolExecutor. | |
Args: | |
fn: Function to decorate. | |
executor: Executor pool to execute fn in. | |
""" | |
if fn is None: | |
# allow using this decorator with brackets, e.g. | |
# @run_in_executor(executor=ThreadPoolExecutor(1)) | |
return partial(run_in_executor, executor=executor) | |
@wraps(fn) | |
async def wrapped(*args, **kwargs): | |
"""Wrap function in a run_in_executor.""" | |
_fn = partial(fn, *args, **kwargs) | |
if hasattr(executor, "coro_apply"): | |
# support aioprocessing.pool.AioPool | |
fut = executor.coro_apply(_fn) | |
else: | |
fut = asyncio.get_running_loop().run_in_executor(executor, _fn) | |
return await fut | |
return wrapped | |
async def examples(): | |
# without brackets | |
@run_in_executor | |
def test1(): | |
print(1) | |
await test1() | |
# with brackets | |
@run_in_executor() | |
def test2(): | |
print(2) | |
await test2() | |
# with explicit ThreadPoolExecutor | |
from concurrent.futures import ThreadPoolExecutor | |
@run_in_executor(executor=ThreadPoolExecutor(4)) | |
def test3(): | |
print(3) | |
await test3() | |
# with explicit AioPool | |
# pip install 'aioprocessing[dill]>=2' | |
from aioprocessing.pool import AioPool | |
@run_in_executor(executor=AioPool(4, maxtasksperchild=4)) | |
def test4(): | |
print(4) | |
await test4() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from functools import wraps, partial | |
from typing import Any, Awaitable, Callable, Optional | |
from executor import run_in_executor # from above | |
def async_proxy( | |
fn: Optional[Callable], *, predicate: Callable = None, executor: Any = None | |
) -> Awaitable: | |
"""Decorate a class or a function that constructs an object, wrapping in AsyncProxy. | |
Examples: | |
>>> open = async_proxy(open) | |
... async with await open("/tmp/tmp", "wb") as fp: | |
... await write(b"...") | |
3 | |
>>> [line async for line in await open("/tmp/tmp")] | |
['...'] | |
>>> from io import BytesIO | |
... ABytesIO = async_proxy(BytesIO) | |
... await (await ABytesIO()).write(b"...") | |
3 | |
>>> @async_proxy | |
... class ABytesIO(BytesIO): | |
... pass | |
... await (await ABytesIO()).write(b"...") | |
3 | |
>>> import botocore, boto3, smart_open | |
... client = boto3.session.Session().client("s3", config=botocore.client.Config(max_pool_connections=64, tcp_keepalive=True, retries={"max_attempts": 6, "mode": "adaptive"})) | |
... open = async_proxy(partial(smart_open.open, transport_params={"client": client})) | |
... async with await open("s3://bucket/tmp", "wb") as fp: | |
... await fp.write(b"...") | |
3 | |
Args: | |
fn: Function or class to decorate. | |
predicate: Function passed to filter() to determine which methods to proxy. | |
executor: Executor to be passed to run_in_executor. | |
Returns: | |
Decorated class or constructor. | |
""" | |
if fn is None: | |
# allow using this decorator with brackets | |
return partial(async_proxy, predicate=predicate, executor=executor) | |
@run_in_executor(executor=executor) | |
@wraps(fn) | |
def wrapped(*args, **kwargs) -> AsyncProxy: | |
"""Wrap function in a run_in_executor.""" | |
return AsyncProxy(fn(*args, **kwargs), predicate=predicate, executor=executor) | |
return wrapped | |
class AsyncProxy(object): | |
"""A threaded proxy where the user-facing methods are coroutines. | |
Examples: | |
>>> fp = AsyncProxy(open("/tmp/tmp", "wb")) | |
... await fp.write(b"...") | |
3 | |
>>> await fp.close() | |
>>> async with AsyncProxy(open("/tmp/tmp", "rb")) as fp: | |
... async for line in fp: | |
... print(line) | |
b'...' | |
>>> import gzip | |
... gzip = AsyncProxy(gzip) | |
... await gzip.decompress(b"") | |
b'' | |
""" | |
def __init__( | |
self, | |
obj: object, | |
predicate: Optional[Callable] = None, | |
executor=None, | |
): | |
"""Create an object with threaded proxies for (a subset of) existing methods. | |
Args: | |
obj: Object or class instance to proxy. | |
predicate: Function passed to filter() to determine which methods to proxy. | |
executor: Executor to be passed to run_in_executor. | |
""" | |
self.predicate = predicate or ( | |
# create proxied methods for all but magic methods | |
lambda tup: ( | |
# name, routine = tup | |
not tup[0].startswith("__") | |
and hasattr(tup[1], "__call__") | |
) | |
) | |
self.executor = executor | |
object.__setattr__(self, "_AsyncProxy__obj", obj) | |
object.__setattr__( | |
self, "_AsyncProxy__threaded", partial(run_in_executor, executor=self.executor) | |
) | |
for name, routine in filter(self.predicate, inspect.getmembers(obj)): | |
object.__setattr__(self, name, self.__threaded(routine)) | |
def __getattr__(self, attr): | |
"""Get proxied attr, fall back to original attr.""" | |
try: # https://stackoverflow.com/a/29268670/5511061 | |
return object.__getattribute__(self, attr) | |
except AttributeError: | |
return getattr(self.__obj, attr) | |
def __setattr__(self, attr, val): | |
"""Proxy setattr to the original object (__init__ circumvents this).""" | |
return setattr(self.__obj, attr, val) | |
def __aiter__(self): | |
return self.__class__(self.__iter__(), predicate=self.predicate, executor=self.executor) | |
async def __anext__(self): | |
obj = await self.__threaded(next)(self.__obj, ...) | |
if obj is ...: # https://stackoverflow.com/a/61774972/5511061 | |
raise StopAsyncIteration # https://peps.python.org/pep-0492/#example-2 | |
return obj | |
async def __aenter__(self): | |
return self.__class__(await self.__threaded(self.__enter__)()) | |
async def __aexit__(self, *args, **kwargs): | |
return await self.__threaded(self.__exit__)(*args, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment