Skip to content

Instantly share code, notes, and snippets.

@AndBondStyle
Created October 2, 2024 04:47
Show Gist options
  • Save AndBondStyle/24ff8cbc2c23465477a1952da452be66 to your computer and use it in GitHub Desktop.
Save AndBondStyle/24ff8cbc2c23465477a1952da452be66 to your computer and use it in GitHub Desktop.
import asyncio
import os
from contextlib import asynccontextmanager
import sqlalchemy as sa
from dependency_injector import providers
from dependency_injector.containers import DeclarativeContainer
from dependency_injector.wiring import Provide, inject
from fastapi import Depends, FastAPI
from sqlalchemy.ext.asyncio import (
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
async def init_db_engine():
engine = create_async_engine(os.environ["POSTGRES_DSN"], echo=True)
print("engine start")
yield engine
print("engine stop")
await engine.dispose()
session_factory = Provide["db_scoped_session"]
async def init_session():
session = (await session_factory)()
async with session:
print("session before")
yield session
print("session after")
class Container(DeclarativeContainer):
db_engine = providers.Resource(init_db_engine)
db_session_factory = providers.Resource(async_sessionmaker, db_engine)
db_scoped_session = providers.ThreadSafeSingleton(
async_scoped_session,
session_factory=db_session_factory,
scopefunc=asyncio.current_task,
)
# Which provider to use?
db_session = providers.Callable(init_session)
# db_session = providers.Object(init_session)
# db_session = providers.Resource(init_session)
@asynccontextmanager
async def lifespan(app: FastAPI):
container = Container()
container.wire(modules=[__name__])
await container.db_scoped_session()
yield
await container.shutdown_resources()
app = FastAPI(lifespan=lifespan)
@app.get("/test1")
async def test1(db=Depends(init_session)):
res = await db.execute(sa.text("select version()"))
return res.scalar()
@app.get("/test2")
@inject
async def test2(db=Depends(Provide[Container.db_session])):
res = await db.execute(sa.text("select version()"))
return res.scalar()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment