Created
January 11, 2023 18:13
-
-
Save nkhitrov/efc82508a0a5b9862cb7a481fcfbdd35 to your computer and use it in GitHub Desktop.
Async celery task with rodi DI example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import functools | |
from typing import Any, Callable, Optional | |
from celery import shared_task | |
from rodi import Container, GetServiceContext | |
from redis.asyncio.cluster import RedisCluster | |
from sqlalchemy.ext.asyncio import AsyncSession | |
container = Container() | |
container.add_scoped_by_factory(...) | |
container.add...(...) | |
celery_deps_provider = container.build_provider() | |
def async_task(lock_name: str | None = None) -> Any: | |
def decorator(task: Any) -> Callable[[Any], Any]: | |
@functools.wraps(task) | |
def wrapper(*args: Any, **kwargs: Any) -> Any: | |
async def task_controller() -> Any: | |
with GetServiceContext(provider=celery_deps_provider) as context: # type: ignore | |
settings = celery_deps_provider.get(AppSettings, context=context) | |
engines: DatabaseEngines = celery_deps_provider.get(DatabaseEngines, context=context) | |
redis_cluster: RedisCluster = celery_deps_provider.get(RedisCluster, context=context) | |
session: AsyncSession = celery_deps_provider.get(AsyncSession, context=context) | |
await initialize_redis_connection(redis_cluster) | |
try: | |
return await task(*args, order_service=order_service, settings=settings, **kwargs) | |
finally: | |
await session.close() | |
await asyncio.wait( | |
[engines.disconnect(), close_redis_connection(redis_cluster)], | |
return_when=asyncio.ALL_COMPLETED, | |
) | |
return_coroutine = kwargs.pop("return_coroutine", False) | |
if return_coroutine: | |
return task_controller() | |
try: | |
return asyncio.run(task_controller()) | |
except Exception as error: | |
logger.exception("unhandled celery async task error", error=error) | |
return wrapper | |
return decorator | |
@shared_task(name="some_task_name") | |
@async_task() | |
async def some_task_controller( | |
*, | |
order_service: OrderService, | |
settings: AppSettings, | |
) -> None: | |
... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment