-
-
Save jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 to your computer and use it in GitHub Desktop.
# The goal of this file is to provide a FastAPI application for handling | |
# chat requests amd generation AI-powered responses using conversation chains. | |
# The application uses the LangChaing library, which includes a chatOpenAI model | |
# for natural language processing. | |
# The `StreamingConversationChain` class is responsible for creating and storing | |
# conversation memories and generating responses. It utilizes the `ChatOpenAI` model | |
# and a callback handler to stream responses as they're generated. | |
# The application defines a `ChatRequest` model for handling chat requests, | |
# which includes the conversation ID and the user's message. | |
# The `/chat` endpoint is used to receive chat requests and generate responses. | |
# It utilizes the `StreamingConversationChain` instance to generate the responses and | |
# sends them back as a streaming response using the `StreamingResponse` class. | |
# PLease note that the implementation relies on certain dependencies and imports, | |
# which are not included in the provided code snippet. | |
# Ensure that all necessary packages are installed and imported | |
# correctly before running the application. | |
# | |
# Install dependencies: | |
# pip install fastapi uvicorn[standard] python-dotenv langchain openai | |
# | |
# Example of usage: | |
# uvicorn main:app --reload | |
# | |
# Example of request: | |
# | |
# curl --no-buffer \ | |
# -X POST \ | |
# -H 'accept: text/event-stream' \ | |
# -H 'Content-Type: application/json' \ | |
# -d '{ | |
# "conversation_id": "cat-conversation", | |
# "message": "what'\''s their size?" | |
# }' \ | |
# http://localhost:8000/chat | |
# | |
# Cheers, | |
# @jvelezmagic | |
import asyncio | |
from functools import lru_cache | |
from typing import AsyncGenerator | |
from fastapi import Depends, FastAPI | |
from fastapi.responses import StreamingResponse | |
from langchain.callbacks import AsyncIteratorCallbackHandler | |
from langchain.chains import ConversationChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
SystemMessagePromptTemplate, | |
) | |
from pydantic import BaseModel, BaseSettings | |
class Settings(BaseSettings): | |
""" | |
Settings class for this application. | |
Utilizes the BaseSettings from pydantic for environment variables. | |
""" | |
openai_api_key: str | |
class Config: | |
env_file = ".env" | |
@lru_cache() | |
def get_settings(): | |
"""Function to get and cache settings. | |
The settings are cached to avoid repeated disk I/O. | |
""" | |
return Settings() | |
class StreamingConversationChain: | |
""" | |
Class for handling streaming conversation chains. | |
It creates and stores memory for each conversation, | |
and generates responses using the ChatOpenAI model from LangChain. | |
""" | |
def __init__(self, openai_api_key: str, temperature: float = 0.0): | |
self.memories = {} | |
self.openai_api_key = openai_api_key | |
self.temperature = temperature | |
async def generate_response( | |
self, conversation_id: str, message: str | |
) -> AsyncGenerator[str, None]: | |
""" | |
Asynchronous function to generate a response for a conversation. | |
It creates a new conversation chain for each message and uses a | |
callback handler to stream responses as they're generated. | |
:param conversation_id: The ID of the conversation. | |
:param message: The message from the user. | |
""" | |
callback_handler = AsyncIteratorCallbackHandler() | |
llm = ChatOpenAI( | |
callbacks=[callback_handler], | |
streaming=True, | |
temperature=self.temperature, | |
openai_api_key=self.openai_api_key, | |
) | |
memory = self.memories.get(conversation_id) | |
if memory is None: | |
memory = ConversationBufferMemory(return_messages=True) | |
self.memories[conversation_id] = memory | |
chain = ConversationChain( | |
memory=memory, | |
prompt=CHAT_PROMPT_TEMPLATE, | |
llm=llm, | |
) | |
run = asyncio.create_task(chain.arun(input=message)) | |
async for token in callback_handler.aiter(): | |
yield token | |
await run | |
class ChatRequest(BaseModel): | |
"""Request model for chat requests. | |
Includes the conversation ID and the message from the user. | |
""" | |
conversation_id: str | |
message: str | |
CHAT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template( | |
"You're a AI that knows everything about cats." | |
), | |
MessagesPlaceholder(variable_name="history"), | |
HumanMessagePromptTemplate.from_template("{input}"), | |
] | |
) | |
app = FastAPI(dependencies=[Depends(get_settings)]) | |
streaming_conversation_chain = StreamingConversationChain( | |
openai_api_key=get_settings().openai_api_key | |
) | |
@app.post("/chat", response_class=StreamingResponse) | |
async def generate_response(data: ChatRequest) -> StreamingResponse: | |
"""Endpoint for chat requests. | |
It uses the StreamingConversationChain instance to generate responses, | |
and then sends these responses as a streaming response. | |
:param data: The request data. | |
""" | |
return StreamingResponse( | |
streaming_conversation_chain.generate_response( | |
data.conversation_id, data.message | |
), | |
media_type="text/event-stream", | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app) |
re: Replicate class ... just looking at the source and it looks like it doesn't support async insofar as I can tell.
Kinda frustrating that you can't just drop in whatever LLM
you want to use ... this is a feature that an abstraction library like LangChain should support imo.
Here's an example of using this with the
ConversationalRetrievalChain
. If y'all can think of some way to improve the code please reply here. Converting this to useEventSourceResponse
might be a good first step :)class ChatRequest(BaseModel): """Request model for chat requests. Includes the conversation ID and the message from the user.""" conversation_id: str message: str class ChatResponse(BaseModel): """Chat response schema""" sender: str message: str type: str xtra: dict = None @validator("sender") def sender_must_be_bot_or_you(cls, v): if v not in ["bot", "you"]: raise ValueError("sender must be bot or you") return v @validator("type") def validate_message_type(cls, v): if v not in ["start", "stream", "end", "error", "info"]: raise ValueError("type must be start, stream or end") return v class StreamingLLMCallbackHandler(AsyncIteratorCallbackHandler): """Callback handler for streaming LLM responses.""" async def on_llm_start(self, serialized, prompts, **kwargs) -> None: self.done.clear() self.queue.put_nowait(ChatResponse(sender="bot", message="", type="start")) async def on_llm_end(self, response, **kwargs) -> None: # we override this method since we want the ConversationalRetrievalChain to potentially return # other items (e.g., source_documents) after it is completed pass async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs) -> None: self.queue.put_nowait(ChatResponse(sender="bot", message=str(error), type="error")) async def on_llm_new_token(self, token: str, **kwargs) -> None: self.queue.put_nowait(ChatResponse(sender="bot", message=token, type="stream")) class ConvoChainCallbackHandler(AsyncCallbackHandler): """Use to add additional information (e.g., source_documents, etc...) once the chain finishes""" def __init__(self, callback_handler) -> None: super().__init__() self.callback_handler = callback_handler async def on_chain_end(self, outputs, *, run_id, parent_run_id, **kwargs) -> None: """Run after chain ends running.""" source_docs = outputs.get("source_documents", None) source_docs_d = [{"page": doc.metadata["page"]} for doc in source_docs] if source_docs else None xtra = {"source_documents": source_docs_d} self.callback_handler.queue.put_nowait(ChatResponse(sender="bot", message="", xtra=xtra, type="info")) self.callback_handler.queue.put_nowait(ChatResponse(sender="bot", message="", type="end")) class StreamingConversationChain: """Class for handling streaming conversation chains.""" def __init__(self, openai_api_key: str, temperature: float = 0.0): self.memories = {} self.openai_api_key = openai_api_key self.temperature = temperature async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]: """ Asynchronous function to generate a response for a conversation. It creates a new conversation chain for each message and uses a callback handler to stream responses as they're generated. :param conversation_id: The ID of the conversation. :param message: The message from the user. """ streaming_cb = StreamingLLMCallbackHandler() # AsyncIteratorCallbackHandler() convo_cb_manager = AsyncCallbackManager([ConvoChainCallbackHandler(streaming_cb)]) question_gen_llm = ChatOpenAI( model_name="gpt-3.5-turbo", max_retries=15, temperature=0.0, streaming=True, openai_api_key=self.openai_api_key, ) streaming_llm = ChatOpenAI( max_retries=15, temperature=0, callbacks=[streaming_cb], streaming=True, openai_api_key=self.openai_api_key, ) memory = self.memories.get(conversation_id) if memory is None: memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer") self.memories[conversation_id] = memory question_gen_chain = LLMChain(llm=question_gen_llm, prompt=CONDENSE_QUESTION_PROMPT) # , callback_manager=manager) final_qa_chain = load_qa_chain( streaming_llm, chain_type="stuff", ) convo_chain = ConversationalRetrievalChain( retriever=retriever, question_generator=question_gen_chain, combine_docs_chain=final_qa_chain, memory=memory, return_source_documents=True, callback_manager=convo_cb_manager, ) run = asyncio.create_task(convo_chain.acall({"question": message})) async for token in streaming_cb.aiter(): # to return string # yield token # to return json if token.type in ["end", "error"]: streaming_cb.done.set() yield json.dumps(token.dict()) await run streaming_conversation_chain = StreamingConversationChain( openai_api_key="API_KEY", temperature=0.7 ) @app.post("/sse-chat-convo", response_class=StreamingResponse) async def generate_response(data: ChatRequest) -> StreamingResponse: """Endpoint for chat requests""" return StreamingResponse( streaming_conversation_chain.generate_response(data.conversation_id, data.message), media_type="text/event-stream", ) ```
Thanks you so much, but how do I return a string. I keep on getting errors when i try to do that.
Not sure if this helps ... but I've simplified my example to simply use a callback for the retriever.
Lmk if this works for llama and company ...
# load document loader = PyPDFLoader("example.pdf") documents = loader.load() # split the documents into chunks text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_documents(documents) # select which embeddings we want to use embeddings = OpenAIEmbeddings() # create the vectorestore to use as the index db = Chroma.from_documents(texts, embeddings) # expose this index in a retriever interface retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2}) class ChatRequest(BaseModel): """Request model for chat requests. Includes the conversation ID and the message from the user.""" conversation_id: str message: str class RetrieverCallbackHandler(AsyncIteratorCallbackHandler): def __init__(self, streaming_callback_handler) -> None: super().__init__() self.streaming_callback_handler = streaming_callback_handler async def on_retriever_end(self, source_docs, *, run_id, parent_run_id, tags, **kwargs): source_docs_d = [{"page": doc.metadata["page"]} for doc in source_docs] if source_docs else None xtra = {"source_documents": source_docs_d} self.streaming_callback_handler.queue.put_nowait(xtra) class StreamingConversationChain: """Class for handling streaming conversation chains.""" def __init__(self, openai_api_key: str, temperature: float = 0.0): self.memories = {} self.openai_api_key = openai_api_key self.temperature = temperature async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]: streaming_cb = AsyncIteratorCallbackHandler() retriever_cb = RetrieverCallbackHandler(streaming_callback_handler=streaming_cb) question_gen_llm = ChatOpenAI( model_name="gpt-3.5-turbo", max_retries=15, temperature=0.0, streaming=True, openai_api_key=self.openai_api_key, ) streaming_llm = ChatOpenAI( max_retries=15, temperature=0, callbacks=[streaming_cb], streaming=True, openai_api_key=self.openai_api_key, ) memory = self.memories.get(conversation_id) if memory is None: memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer") self.memories[conversation_id] = memory question_gen_chain = LLMChain(llm=question_gen_llm, prompt=CONDENSE_QUESTION_PROMPT) final_qa_chain = load_qa_chain(streaming_llm, chain_type="stuff") convo_chain = ConversationalRetrievalChain( retriever=retriever, question_generator=question_gen_chain, combine_docs_chain=final_qa_chain, memory=memory, return_source_documents=True, ) run = asyncio.create_task(convo_chain.acall({"question": message}, callbacks=[retriever_cb])) async for token in streaming_cb.aiter(): yield json.dumps(token) if isinstance(token, dict) else token await run
Thanks so much. I have made some modifications so that the return is only a string.
import asyncio
from functools import lru_cache
from typing import AsyncGenerator
from fastapi import Depends, FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from pydantic import BaseModel, BaseSettings, Field, validator
import json
from langchain.chains.question_answering import load_qa_chain
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
persist_directory = '../../ai_agent_logic/vectorstores/chroma'
class Settings(BaseSettings):
"""
Settings class for this application.
Utilizes the BaseSettings from pydantic for environment variables.
"""
openai_api_key: str = Field(..., env="OPENAI_API_KEY")
class Config:
env_file = ".env"
@lru_cache()
def get_settings():
"""
Function to get and cache settings.
The settings are cached to avoid repeated disk I/O.
"""
return Settings()
class ChatRequest(BaseModel):
"""Request model for chat requests. Includes the conversation ID and the message from the user."""
conversation_id: str
message: str
embedding = OpenAIEmbeddings(openai_api_key=get_settings().openai_api_key)
vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=embedding
)
class RetrieverCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(self, streaming_callback_handler) -> None:
super().__init__()
self.streaming_callback_handler = streaming_callback_handler
async def on_retriever_end(self, source_docs, *, run_id, parent_run_id, tags, **kwargs):
source_docs_d = [{"page": doc.metadata["page"]}
for doc in source_docs] if source_docs else None
xtra = {"source_documents": source_docs_d}
self.streaming_callback_handler.queue.put_nowait(xtra)
class StreamingConversationChain:
"""Class for handling streaming conversation chains."""
def __init__(self, openai_api_key: str, temperature: float = 0.0):
self.memories = {}
self.openai_api_key = openai_api_key
self.temperature = temperature
async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]:
streaming_cb = AsyncIteratorCallbackHandler()
retriever_cb = RetrieverCallbackHandler(
streaming_callback_handler=streaming_cb)
question_gen_llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
max_retries=15,
temperature=0.0,
streaming=True,
openai_api_key=self.openai_api_key,
)
streaming_llm = ChatOpenAI(
max_retries=15,
temperature=0,
callbacks=[streaming_cb],
streaming=True,
openai_api_key=self.openai_api_key,
)
memory = self.memories.get(conversation_id)
if memory is None:
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer")
self.memories[conversation_id] = memory
question_gen_chain = LLMChain(
llm=question_gen_llm, prompt=CHAT_PROMPT_TEMPLATE)
final_qa_chain = load_qa_chain(streaming_llm, chain_type="stuff")
convo_chain = ConversationalRetrievalChain(
retriever=vectordb.as_retriever(),
question_generator=question_gen_chain,
combine_docs_chain=final_qa_chain,
memory=memory,
return_source_documents=True,
)
run = asyncio.create_task(convo_chain.acall(
{"question": message}, callbacks=[retriever_cb]))
async for token in streaming_cb.aiter():
yield "" if isinstance(token, dict) else token
await run
app = FastAPI(dependencies=[Depends(get_settings)])
streaming_conversation_chain = StreamingConversationChain(
openai_api_key=get_settings().openai_api_key, temperature=0.7
)
CHAT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
"You're a AI that knows everything about cats."
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
@app.post("/chat", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
"""Endpoint for chat requests"""
return StreamingResponse(
streaming_conversation_chain.generate_response(
data.conversation_id, data.message),
media_type="text/event-stream",
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app)
Hi, everyone! 🚀 Exciting news – LangChain Expression Language is now available, and it makes implementing streaming responses a breeze, eliminating the need for manual callbacks.
I've put together a special gist to showcase its capabilities. Inside, you'll find everything you need to set up a QA bot that streams responses along with source documents, all built on FastAPI.
gist: https://gist.github.com/jvelezmagic/f3653cc2ddab1c91e86751c8b423a1b6
Example includes:
Persistent Chat Memory
: Stores chat history in a local file.Persistent Vector Store
: Stores document embeddings in a local vector store.Standalone Question Generation
: Rephrases follow-up questions to standalone questions in their original language.Document Retrieval
: Searches and retrieves relevant documents based on user queries.Context-Aware Responses
: Generates responses based on a combined context from relevant documents.Streaming Responses
: Streams responses in real time either as plain text or as Server-Sent Events (SSE). SSE also sends the relevant documents as context.
Happy coding! 🐈
Amazing . Any tutorial example would be appreciated.
@jvelezmagic i test your code but it's not work
This worked for me. Amazing. Thank you @jvelezmagic
For those getting error you can try updating to latest langchain version and it fixed the issue for me with the AsyncIteratorCallbackHandler()
I'm quite new to Python, I'm a bit confused by how state is handled in the example. The only way this makes sense to me is that running the app via uvicorn.run()
makes it so that the object of the StreamingConversationChain
is not remade every single time a request is made, but is kept "alive" and re-used until the app is shutdown; which is done if you're restarting the app to push a code update for example.
Is that correct? If not, I'd much appreciate if someone elaborated on how state is maintained over different HTTP sessions using the example of @jvelezmagic , much obliged by the way! <3
@coreation , you are right, the application in the example preserves the state, so memory would be available until shutdown. In a real case scenario you could you a database backed memory, like redis or postgresql to keep made the application unaware of the state. 🐾
@jvelezmagic thanks so much for the reply, I wasn't entirely sure but fully understanding how it works really helps out. Thanks for the gist, cheers!
Hi, I am trying to use ConversationalRetrievalChain with Azure Cognitive Search as retriever with streaming capabilities enabled. The code is not providing the output in a streaming manner. I would like to know if there is any such feature which is supported using Langchain combining Azure Cognitive Search with LLM.
The code snippet I used is as below.
Code Snippet
def search_docs_chain_with_memory_streaming(
search_index_name=os.getenv("AZURE_COGNITIVE_SEARCH_INDEX_NAME"),
question_list=[],
answer_list=[],
):
code = detect(question)
language_name = map_language_code_to_name(code)
embeddings = OpenAIEmbeddings(
deployment=oaienvs.OPENAI_EMBEDDING_DEPLOYMENT_NAME,
model=oaienvs.OPENAI_EMBEDDING_MODEL_NAME,
openai_api_base=os.environ["OPENAI_API_BASE"],
openai_api_type=os.environ["OPENAI_API_TYPE"],
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer")
acs = AzureSearch(
azure_search_endpoint=os.getenv("AZURE_SEARCH_SERVICE_ENDPOINT"),
azure_search_key=os.getenv("AZURE_COGNITIVE_SEARCH_API_KEY"),
index_name=search_index_name,
search_type="similarity",
semantic_configuration_name="default",
embedding_function=embeddings.embed_query,
)
retriever = acs.as_retriever()
retriever.search_kwargs = {"score_threshold": 0.8} # {'k':1}
print("language_name-----", language_name)
hcp_conv_template = (
get_prompt(workflows, "retrievalchain_hcp_conv_template1", "system_prompt", "v0")
- language_name +
get_prompt(workflows, "retrievalchain_hcp_conv_template2", "system_prompt", "v0")
)
CONDENSE_QUESTION_PROMPT = get_prompt(workflows, "retrievalchain_condense_question_prompt", "system_prompt", "v0")
prompt = PromptTemplate(
input_variables=["question"], template=CONDENSE_QUESTION_PROMPT
)
SYSTEM_MSG2 = get_prompt(workflows, "retrievalchain_system_msg_template", "system_prompt", "v0")
messages = [
SystemMessagePromptTemplate.from_template(SYSTEM_MSG2),
HumanMessagePromptTemplate.from_template(hcp_conv_template),
]
qa_prompt = ChatPromptTemplate.from_messages(messages)
llm = AzureChatOpenAI(
deployment_name=oaienvs.OPENAI_CHAT_MODEL_DEPLOYMENT_NAME, temperature=0.7, max_retries=4,
#callbacks=[streaming_cb],
streaming=True
#callback_manager=CallbackManager([MyCustomHandler()])
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
get_chat_history=lambda o: o,
memory=memory,
condense_question_prompt=prompt,
return_source_documents=True,
verbose=True,
#callback_manager=convo_cb_manager,
#condense_question_llm = llm_condense_ques,
combine_docs_chain_kwargs={"prompt": qa_prompt},
)
if len(question_list) == 0:
question = question + ". Give the answer only in " + language_name + "."
for i in range(len(question_list)):
qa_chain.memory.save_context(
inputs={"question": question_list[i]}, outputs={"answer": answer_list[i]}
)
#return qa_chain.stream({"question": question, "chat_history": []})
return qa_chain
Also I have tried different callback handlers and invoke methods as mentioned in https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
Kindly suggest if there is any workaround to it.
is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working
is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working
@ZanyuanYang what database are you using? is it in-memory?
if it's an in-memory storage, building an API would be diffcult because each call to the endpoint will not share the memory
is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working
@ZanyuanYang what database are you using? is it in-memory?
if it's an in-memory storage, building an API would be diffcult because each call to the endpoint will not share the memory
I used Elasticsearch AvikantSrivastava
@AvikantSrivastava This is my code
class StreamingLLMCallbackHandler(AsyncIteratorCallbackHandler):
"""Callback handler for streaming LLM responses."""
async def on_llm_start(self, serialized, prompts, **kwargs) -> None:
logging.info("LLM start")
self.done.clear()
self.queue.put_nowait(ChatResponse(sender="bot", message="", type="start"))
async def on_llm_end(self, response, **kwargs) -> None:
logging.info("LLM end")
# we override this method since we want the ConversationalRetrievalChain to potentially return
# other items (e.g., source_documents) after it is completed
pass
async def on_llm_error(
self, error: Exception | KeyboardInterrupt, **kwargs
) -> None:
logging.error(f"LLM error: {error}")
self.queue.put_nowait(
ChatResponse(sender="bot", message=str(error), type="error")
)
async def on_llm_new_token(self, token: str, **kwargs) -> None:
# if token not in ['"', "}"]:
self.queue.put_nowait(ChatResponse(sender="bot", message=token, type="stream"))
class ConvoChainCallbackHandler(AsyncCallbackHandler):
"""Use to add additional information (e.g., source_documents, etc...) once the chain finishes"""
def __init__(self, callback_handler) -> None:
super().__init__()
self.callback_handler = callback_handler
async def on_chain_end(self, outputs, *, run_id, parent_run_id, **kwargs) -> None:
"""Run after chain ends running."""
source_docs = outputs.get("source_documents", None)
doc_list = [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in source_docs
]
metadata_list = getMetadataFromCourtListener(doc_list)
# metadata = {"metadata": metadata_list}
self.callback_handler.queue.put_nowait(
ChatResponse(sender="bot", message="", metadata=metadata_list, type="info")
)
self.callback_handler.queue.put_nowait(
ChatResponse(sender="bot", message="", type="end")
)
class StreamingConversationChain:
"""Class for handling streaming conversation chains."""
def __init__(self, temperature: float = 0.0):
self.memories = {}
self.openai_api_key = openai_api_key
self.temperature = temperature
async def generate_response(
self, conversation_id: str, message: str
) -> AsyncGenerator[str, None]:
"""
Asynchronous function to generate a response for a conversation.
It creates a new conversation chain for each message and uses a
callback handler to stream responses as they're generated.
:param conversation_id: The ID of the conversation.
:param message: The message from the user.
"""
streaming_cb = StreamingLLMCallbackHandler() # AsyncIteratorCallbackHandler()
convo_cb_manager = AsyncCallbackManager(
[ConvoChainCallbackHandler(streaming_cb)]
)
question_gen_llm = ChatOpenAI(
model_name="gpt-3.5-turbo-1106",
temperature=0.0,
streaming=True,
openai_api_key=self.openai_api_key,
max_tokens=4097,
)
streaming_llm = ChatOpenAI(
model_name="gpt-3.5-turbo-1106",
temperature=0,
callbacks=[streaming_cb],
streaming=True,
openai_api_key=self.openai_api_key,
)
memory = self.memories.get(conversation_id)
if memory is None:
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer",
)
self.memories[conversation_id] = memory
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)
question_gen_chain = LLMChain(
llm=question_gen_llm, prompt=prompt
) # , callback_manager=manager)
final_qa_chain = load_qa_chain(
streaming_llm,
chain_type="stuff",
)
es_retriever = LexARIElasticSearchBM25Retriever(
client=ES_CLIENT, index_name=ES_INDEX_NAME
)
docs = es_retriever.get_relevant_documents(message)
qdrant = Qdrant.from_documents(
docs,
EMBEDDINGS,
location=":memory:", # Local mode with in-memory storage only
collection_name="my_documents",
)
convo_chain = ConversationalRetrievalChain(
retriever=qdrant.as_retriever(search_type="similarity"),
question_generator=question_gen_chain,
combine_docs_chain=final_qa_chain,
memory=memory,
return_source_documents=True,
callback_manager=convo_cb_manager,
max_tokens_limit=16385,
)
run = asyncio.create_task(convo_chain.acall({"question": message}))
async for token in streaming_cb.aiter():
# Print for debugging purposes
print("dict: ", token.dict())
# Yield the response as JSON
yield json.dumps(token.dict())
if token.dict().get("type") in ["end", "error"]:
streaming_cb.done.set()
# Wait for the conversation chain task to complete
await run
Is anyone know how to fixed this error?
return inputs[prompt_input_key], outputs[output_key]
~~~~~~~^^^^^^^^^^^^
KeyError: 'answer'
return_source_documents=False
i hope after this changes your chain will work fine
Not sure if this helps ... but I've simplified my example to simply use a callback for the retriever.
Lmk if this works for llama and company ...