Created
October 25, 2024 14:59
-
-
Save amosgyamfi/ab99aa937cf8b93f1010c20a8d73963b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import openai | |
import chainlit as cl | |
from llama_index.core import ( | |
Settings, | |
StorageContext, | |
VectorStoreIndex, | |
SimpleDirectoryReader, | |
load_index_from_storage, | |
) | |
from llama_index.llms.openai import OpenAI | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine | |
from llama_index.core.callbacks import CallbackManager | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
try: | |
# Rebuild storage context | |
storage_context = StorageContext.from_defaults(persist_dir="./storage") | |
# Load existing index | |
index = load_index_from_storage(storage_context) | |
except Exception: | |
# If loading fails, create a new index | |
documents = SimpleDirectoryReader("./data").load_data(show_progress=True) | |
index = VectorStoreIndex.from_documents(documents) | |
index.storage_context.persist() | |
@cl.on_chat_start | |
async def start(): | |
# Configure global settings | |
Settings.llm = OpenAI( | |
model="gpt-4", | |
temperature=0.1, | |
max_tokens=1024, | |
streaming=True, | |
) | |
Settings.embed_model = OpenAIEmbedding( | |
model="text-embedding-ada-002" | |
) # Ensure correct model name | |
Settings.context_window = 4096 | |
# Set callback_manager globally | |
Settings.callback_manager = CallbackManager([cl.LlamaIndexCallbackHandler()]) | |
# Create query engine without passing callback_manager | |
query_engine = index.as_query_engine(streaming=True, similarity_top_k=2) | |
# Store the query engine in the user session | |
cl.user_session.set("query_engine", query_engine) | |
# Send welcome message to the user | |
await cl.Message( | |
author="Assistant", content="Hello! I'm an AI assistant. How may I help you?" | |
).send() | |
@cl.on_message | |
async def main(message: cl.Message): | |
# Retrieve the query engine from the user session | |
query_engine = cl.user_session.get("query_engine") # type: RetrieverQueryEngine | |
# Initialize a new message from the assistant | |
msg = cl.Message(content="", author="Assistant") | |
# Perform the query asynchronously | |
res = await cl.make_async(query_engine.query)(message.content) | |
# Stream the response tokens to the user | |
for token in res.response_gen: | |
await msg.stream_token(token) | |
await msg.send() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment