Skip to content

Instantly share code, notes, and snippets.

@ninely
Last active October 27, 2024 09:14
Show Gist options
  • Save ninely/88485b2e265d852d3feb8bd115065b1a to your computer and use it in GitHub Desktop.
Save ninely/88485b2e265d852d3feb8bd115065b1a to your computer and use it in GitHub Desktop.
Langchain with fastapi stream example
"""This is an example of how to use async langchain with fastapi and return a streaming response.
The latest version of Langchain has improved its compatibility with asynchronous FastAPI,
making it easier to implement streaming functionality in your applications.
"""
import asyncio
import os
from typing import AsyncIterable, Awaitable
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from pydantic import BaseModel
# Two ways to load env variables
# 1.load env variables from .env file
load_dotenv()
# 2.manually set env variables
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = ""
app = FastAPI()
async def send_message(message: str) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
)
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
print(f"Caught exception: {e}")
finally:
# Signal the aiter to stop.
event.set()
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
model.agenerate(messages=[[HumanMessage(content=message)]]),
callback.done),
)
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield f"data: {token}\n\n"
await task
class StreamRequest(BaseModel):
"""Request body for streaming."""
message: str
@app.post("/stream")
def stream(body: StreamRequest):
return StreamingResponse(send_message(body.message), media_type="text/event-stream")
if __name__ == "__main__":
uvicorn.run(host="0.0.0.0", port=8000, app=app)
"""This is an example of how to use async langchain with fastapi and return a streaming response."""
import os
from typing import Any, Optional, Awaitable, Callable, Union
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import AsyncCallbackManager
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from pydantic import BaseModel
from starlette.types import Send
# two ways to load env variables
# 1.load env variables from .env file
load_dotenv()
# 2.manually set env variables
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = ""
app = FastAPI()
Sender = Callable[[Union[str, bytes]], Awaitable[None]]
class AsyncStreamCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming, inheritance from AsyncCallbackHandler."""
def __init__(self, send: Sender):
super().__init__()
self.send = send
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Rewrite on_llm_new_token to send token to client."""
await self.send(f"data: {token}\n\n")
class ChatOpenAIStreamingResponse(StreamingResponse):
"""Streaming response for openai chat model, inheritance from StreamingResponse."""
def __init__(
self,
generate: Callable[[Sender], Awaitable[None]],
status_code: int = 200,
media_type: Optional[str] = None,
) -> None:
super().__init__(content=iter(()), status_code=status_code, media_type=media_type)
self.generate = generate
async def stream_response(self, send: Send) -> None:
"""Rewrite stream_response to send response to client."""
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async def send_chunk(chunk: Union[str, bytes]):
if not isinstance(chunk, bytes):
chunk = chunk.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
# send body to client
await self.generate(send_chunk)
# send empty body to client to close connection
await send({"type": "http.response.body", "body": b"", "more_body": False})
def send_message(message: str) -> Callable[[Sender], Awaitable[None]]:
async def generate(send: Sender):
model = ChatOpenAI(
streaming=True,
verbose=True,
callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send)]),
)
await model.agenerate(messages=[[HumanMessage(content=message)]])
return generate
class StreamRequest(BaseModel):
"""Request body for streaming."""
message: str
@app.post("/stream")
def stream(body: StreamRequest):
return ChatOpenAIStreamingResponse(send_message(body.message), media_type="text/event-stream")
if __name__ == "__main__":
uvicorn.run(host="0.0.0.0", port=8000, app=app)
aiohttp==3.8.4 ; python_full_version >= "3.8.1" and python_version < "3.12"
aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_version < "3.12"
anyio==3.7.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
async-timeout==4.0.2 ; python_full_version >= "3.8.1" and python_version < "3.12"
attrs==23.1.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
certifi==2023.5.7 ; python_full_version >= "3.8.1" and python_version < "3.12"
charset-normalizer==3.1.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
click==8.1.3 ; python_full_version >= "3.8.1" and python_version < "3.12"
colorama==0.4.6 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Windows"
dataclasses-json==0.5.7 ; python_full_version >= "3.8.1" and python_version < "3.12"
exceptiongroup==1.1.1 ; python_full_version >= "3.8.1" and python_version < "3.11"
fastapi==0.95.2 ; python_full_version >= "3.8.1" and python_version < "3.12"
frozenlist==1.3.3 ; python_full_version >= "3.8.1" and python_version < "3.12"
greenlet==2.0.2 ; python_full_version >= "3.8.1" and python_version < "3.12" and (platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64")
h11==0.14.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
idna==3.4 ; python_full_version >= "3.8.1" and python_version < "3.12"
langchain==0.0.181 ; python_full_version >= "3.8.1" and python_version < "3.12"
marshmallow-enum==1.5.1 ; python_full_version >= "3.8.1" and python_version < "3.12"
marshmallow==3.19.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
multidict==6.0.4 ; python_full_version >= "3.8.1" and python_version < "3.12"
mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
numexpr==2.8.4 ; python_full_version >= "3.8.1" and python_version < "3.12"
numpy==1.24.3 ; python_full_version >= "3.8.1" and python_version < "3.12"
openai==0.27.7 ; python_full_version >= "3.8.1" and python_version < "3.12"
openapi-schema-pydantic==1.2.4 ; python_full_version >= "3.8.1" and python_version < "3.12"
packaging==23.1 ; python_full_version >= "3.8.1" and python_version < "3.12"
pydantic==1.10.8 ; python_full_version >= "3.8.1" and python_version < "3.12"
python-dotenv==1.0.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
pyyaml==6.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
requests==2.31.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
sqlalchemy==2.0.15 ; python_full_version >= "3.8.1" and python_version < "3.12"
starlette==0.27.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
tenacity==8.2.2 ; python_full_version >= "3.8.1" and python_version < "3.12"
tqdm==4.65.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
typing-extensions==4.6.2 ; python_full_version >= "3.8.1" and python_version < "3.12"
typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
urllib3==2.0.2 ; python_full_version >= "3.8.1" and python_version < "3.12"
uvicorn==0.22.0 ; python_full_version >= "3.8.1" and python_version < "3.12"
yarl==1.9.2 ; python_full_version >= "3.8.1" and python_version < "3.12"
#!/usr/bin/env sh
# This script is used to test.
curl "http://127.0.0.1:8000/stream" -X POST -d '{"message": "hello!"}' -H 'Content-Type: application/json'
@ninely
Copy link
Author

ninely commented Aug 12, 2023

Here's an example of a Flask-SocketIO server that sends a stream of messages to the client. @faridelya Maybe u can give it a try.

async def async_generator():
    # 1. Use the iterator callback
    callback = AsyncIteratorCallbackHandler()

    # 2. Begin a task that runs in the background.
    task = asyncio.create_task(
        llm_chain.arun(...),
    )
    # 3. Read data
    async for token in callback.aiter():
        yield f"data: {token}\n\n"
    await task

@socketio.on('start')
def handle_start():
    def run_loop(target_loop):
        asyncio.set_event_loop(target_loop)
        target_loop.run_until_complete(async_emit())

    async def async_emit():
        async for data in async_generator():
            socketio.emit('response', {'data': data})

    loop = asyncio.new_event_loop()
    t = threading.Thread(target=run_loop, args=(loop,))
    t.start()

@faridelya
Copy link

faridelya commented Aug 12, 2023

Thanks @everyone i got this how to use but it returning final output not streaming.

Usage

 from langchain.callbacks.manager import CallbackManager
 callback_manager = CallbackManager([AsyncIteratorCallbackHandler()])


# You can set in any model callback_manager  parameter
llm = LlamaCpp(
    model_path=model_path,
    max_tokens=2024,
    n_gpu_layers=n_gpu_layers,
    n_batch=n_batch,
    callback_manager=callback_manager,
    verbose=False,
)

response = llm_chain.run(question)
print(response)   # or return it its upto you

i just write this because i was facing problem in usage so may some find it useful
Thanks

@faridelya
Copy link

faridelya commented Aug 12, 2023

Here's an example of a Flask-SocketIO server that sends a stream of messages to the client. @faridelya Maybe u can give it a try.

async def async_generator():
    # 1. Use the iterator callback
    callback = AsyncIteratorCallbackHandler()

    # 2. Begin a task that runs in the background.
    task = asyncio.create_task(
        llm_chain.arun(...),
    )
    # 3. Read data
    async for token in callback.aiter():
        yield f"data: {token}\n\n"
    await task

@socketio.on('start')
def handle_start():
    def run_loop(target_loop):
        asyncio.set_event_loop(target_loop)
        target_loop.run_until_complete(async_emit())

    async def async_emit():
        async for data in async_generator():
            socketio.emit('response', {'data': data})

    loop = asyncio.new_event_loop()
    t = threading.Thread(target=run_loop, args=(loop,))
    t.start()

Sorry @ninely i did not figure out how to use this i known you may have busy schedule but if possible check this example where it show streaming. but this streaming is print out.

i want streaming like ChatGPT api can handle like using fo loop for response variable and we then return or emit each chunk in real time.
Anyway Thanks for your kind response

@ninely
Copy link
Author

ninely commented Aug 13, 2023

@faridelya I have just created a PR to enable LlamaCpp to support async stream response. If it gets approved, the following example can be used to implement real-time stream response, not just stdout.

import asyncio
import os
from typing import AsyncIterable, Awaitable, Callable, Union, Any

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler
from pydantic import BaseModel

from langchain.llms import LlamaCpp
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager

# Load env variables from .env file
load_dotenv()

app = FastAPI()


template = """Question: {question}

Answer: Let's work this out in a step by step way to be sure we have the right answer."""

prompt = PromptTemplate(template=template, input_variables=["question"])


Sender = Callable[[Union[str, bytes]], Awaitable[None]]


class AsyncStreamCallbackHandler(AsyncCallbackHandler):
    """Callback handler for streaming, inheritance from AsyncCallbackHandler."""

    def __init__(self, send: Sender):
        super().__init__()
        self.send = send

    async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Rewrite on_llm_new_token to send token to client."""
        await self.send(f"data: {token}\n\n")


async def send_message(message: str) -> AsyncIterable[str]:
    # Callbacks support token-wise streaming
    callback = AsyncIteratorCallbackHandler()
    callback_manager = CallbackManager([callback])
    # Verbose is required to pass to the callback manager

    # Make sure the model path is correct for your system!
    llm = LlamaCpp(
        model_path=os.environ["MODEL_PATH"],    # replace with your model path
        callback_manager=callback_manager,
        verbose=True,
        streaming=True,
    )

    llm_chain = LLMChain(prompt=prompt, llm=llm)

    question = "What NFL team won the Super Bowl in the year Justin Bieber was born?"

    async def wrap_done(fn: Awaitable, event: asyncio.Event):
        """Wrap an awaitable with an event to signal when it's done or an exception is raised."""
        try:
            await fn
        except Exception as e:
            # TODO: handle exception
            print(f"Caught exception: {e}")
        finally:
            # Signal the aiter to stop.
            event.set()

    # Begin a task that runs in the background.
    task = asyncio.create_task(wrap_done(
        llm_chain.arun(question),
        callback.done),
    )

    async for token in callback.aiter():
        # Use server-sent-events to stream the response
        yield f"data: {token}\n\n"

    await task


class StreamRequest(BaseModel):
    """Request body for streaming."""
    message: str


@app.post("/stream")
def stream(body: StreamRequest):
    return StreamingResponse(send_message(body.message), media_type="text/event-stream")


if __name__ == "__main__":
    uvicorn.run(host="0.0.0.0", port=8000, app=app)

test
curl http://127.0.0.1:8000/stream -X POST -d '{"message": ""}' -H 'Content-Type: application/json'

@Ludobico
Copy link

how to fix the error

RuntimeWarning: coroutine 'AsyncCallbackManagerForLLMRun.on_llm_new_token' was never awaited
  run_manager.on_llm_new_token(

can anyone help? thx

@iiitmahesh
Copy link

iiitmahesh commented Aug 28, 2023

@Ludobico same issue for me also, any update?? @ninely @faridelya Can you please help??

RuntimeWarning: coroutine 'AsyncCallbackHandler.on_chain_end' was never awaited
  getattr(handler, event_name)(*args, **kwargs)
RuntimeWarning: Enable tracemalloc to get the object allocation traceback

Caught exception: object dict can't be used in 'await' expression

My Code:

async def send_message(question):
    callback = AsyncFinalIteratorCallbackHandler()
    memory_key = "chat_history"
    prompt = OpenAIFunctionsAgent.create_prompt(
            system_message=system_message,
            extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)]
        )
    retriever = db.as_retriever(
            search_type="mmr",
            search_kwargs={'k': 4, 'lambda_mult': 0.25},
            return_source_documents=True
        )
    tool = create_retriever_tool(
        retriever, 
        "VectorDB_Query_Store",
        "Searches and returns documents database."
    )
    tools = [tool]
    llm = ChatOpenAI(model="gpt-4-0613",temperature = 0,openai_api_key= openai_api_key_gpt4, streaming = True,callbacks=[callback])
    memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
    agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
    agent_executor = AgentExecutor.from_agent_and_tools(
                agent=agent,
                tools=tools,
                memory=memory,
                verbose=True,
                return_intermediate_steps=True,
                callbacks=[callback]
                
            )
    async def wrap_done(fn: Awaitable, event: asyncio.Event):
        """Wrap an awaitable with a event to signal when it's done or an exception is raised."""
        try:
            await fn
        except Exception as e:
            # TODO: handle exception
            print(f"Caught exception: {e}")
        finally:
            # Signal the aiter to stop.
            event.set()
    # task = asyncio.create_task(wrap_done(qa_chain.arun({"question": message}),callback.done))
    task = asyncio.create_task(agent_executor(agent_executor({"input":question},callback.done))
    
    async for token in callback.aiter():
        # Use server-sent-events to stream the response
        yield token
    await task



@app.post("/stream")
def stream(body: StreamRequest):
    return StreamingResponse(send_message(body.message), media_type="text/event-stream")

@ninely
Copy link
Author

ninely commented Aug 28, 2023

@Ludobico @iiitmahesh AsyncCallbackHandler needs to run in asynchronous methods, such as arun, acall.

@iiitmahesh
Copy link

@ninely Do you have any example with agent or AgentExecutor with streaming api ?? Any help would be appreciated.Thanks

@ninely
Copy link
Author

ninely commented Aug 30, 2023

@iiitmahesh just like this.

asyncio.create_task(agent_executor.arun(your_params),callback.done)

@acliyanarachchi
Copy link

acliyanarachchi commented Sep 12, 2023

@ninely

I did the implementation as you have mentioned. It worked perfectly with one issue.

If the initialize_agent() method called in the setup and reuse the agent to send the messages, streaming only works in the initial stage. It doesn't work after. What would be the problem here?

self.callback_handler is a class variable and will be used in the advanced_chat(). Since AsyncFinalIteratorCallbackHandler() called once in the application, streaming only works once.
If I move the advanced_chat() code to initialize(), streaming works. But I can't do it because there are many tools initialized in the initialize() method. I don't want them to reinitialize during every message process.

-- Sample code ----

`def initialize(self):
self.callback_handler = AsyncFinalIteratorCallbackHandler()

      llm = ChatOpenAI(streaming=True, temperature=0, callbacks=[self.callback_handler],
                       model_name="gpt-4")

      memory = AgentMemory().get("ConversationSummaryBufferMemory", llm)
      agent = initialize_agent( .... )

 async def advanced_chat(self, agent, text) -> AsyncGenerator[str, None]:  
       run = asyncio.create_task(self.wrap_done(
            agent.arun(input=text),
            self.callback_handler.done))

       async for token in self.callback_handler.aiter():
            yield token

        await run`

@ninely
Copy link
Author

ninely commented Sep 13, 2023

@acliyanarachchi Before reusing, you need to execute callback.done.clear(). By the way, If used by two LLM runs in parallel this won't work as expected.

@gingergenius
Copy link

What method should be used for an LLM or a ChatLLM, not a chain or agent?

@acliyanarachchi
Copy link

@acliyanarachchi Before reusing, you need to execute callback.done.clear(). By the way, If used by two LLM runs in parallel this won't work as expected.

Thanks @ninely. This works perfectly except the error @iiitmahesh mentioned above - RuntimeWarning: coroutine 'AsyncCallbackHandler.on_chain_end'
It doesn't break the flow. But it delay the flow.

Btw, Do you have any examples of how to get VertexAI integrated into initialize_agent() with streaming?

@ninely
Copy link
Author

ninely commented Sep 22, 2023

@acliyanarachchi Sorry, I don't know much about this and haven't researched it yet.

@bindlam
Copy link

bindlam commented Sep 25, 2023

I am new to fast API, can someone please help. the code was working fine when I was using flask and deploying it. we are trying to move to production we want to use fast API now. generate_answer_from_LLM functions fails at llm.predict(usrPrompt) Exception in getTChatResponse: Resource not found, So added streaming and callback_manager but getting the below error now

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import time
import os

import langchain
from langchain.cache import InMemoryCache
from langchain.callbacks import get_openai_callback
from flask_cors import CORS, cross_origin

from langchain.llms import AzureOpenAI

import json
from langchain.cache import SQLiteCache
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

ENV_VARS = {
"OPENAI_API_TYPE": "azure",
"OPENAI_API_VERSION": "2022-12-01",
"OPENAI_API_BASE": "openai.azure.com/",
"OPENAI_API_KEY": "
*"
}
os.environ.update(ENV_VARS)

app = FastAPI()

class Message(BaseModel):
content: str

class InputJson(BaseModel):
messages: List[Message]
temperature: int
Document_query: bool

from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager([AsyncIteratorCallbackHandler()])

llm = AzureOpenAI(
deployment_name="text-davinci-003",
temperature=0,
streaming=True,
callback_manager=callback_manager,
max_tokens=1000)

langchain.llm_cache = SQLiteCache(database_path="./langchain.db")

def generate_answer_from_LLM(usrPrompt):
print("entered dragon")
print(usrPrompt)
with get_openai_callback() as cb:
start_time = time.time()
response = llm.predict(usrPrompt)
end_time = time.time()
print(f"Time taken: {end_time-start_time:0.2f} sec")
return response

@app.post("/getTChatResponse")
async def getTChatResponse(data: InputJson):
try:
user_prompt =data.messages[0].content
print("user_prompt--->", user_prompt)
response =generate_answer_from_LLM(user_prompt)
print(response)
return response
except Exception as e:
print("Exception in getTChatResponse:", e)
return {"error": str(e)}, 500

@app.get("/")
async def read_root():
return {"message": f"Welcome " }
if name == 'main':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=3010 ,reload=True)

Error I am getting:

C:\Users\Anaconda3\envs\tech_day_1\lib\site-packages\langchain\callbacks\manager.py:115: RuntimeWarning: coroutine 'AsyncIteratorCallbackHandler.on_llm_start' was never awaited
getattr(handler, event_name)(*args, **kwargs)
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
C:\Users\Anaconda3\envs\tech_day_1\lib\site-packages\langchain\callbacks\manager.py:115: RuntimeWarning: coroutine 'AsyncIteratorCallbackHandler.on_llm_error' was never awaited
getattr(handler, event_name)(*args, **kwargs)
RuntimeWarning: Enable tracemalloc to get the object allocation traceback

@gingergenius
Copy link

Async methods like agenerate are intended for generating multiple responses simultaneously independently of each other, right? Can we do streaming without those?

@ninely
Copy link
Author

ninely commented Sep 26, 2023

@gingergenius You can use StreamingStdOutCallbackHandler without using async.

@RoderickVM
Copy link

@ninely I just wanted to say THANK YOU! You brought an end to lots of hours of frustration with your async def wrap_done(fn: Awaitable, event: asyncio.Event) solution, after trying this from @Coding-Crashkurse. and this from @jamescalam

Would you mind explaining why Awaitable and asyncio.Event are necessary?

@ninely
Copy link
Author

ninely commented Oct 7, 2023

@RoderickVM The purpose of wrap_done is to interrupt the main process that is blocked on aiter when an exception occurs during the execution of fn. For specifics, you can look at the implementation of aiter in AsyncIteratorCallbackHandler.

@avikhandakar-dev
Copy link

avikhandakar-dev commented Oct 12, 2023

@ninely Can you please help me? I'm using Sequential Chain to join multiple prompt. Everything works as expected. But when streaming, it only stream first chain output.
Here is my code:
`import asyncio
from langchain.chat_models import ChatOpenAI
from dotenv import load_dotenv
import os
from langchain.chains import LLMChain, SequentialChain
from langchain.prompts import PromptTemplate
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler

load_dotenv()
openai_api_key = os.environ.get('OPENAI_API_KEY')
llm = ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo-16k",
openai_api_key=openai_api_key, streaming=True, callbacks=[])

capital_template = """
Where is the capital of {country}?
"""
capital_prompt_template = PromptTemplate(
input_variables=["country"], template=capital_template)
capital_chain = LLMChain(
llm=llm, prompt=capital_prompt_template, output_key="capital")

about_template = """Tell me about {capital} in 5 words
"""
about_prompt_template = PromptTemplate(
input_variables=["capital"], template=about_template)
about_chain = LLMChain(
llm=llm, prompt=about_prompt_template, output_key="about")

test_chain = SequentialChain(
chains=[capital_chain, about_chain],
input_variables=["country"],
output_variables=["capital", "about"],
verbose=True)

class TestAgent:
def init(self, country):
self.country = country

async def run(self, stream_it):
    res = await test_chain.acall({"country": self.country}, callbacks=[stream_it])
    print(res)

async def create_gen(self, stream_it: AsyncIteratorCallbackHandler):
    task = asyncio.create_task(
        self.run(stream_it))
    async for token in stream_it.aiter():
        print(token)
        yield token
    await task

`

Here is FastApi code:
`
@router.get("/test")
async def test():
stream_it = AsyncIteratorCallbackHandler()
agent = TestAgent("England")
gen = agent.create_gen(stream_it)
return StreamingResponse(gen, media_type="text/event-stream")

`

How can i stream full output? Its only stream : The capital of England is London.

@FrancescoSaverioZuppichini
Copy link

FrancescoSaverioZuppichini commented Oct 17, 2023

crazy how hard it is, really

@pietz
Copy link

pietz commented Oct 24, 2023

This has helped me a lot. Thank you!

@hteeyeoh
Copy link

Trying to implement streaming using AsyncIteratorCallbackHandler() shown above into my LLMChain where doing summarization. Split the document into several chunk and run using:
task = asyncio.create_task(wrap_done([await chain.acall(doc.page_content, callback) for doc in docs], callback.done),)
It just not working. Any suggestion for this?

@robertoronderosjr
Copy link

Trying to implement streaming using AsyncIteratorCallbackHandler() shown above into my LLMChain where doing summarization. Split the document into several chunk and run using: task = asyncio.create_task(wrap_done([await chain.acall(doc.page_content, callback) for doc in docs], callback.done),) It just not working. Any suggestion for this?

same -- The "new way" above does not work in newer versions of langchain, in my case I get this error:
NotImplementedError: AsyncIteratorCallbackHandler does not implement on_chat_model_start

@neokd
Copy link

neokd commented Dec 4, 2023

@ninely i'm getting this error when trying to run with LLMChain and llamacpp. Can anyone help?

/opt/homebrew/anaconda3/lib/python3.10/site-packages/langchain/llms/llamacpp.py:352: RuntimeWarning: coroutine 'AsyncCallbackManagerForLLMRun.on_llm_new_token' was never awaited
  run_manager.on_llm_new_token(
RuntimeWarning: Enable tracemalloc to get the object allocation traceback

@aicodex
Copy link

aicodex commented Jan 8, 2024

My way: use LLamaCpp, llm.stream() and yield.

import asyncio
import os
from typing import AsyncIterable, Awaitable

import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langchain.llms import LlamaCpp
from langchain.cache import InMemoryCache
from langchain.globals import set_llm_cache

set_llm_cache(InMemoryCache())

app = FastAPI()

model_path="/project/llama_data/Llama-2-7b-chat-hf/ggml-model-q4_0.gguf"

llm = LlamaCpp(
    model_path=model_path,
    n_gpu_layers=40,
    n_batch=512,
    temperature=0.1,
    verbose=True,
    n_ctx=512
)

async def request_qa_stream(question: str):
    for text in llm.stream(question):
        yield text

def request_qa(question: str):
    result = llm(question)
    return result

class QARequest(BaseModel):
    question: str

#curl "http://127.0.0.1:8000/qa/stream" -X POST -d '{"question":"who are you"}' -H 'Content-Type: application/json'
@app.post("/qa/stream")
def qa(body: QARequest):
    return StreamingResponse(request_qa_stream(body.question), media_type="text/event-stream")
    
#curl "http://127.0.0.1:8000/qa" -X POST -d '{"question":"who are you"}' -H 'Content-Type: application/json'
@app.post("/qa")
def qa(body: QARequest):
    return request_qa(body.question)

if __name__ == "__main__":
    uvicorn.run(host="0.0.0.0", port=8000, app=app)

or like this

## use agenerate and callback
async def request_qa_stream(question: str) -> AsyncIterable[str]:
    callback = AsyncIteratorCallbackHandler()
    llm.callbacks = CallbackManager(callback)
    async def wrap_done(fn: Awaitable, event: asyncio.Event):
        """Wrap an awaitable with a event to signal when it's done or an exception is raised."""
        try:
            await fn
        except Exception as e:
            # TODO: handle exception
            print(f"Caught exception: {e}")
        finally:
            # Signal the aiter to stop.
            event.set()

    # Begin a task that runs in the background.
    task = asyncio.create_task(wrap_done(
        llm.agenerate([question]),
        callback.done),
    )

    async for token in callback.aiter():
        # Use server-sent-events to stream the response
        yield f"{token}"

    await task

Maybe the second way can deal with concurrent request I guess, never test concurrent request.

@YanSte
Copy link

YanSte commented Apr 19, 2024

Hi all !

I wanted to share with you a Custom Stream Response that I implemented in my FastAPI application recently.

I created this solution to manage streaming data.

You can use Stream, Event of Langchain but I'm doing special things with the Handlers that's why I need it.

Here examples:

Fast API

@router.get("/myExample")
async def mySpecialAPI(
    session_id: UUID,
    input="Hello",
) -> StreamResponse:
    # Note: Don't write await we need a coroutine
    invoke = chain.ainvoke(..)
    callback = MyCallback(..)
    return StreamResponse(invoke, callback)

Custom Stream Response

from __future__ import annotations
import asyncio
import typing
from typing import Any, AsyncIterable, Coroutine
from fastapi.responses import StreamingResponse as FastApiStreamingResponse
from starlette.background import BackgroundTask

class StreamResponse(FastApiStreamingResponse):
    def __init__(
        self,
        invoke: Coroutine,
        callback: MyCustomAsyncIteratorCallbackHandler,
        status_code: int = 200,
        headers: typing.Mapping[str, str] | None = None,
        media_type: str | None = "text/event-stream",
        background: BackgroundTask | None = None,
    ) -> None:
        super().__init__(
            content=StreamResponse.send_message(callback, invoke),
            status_code=status_code,
            headers=headers,
            media_type=media_type,
            background=background,
        )

    @staticmethod
    async def send_message(
        callback: AsyncIteratorCallbackHandler, invoke: Coroutine
    ) -> AsyncIterable[str]:
        asyncio.create_task(invoke)

        async for token in callback.aiter():
            yield token

My Custom Callbackhandler

from __future__ import annotations
import asyncio
from typing import Any, AsyncIterator, List

class MyCustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
    """Callback handler that returns an async iterator."""
    # Note: Can be a BaseModel than str
    queue: asyncio.Queue[Optional[str]]

    # Pass your params as you want
    def __init__(self) -> None:
        self.queue = asyncio.Queue()

    async def on_llm_new_token(
        self,
        token: str,
        tags: List[str] | None = None,
        **kwargs: Any,
    ) -> None:
         self.queue.put_nowait(token)

    async def on_llm_end(
        self,
        response: LLMResult,
        tags: List[str] | None = None,
        **kwargs: Any,
    ) -> None:
          self.queue.put_nowait(None)

   # Note: Ect.. for error 

    async def aiter(self) -> AsyncIterator[str]:
        while True:
            token = await self.queue.get()
           
            if isinstance(token, str):
                yield token # Note: or a BaseModel.model_dump_json() etc..

            elif token is None:
               self.queue.task_done()
               break

https://gist.github.com/YanSte/7be29bc93f21b010f64936fa334a185f

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment