Skip to content

Instantly share code, notes, and snippets.

@RibomBalt
Last active January 21, 2025 15:32
Show Gist options
  • Save RibomBalt/368a28edcf57c106a9320f8c66d81f4b to your computer and use it in GitHub Desktop.
Save RibomBalt/368a28edcf57c106a9320f8c66d81f4b to your computer and use it in GitHub Desktop.
Note (or wield behavior) on lifespan of `asynccontextmanager` in background tasks

Note (or wield behavior) on lifespan of asynccontextmanager in background tasks

related: How to extend the lifespan of an object under asynccontextmanager into a background task in FastAPI?

!!! note There may be some confusing behavior in this example. If you are not aware, please don't use this code directly in your project.

Idea

In this simplified app, we have a simplified Chat management system. You can imagine it like a LLM chatbox, and you request get_chat endpoint to get streaming response from LLM. To achieve some extent of asynchronous characteristics, the client can retrieve the response in several following requests (one sentence per request), not having to wait for the whole response from LLM to finish. In practice, the response from LLM is returned in a async generator, which is iterated later in a background task. So the original request should be finished before the async generator is finished.

Meanwhile, I tried to save/load the chat sessions in a Object-Oriented way. I add save and load methods to interact with cache, and add a get_chat_from_cache async context manager to get the chat session object from cache (asynccontextmanager=wrapper that converts a async generator into something that can be used with async with), in order to automatically manage the save/load of cache. Additionally, after the response generator is drained, we need also to save the full response to cache. This part is integrated into the async generator that is processed by the background task.

However, as you can see, the lifespan of the chat session object within asynccontextmanager is finished before initial response to the get_chat endpoints. But actually the object is still processed in the later process iteration of async generator, which is out of context. No error is raised, but the later changes is simply not cached.

Now I can think of two possible ways to address this issue:

  • Create a new object with get_chat_from_cache with the same sess_id in the async generator.
  • Manually save to cache in the generator. Clearly this is somewhat a hack in my opinion.

Run the test

# install fastapi, httpx, pytest
# My environment: fastapi==0.115.5
pip install -r requirements.txt
# run the test
# if you uncomment the hack, this would pass
pytest -svv test_app.py
from fastapi import FastAPI, Depends, BackgroundTasks, Request
from typing import Annotated, AsyncIterator
from pydantic import BaseModel, Field
from uuid import uuid4
from contextlib import asynccontextmanager
import random
import asyncio
app = FastAPI()
class Chat(BaseModel):
"""
This is a over-simplified Chat History Manager, that can be used in e.g. LangChain-like system
There is an additional `total` field because history are serialized and cached on their own, and we don't want to load all histories when unserialize them from cache/database.
"""
id: str = Field(default_factory=lambda: uuid4().hex)
meta: str = "some meta information"
history: list[str] = []
total: int = 0
uncached: int = 0
def add_message(self, msg: str):
self.history.append(msg)
self.total += 1
self.uncached += 1
async def save(self, cache: dict):
# cache history that are not cached
for imsg in range(-self.uncached, 0):
cache[f"msg:{self.id}:{self.total + imsg}"] = self.history[-self.uncached]
self.uncached = 0
# cache everything except history
cache[f"sess:{self.id}"] = self.model_dump(exclude={"history"})
print(f"saved: {self}")
@classmethod
async def load(cls, sess_id: str, cache: dict, max_read: int = 30):
sess_key = f"sess:{sess_id}"
obj = cls.model_validate(cache.get(sess_key))
for imsg in range(max(0, obj.total - max_read), obj.total):
obj.history.append(cache.get(f"msg:{obj.id}:{imsg}"))
print(f"loaded: {obj}")
return obj
async def chat(self, msg: str, cache: dict):
"""So this"""
self.add_message(msg)
async def get_chat():
resp = []
for i in range(random.randint(3, 5)):
# simulate long network IO
await asyncio.sleep(0.5)
chunk = f"resp{i}:{random.randbytes(2).hex()};"
resp.append(chunk)
yield chunk
self.add_message("".join(resp))
# NOTE to make the message cache work properly, we have to manually save this:
# await self.save(cache)
return get_chat()
# use a simple dict to mimic an actual cache, e.g. Redis
cache = {}
async def get_cache():
return cache
# didn't figure out how to make Chat a dependable
# I have read https://fastapi.tiangolo.com/advanced/advanced-dependencies/#parameterized-dependencies but still no clue
# the problem is: `sess_id` is passed from user, not something we can fix just like this tutorial shows.
# As an alternative, I used this async context manager.
# Theoretically this would automatically save the Chat object after exiting the `async with` block
@asynccontextmanager
async def get_chat_from_cache(sess_id: str, cache: dict):
"""
get object from cache (possibly create one), yield it, then save it back to cache
"""
sess_key = f"sess:{sess_id}"
if sess_key not in cache:
obj = Chat()
obj.id = sess_id
await obj.save(cache)
else:
obj = await Chat.load(sess_id, cache)
yield obj
await obj.save(cache)
async def task(sess_id: str, task_id: int, resp_gen: AsyncIterator[str], cache: dict):
""" """
async for chunk in resp_gen:
# do something with chunk, e.g. stream it to the client via a websocket
await asyncio.sleep(0.5)
cache[f"chunk:{sess_id}:{task_id}"] = chunk
task_id += 1
@app.get("/{sess_id}/{task_id}/{prompt}")
async def get_chat(
req: Request,
sess_id: str,
task_id: int,
prompt: str,
background_task: BackgroundTasks,
cache: Annotated[dict, Depends(get_cache)],
):
print(f"req incoming: {req.url}")
async with get_chat_from_cache(sess_id=sess_id, cache=cache) as chat:
resp_gen = await chat.chat(f"prompt:{prompt}", cache=cache)
background_task.add_task(
task, sess_id=sess_id, task_id=task_id, resp_gen=resp_gen, cache=cache
)
return "success"
@app.get("/{sess_id}")
async def get_sess(
req: Request, sess_id: str, cache: Annotated[dict, Depends(get_cache)]
):
print(f"req incoming: {req.url}")
return (await Chat.load(sess_id=sess_id, cache=cache)).model_dump()
from fastapi.testclient import TestClient
from app import app
def test_app():
with TestClient(app) as client:
resp = client.get('/test/1/test_input')
assert resp.status_code == 200
resp = client.get('/test')
assert resp.status_code == 200
# check if the response message is added
assert len(resp.json()['history']) == 2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment