Created
September 8, 2024 13:32
-
-
Save kzinmr/678284ebba3ff004bc80719d75883e01 to your computer and use it in GitHub Desktop.
Adaptation from https://lancedb.github.io/lancedb/examples/serverless_qa_bot_with_modal_and_langchain/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
from pathlib import Path | |
import modal | |
lancedb_image = modal.Image.debian_slim().pip_install( | |
"lancedb", | |
"langchain_core", | |
"langchain_community", | |
"langchain_openai", | |
"openai", | |
"pandas", | |
"tiktoken", | |
"unstructured", | |
"tabulate", | |
) | |
app = modal.App( | |
name="example-langchain-lancedb", | |
image=lancedb_image, | |
secrets=[modal.Secret.from_name("openai-secret")], | |
) | |
embedding_model = "text-embedding-3-small" | |
chat_model = "gpt-4o-mini" | |
@app.function( | |
mounts=[ | |
modal.Mount.from_local_file("docs.pkl", remote_path=Path("/root") / "docs.pkl") | |
] | |
) | |
def load_docs(): | |
docs_path = Path("/root") / "docs.pkl" | |
if docs_path.exists(): | |
with docs_path.open("rb") as fh: | |
return pickle.load(fh) | |
else: | |
raise FileNotFoundError(f"File {docs_path} not found") | |
@app.cls(cpu=1) | |
class Model: | |
@modal.enter() | |
def run_this_on_container_startup(self): | |
# import lancedb | |
from langchain import hub | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.vectorstores.lancedb import LanceDB | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
# Download and store docs | |
# download_docs() # download docs locally, which is mounted into the container | |
self.documents = load_docs.remote() | |
# Build Vector DB | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
) | |
documents = text_splitter.split_documents(self.documents) | |
embeddings = OpenAIEmbeddings(model=embedding_model) | |
# db = lancedb.connect(db_path) | |
# table = db.create_table( | |
# table_name, | |
# data=[ | |
# { | |
# "vector": embeddings.embed_query("Hello World"), | |
# "text": "Hello World", | |
# "id": "1", | |
# } | |
# ], | |
# mode="overwrite", | |
# ) | |
self.docsearch = LanceDB.from_documents(documents, embeddings) | |
# Build RAG chain | |
self.llm = ChatOpenAI(model=chat_model) | |
# See full prompt at https://smith.langchain.com/hub/langchain-ai/retrieval-qa-chat | |
self.prompt = hub.pull("langchain-ai/retrieval-qa-chat") | |
_combine_docs_chain = create_stuff_documents_chain(self.llm, self.prompt) | |
self.rag_chain = create_retrieval_chain( | |
self.docsearch.as_retriever(), _combine_docs_chain | |
) | |
@modal.method() | |
def qanda_langchain(self, query: str): | |
res = self.rag_chain.invoke({"input": query}) | |
return { | |
"input": res["input"], | |
"context": [doc.dict() for doc in res["context"]], | |
"answer": res["answer"], | |
} | |
@modal.web_endpoint(method="GET") | |
def web(self, query: str): | |
return self.qanda_langchain.remote(query) | |
@app.local_entrypoint() | |
def cli(query: str): | |
answer = Model().qanda_langchain.remote(query) | |
print(answer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Build
docs.pkl
with the following script: