Last active
October 27, 2024 09:14
-
-
Save ninely/88485b2e265d852d3feb8bd115065b1a to your computer and use it in GitHub Desktop.
Langchain with fastapi stream example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This is an example of how to use async langchain with fastapi and return a streaming response. | |
The latest version of Langchain has improved its compatibility with asynchronous FastAPI, | |
making it easier to implement streaming functionality in your applications. | |
""" | |
import asyncio | |
import os | |
from typing import AsyncIterable, Awaitable | |
import uvicorn | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse | |
from langchain.callbacks import AsyncIteratorCallbackHandler | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import HumanMessage | |
from pydantic import BaseModel | |
# Two ways to load env variables | |
# 1.load env variables from .env file | |
load_dotenv() | |
# 2.manually set env variables | |
if "OPENAI_API_KEY" not in os.environ: | |
os.environ["OPENAI_API_KEY"] = "" | |
app = FastAPI() | |
async def send_message(message: str) -> AsyncIterable[str]: | |
callback = AsyncIteratorCallbackHandler() | |
model = ChatOpenAI( | |
streaming=True, | |
verbose=True, | |
callbacks=[callback], | |
) | |
async def wrap_done(fn: Awaitable, event: asyncio.Event): | |
"""Wrap an awaitable with a event to signal when it's done or an exception is raised.""" | |
try: | |
await fn | |
except Exception as e: | |
# TODO: handle exception | |
print(f"Caught exception: {e}") | |
finally: | |
# Signal the aiter to stop. | |
event.set() | |
# Begin a task that runs in the background. | |
task = asyncio.create_task(wrap_done( | |
model.agenerate(messages=[[HumanMessage(content=message)]]), | |
callback.done), | |
) | |
async for token in callback.aiter(): | |
# Use server-sent-events to stream the response | |
yield f"data: {token}\n\n" | |
await task | |
class StreamRequest(BaseModel): | |
"""Request body for streaming.""" | |
message: str | |
@app.post("/stream") | |
def stream(body: StreamRequest): | |
return StreamingResponse(send_message(body.message), media_type="text/event-stream") | |
if __name__ == "__main__": | |
uvicorn.run(host="0.0.0.0", port=8000, app=app) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This is an example of how to use async langchain with fastapi and return a streaming response.""" | |
import os | |
from typing import Any, Optional, Awaitable, Callable, Union | |
import uvicorn | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse | |
from langchain.callbacks.base import AsyncCallbackHandler | |
from langchain.callbacks.manager import AsyncCallbackManager | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import HumanMessage | |
from pydantic import BaseModel | |
from starlette.types import Send | |
# two ways to load env variables | |
# 1.load env variables from .env file | |
load_dotenv() | |
# 2.manually set env variables | |
if "OPENAI_API_KEY" not in os.environ: | |
os.environ["OPENAI_API_KEY"] = "" | |
app = FastAPI() | |
Sender = Callable[[Union[str, bytes]], Awaitable[None]] | |
class AsyncStreamCallbackHandler(AsyncCallbackHandler): | |
"""Callback handler for streaming, inheritance from AsyncCallbackHandler.""" | |
def __init__(self, send: Sender): | |
super().__init__() | |
self.send = send | |
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Rewrite on_llm_new_token to send token to client.""" | |
await self.send(f"data: {token}\n\n") | |
class ChatOpenAIStreamingResponse(StreamingResponse): | |
"""Streaming response for openai chat model, inheritance from StreamingResponse.""" | |
def __init__( | |
self, | |
generate: Callable[[Sender], Awaitable[None]], | |
status_code: int = 200, | |
media_type: Optional[str] = None, | |
) -> None: | |
super().__init__(content=iter(()), status_code=status_code, media_type=media_type) | |
self.generate = generate | |
async def stream_response(self, send: Send) -> None: | |
"""Rewrite stream_response to send response to client.""" | |
await send( | |
{ | |
"type": "http.response.start", | |
"status": self.status_code, | |
"headers": self.raw_headers, | |
} | |
) | |
async def send_chunk(chunk: Union[str, bytes]): | |
if not isinstance(chunk, bytes): | |
chunk = chunk.encode(self.charset) | |
await send({"type": "http.response.body", "body": chunk, "more_body": True}) | |
# send body to client | |
await self.generate(send_chunk) | |
# send empty body to client to close connection | |
await send({"type": "http.response.body", "body": b"", "more_body": False}) | |
def send_message(message: str) -> Callable[[Sender], Awaitable[None]]: | |
async def generate(send: Sender): | |
model = ChatOpenAI( | |
streaming=True, | |
verbose=True, | |
callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send)]), | |
) | |
await model.agenerate(messages=[[HumanMessage(content=message)]]) | |
return generate | |
class StreamRequest(BaseModel): | |
"""Request body for streaming.""" | |
message: str | |
@app.post("/stream") | |
def stream(body: StreamRequest): | |
return ChatOpenAIStreamingResponse(send_message(body.message), media_type="text/event-stream") | |
if __name__ == "__main__": | |
uvicorn.run(host="0.0.0.0", port=8000, app=app) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aiohttp==3.8.4 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
anyio==3.7.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
async-timeout==4.0.2 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
attrs==23.1.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
certifi==2023.5.7 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
charset-normalizer==3.1.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
click==8.1.3 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
colorama==0.4.6 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Windows" | |
dataclasses-json==0.5.7 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
exceptiongroup==1.1.1 ; python_full_version >= "3.8.1" and python_version < "3.11" | |
fastapi==0.95.2 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
frozenlist==1.3.3 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
greenlet==2.0.2 ; python_full_version >= "3.8.1" and python_version < "3.12" and (platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64") | |
h11==0.14.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
idna==3.4 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
langchain==0.0.181 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
marshmallow-enum==1.5.1 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
marshmallow==3.19.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
multidict==6.0.4 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
numexpr==2.8.4 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
numpy==1.24.3 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
openai==0.27.7 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
openapi-schema-pydantic==1.2.4 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
packaging==23.1 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
pydantic==1.10.8 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
python-dotenv==1.0.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
pyyaml==6.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
requests==2.31.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
sqlalchemy==2.0.15 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
starlette==0.27.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
tenacity==8.2.2 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
tqdm==4.65.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
typing-extensions==4.6.2 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
urllib3==2.0.2 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
uvicorn==0.22.0 ; python_full_version >= "3.8.1" and python_version < "3.12" | |
yarl==1.9.2 ; python_full_version >= "3.8.1" and python_version < "3.12" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env sh | |
# This script is used to test. | |
curl "http://127.0.0.1:8000/stream" -X POST -d '{"message": "hello!"}' -H 'Content-Type: application/json' |
My way: use LLamaCpp, llm.stream() and yield.
import asyncio
import os
from typing import AsyncIterable, Awaitable
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langchain.llms import LlamaCpp
from langchain.cache import InMemoryCache
from langchain.globals import set_llm_cache
set_llm_cache(InMemoryCache())
app = FastAPI()
model_path="/project/llama_data/Llama-2-7b-chat-hf/ggml-model-q4_0.gguf"
llm = LlamaCpp(
model_path=model_path,
n_gpu_layers=40,
n_batch=512,
temperature=0.1,
verbose=True,
n_ctx=512
)
async def request_qa_stream(question: str):
for text in llm.stream(question):
yield text
def request_qa(question: str):
result = llm(question)
return result
class QARequest(BaseModel):
question: str
#curl "http://127.0.0.1:8000/qa/stream" -X POST -d '{"question":"who are you"}' -H 'Content-Type: application/json'
@app.post("/qa/stream")
def qa(body: QARequest):
return StreamingResponse(request_qa_stream(body.question), media_type="text/event-stream")
#curl "http://127.0.0.1:8000/qa" -X POST -d '{"question":"who are you"}' -H 'Content-Type: application/json'
@app.post("/qa")
def qa(body: QARequest):
return request_qa(body.question)
if __name__ == "__main__":
uvicorn.run(host="0.0.0.0", port=8000, app=app)
or like this
## use agenerate and callback
async def request_qa_stream(question: str) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
llm.callbacks = CallbackManager(callback)
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
print(f"Caught exception: {e}")
finally:
# Signal the aiter to stop.
event.set()
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
llm.agenerate([question]),
callback.done),
)
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield f"{token}"
await task
Maybe the second way can deal with concurrent request I guess, never test concurrent request.
Hi all !
I wanted to share with you a Custom Stream Response that I implemented in my FastAPI application recently.
I created this solution to manage streaming data.
You can use Stream, Event of Langchain but I'm doing special things with the Handlers that's why I need it.
Here examples:
Fast API
@router.get("/myExample")
async def mySpecialAPI(
session_id: UUID,
input="Hello",
) -> StreamResponse:
# Note: Don't write await we need a coroutine
invoke = chain.ainvoke(..)
callback = MyCallback(..)
return StreamResponse(invoke, callback)
Custom Stream Response
from __future__ import annotations
import asyncio
import typing
from typing import Any, AsyncIterable, Coroutine
from fastapi.responses import StreamingResponse as FastApiStreamingResponse
from starlette.background import BackgroundTask
class StreamResponse(FastApiStreamingResponse):
def __init__(
self,
invoke: Coroutine,
callback: MyCustomAsyncIteratorCallbackHandler,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = "text/event-stream",
background: BackgroundTask | None = None,
) -> None:
super().__init__(
content=StreamResponse.send_message(callback, invoke),
status_code=status_code,
headers=headers,
media_type=media_type,
background=background,
)
@staticmethod
async def send_message(
callback: AsyncIteratorCallbackHandler, invoke: Coroutine
) -> AsyncIterable[str]:
asyncio.create_task(invoke)
async for token in callback.aiter():
yield token
My Custom Callbackhandler
from __future__ import annotations
import asyncio
from typing import Any, AsyncIterator, List
class MyCustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
"""Callback handler that returns an async iterator."""
# Note: Can be a BaseModel than str
queue: asyncio.Queue[Optional[str]]
# Pass your params as you want
def __init__(self) -> None:
self.queue = asyncio.Queue()
async def on_llm_new_token(
self,
token: str,
tags: List[str] | None = None,
**kwargs: Any,
) -> None:
self.queue.put_nowait(token)
async def on_llm_end(
self,
response: LLMResult,
tags: List[str] | None = None,
**kwargs: Any,
) -> None:
self.queue.put_nowait(None)
# Note: Ect.. for error
async def aiter(self) -> AsyncIterator[str]:
while True:
token = await self.queue.get()
if isinstance(token, str):
yield token # Note: or a BaseModel.model_dump_json() etc..
elif token is None:
self.queue.task_done()
break
https://gist.github.com/YanSte/7be29bc93f21b010f64936fa334a185f
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@ninely i'm getting this error when trying to run with LLMChain and llamacpp. Can anyone help?
/opt/homebrew/anaconda3/lib/python3.10/site-packages/langchain/llms/llamacpp.py:352: RuntimeWarning: coroutine 'AsyncCallbackManagerForLLMRun.on_llm_new_token' was never awaited run_manager.on_llm_new_token( RuntimeWarning: Enable tracemalloc to get the object allocation traceback