Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Created September 8, 2024 13:32
Show Gist options
  • Save kzinmr/678284ebba3ff004bc80719d75883e01 to your computer and use it in GitHub Desktop.
Save kzinmr/678284ebba3ff004bc80719d75883e01 to your computer and use it in GitHub Desktop.
import pickle
from pathlib import Path
import modal
lancedb_image = modal.Image.debian_slim().pip_install(
"lancedb",
"langchain_core",
"langchain_community",
"langchain_openai",
"openai",
"pandas",
"tiktoken",
"unstructured",
"tabulate",
)
app = modal.App(
name="example-langchain-lancedb",
image=lancedb_image,
secrets=[modal.Secret.from_name("openai-secret")],
)
embedding_model = "text-embedding-3-small"
chat_model = "gpt-4o-mini"
@app.function(
mounts=[
modal.Mount.from_local_file("docs.pkl", remote_path=Path("/root") / "docs.pkl")
]
)
def load_docs():
docs_path = Path("/root") / "docs.pkl"
if docs_path.exists():
with docs_path.open("rb") as fh:
return pickle.load(fh)
else:
raise FileNotFoundError(f"File {docs_path} not found")
@app.cls(cpu=1)
class Model:
@modal.enter()
def run_this_on_container_startup(self):
# import lancedb
from langchain import hub
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.vectorstores.lancedb import LanceDB
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Download and store docs
# download_docs() # download docs locally, which is mounted into the container
self.documents = load_docs.remote()
# Build Vector DB
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
documents = text_splitter.split_documents(self.documents)
embeddings = OpenAIEmbeddings(model=embedding_model)
# db = lancedb.connect(db_path)
# table = db.create_table(
# table_name,
# data=[
# {
# "vector": embeddings.embed_query("Hello World"),
# "text": "Hello World",
# "id": "1",
# }
# ],
# mode="overwrite",
# )
self.docsearch = LanceDB.from_documents(documents, embeddings)
# Build RAG chain
self.llm = ChatOpenAI(model=chat_model)
# See full prompt at https://smith.langchain.com/hub/langchain-ai/retrieval-qa-chat
self.prompt = hub.pull("langchain-ai/retrieval-qa-chat")
_combine_docs_chain = create_stuff_documents_chain(self.llm, self.prompt)
self.rag_chain = create_retrieval_chain(
self.docsearch.as_retriever(), _combine_docs_chain
)
@modal.method()
def qanda_langchain(self, query: str):
res = self.rag_chain.invoke({"input": query})
return {
"input": res["input"],
"context": [doc.dict() for doc in res["context"]],
"answer": res["answer"],
}
@modal.web_endpoint(method="GET")
def web(self, query: str):
return self.qanda_langchain.remote(query)
@app.local_entrypoint()
def cli(query: str):
answer = Model().qanda_langchain.remote(query)
print(answer)
@kzinmr
Copy link
Author

kzinmr commented Sep 8, 2024

Build docs.pkl with the following script:

import pickle
import tempfile
import zipfile
from pathlib import Path
from urllib.request import urlopen

from langchain_community.document_loaders import UnstructuredHTMLLoader

tempdir = tempfile.mkdtemp()
print(tempdir)
download_path = Path(tempdir, "pandas.documentation.zip")
extract_path = Path(tempdir, "pandas_docs")
docs_path = Path("docs.pkl")


def get_document_title(doc) -> str:
    # sample = "page_content='The page has been moved to Merge, join, concatenate and compare' metadata={'source': '/var/folders/2d/9n9lhrj936x18qcdwdh9fly80000gn/T/tmpd0pb3ewu/pandas_docs/pandas.documentation/merging.html'}"
    return doc.page_content.split("\n")[0]


def store_docs():
    url = "https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip"
    with open(download_path, "wb") as fh:
        fh.write(urlopen(url).read())
    zipfile.ZipFile(download_path).extractall(path=extract_path)

    docs = []
    root_path = extract_path / "pandas.documentation"
    for root, dirs, files in root_path.walk(on_error=print):
        for p in files:
            if Path(p).suffix != ".html":
                continue

            full_path = (root / p).absolute()
            loader = UnstructuredHTMLLoader(str(full_path))
            raw_document = loader.load()

            _raw_document = raw_document[0]

            title = get_document_title(_raw_document)
            _raw_document.metadata |= {"title": title, "version": "2.0rc0"}
            _raw_document.metadata["source"] = str(_raw_document.metadata["source"])

            raw_document[0] = _raw_document
            docs.extend(raw_document)

    with docs_path.open("wb") as fh:
        pickle.dump(docs, fh)

    return docs


if __name__ == "__main__":
    store_docs()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment