Last active
February 18, 2026 01:49
-
-
Save donbr/6c57f20f8008dbffddd5e8e7a72cb026 to your computer and use it in GitHub Desktop.
Advanced Retrieval with LangChain
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # # Advanced Retrieval with LangChain | |
| # Standard Library Imports | |
| import getpass | |
| import os | |
| from operator import itemgetter | |
| from uuid import uuid4 | |
| # Third-Party Imports | |
| from dotenv import load_dotenv | |
| # LangChain Core | |
| from langchain.retrievers import EnsembleRetriever, ParentDocumentRetriever | |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain.storage import InMemoryStore | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| # LangChain Community | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain_community.retrievers import BM25Retriever | |
| # LangChain Integrations | |
| from langchain_cohere import CohereRerank | |
| from langchain_experimental.text_splitter import SemanticChunker | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_qdrant import QdrantVectorStore | |
| # Qdrant | |
| from qdrant_client import QdrantClient, models | |
| load_dotenv() | |
| if not os.environ.get("OPENAI_API_KEY"): | |
| os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key:") | |
| if not os.environ.get("COHERE_API_KEY"): | |
| os.environ["COHERE_API_KEY"] = getpass.getpass("Cohere API Key:") | |
| if not os.environ.get("LANGCHAIN_API_KEY"): | |
| os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("LangChain API Key:") | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_PROJECT"] = f"AIM - Advanced Retrieval - {uuid4().hex[0:8]}" | |
| chat_model = ChatOpenAI(model="gpt-4.1-nano") | |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
| RAG_TEMPLATE = """\ | |
| You are a helpful and kind assistant. Use the context provided below to answer the question. | |
| If you do not know the answer, or are unsure, say you don't know. | |
| Query: | |
| {question} | |
| Context: | |
| {context} | |
| """ | |
| rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE) | |
| loader = TextLoader("data/HealthWellnessGuide.txt") | |
| raw_docs = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| wellness_docs = text_splitter.split_documents(raw_docs) | |
| print(f"Raw documents: {len(raw_docs)}") | |
| print(f"Split chunks: {len(wellness_docs)}") | |
| print(f"\nExample chunk:\n{wellness_docs[0]}") | |
| vectorstore = QdrantVectorStore.from_documents( | |
| wellness_docs, | |
| embeddings, | |
| location=":memory:", | |
| collection_name="wellness_guide", | |
| ) | |
| naive_retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) | |
| semantic_chunker = SemanticChunker(embeddings, breakpoint_threshold_type="percentile") | |
| semantic_documents = semantic_chunker.split_documents(raw_docs) | |
| semantic_vectorstore = QdrantVectorStore.from_documents( | |
| semantic_documents, | |
| embeddings, | |
| location=":memory:", | |
| collection_name="wellness_guide_semantic_chunks", | |
| ) | |
| semantic_retriever = semantic_vectorstore.as_retriever(search_kwargs={"k": 10}) | |
| parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200) | |
| child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50) | |
| client = QdrantClient(location=":memory:") | |
| client.create_collection( | |
| collection_name="wellness_parent_child", | |
| vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE), | |
| ) | |
| parent_document_vectorstore = QdrantVectorStore( | |
| collection_name="wellness_parent_child", | |
| embedding=OpenAIEmbeddings(model="text-embedding-3-small"), | |
| client=client, | |
| ) | |
| store = InMemoryStore() | |
| parent_document_retriever = ParentDocumentRetriever( | |
| vectorstore=parent_document_vectorstore, | |
| docstore=store, | |
| child_splitter=child_splitter, | |
| parent_splitter=parent_splitter, | |
| ) | |
| parent_document_retriever.add_documents(raw_docs, ids=None) | |
| bm25_retriever = BM25Retriever.from_documents(wellness_docs) | |
| compressor = CohereRerank(model="rerank-v3.5") | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, base_retriever=naive_retriever | |
| ) | |
| multi_query_retriever = MultiQueryRetriever.from_llm( | |
| retriever=naive_retriever, llm=chat_model | |
| ) | |
| retriever_list = [ | |
| bm25_retriever, | |
| naive_retriever, | |
| parent_document_retriever, | |
| compression_retriever, | |
| multi_query_retriever, | |
| ] | |
| equal_weighting = [1 / len(retriever_list)] * len(retriever_list) | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=retriever_list, weights=equal_weighting | |
| ) | |
| naive_retrieval_chain = ( | |
| { | |
| "context": itemgetter("question") | naive_retriever, | |
| "question": itemgetter("question"), | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | chat_model, "context": itemgetter("context")} | |
| ) | |
| bm25_retrieval_chain = ( | |
| { | |
| "context": itemgetter("question") | bm25_retriever, | |
| "question": itemgetter("question"), | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | chat_model, "context": itemgetter("context")} | |
| ) | |
| contextual_compression_retrieval_chain = ( | |
| { | |
| "context": itemgetter("question") | compression_retriever, | |
| "question": itemgetter("question"), | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | chat_model, "context": itemgetter("context")} | |
| ) | |
| multi_query_retrieval_chain = ( | |
| { | |
| "context": itemgetter("question") | multi_query_retriever, | |
| "question": itemgetter("question"), | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | chat_model, "context": itemgetter("context")} | |
| ) | |
| parent_document_retrieval_chain = ( | |
| { | |
| "context": itemgetter("question") | parent_document_retriever, | |
| "question": itemgetter("question"), | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | chat_model, "context": itemgetter("context")} | |
| ) | |
| ensemble_retrieval_chain = ( | |
| { | |
| "context": itemgetter("question") | ensemble_retriever, | |
| "question": itemgetter("question"), | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | chat_model, "context": itemgetter("context")} | |
| ) | |
| semantic_retrieval_chain = ( | |
| { | |
| "context": itemgetter("question") | semantic_retriever, | |
| "question": itemgetter("question"), | |
| } | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": rag_prompt | chat_model, "context": itemgetter("context")} | |
| ) | |
| naive_retrieval_chain.invoke( | |
| {"question": "What exercises can help with lower back pain?"} | |
| )["response"].content | |
| naive_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[ | |
| "response" | |
| ].content | |
| naive_retrieval_chain.invoke( | |
| {"question": "What are some natural remedies for stress and headaches?"} | |
| )["response"].content | |
| bm25_retrieval_chain.invoke( | |
| {"question": "What exercises can help with lower back pain?"} | |
| )["response"].content | |
| bm25_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[ | |
| "response" | |
| ].content | |
| bm25_retrieval_chain.invoke( | |
| {"question": "What are some natural remedies for stress and headaches?"} | |
| )["response"].content | |
| contextual_compression_retrieval_chain.invoke( | |
| {"question": "What exercises can help with lower back pain?"} | |
| )["response"].content | |
| contextual_compression_retrieval_chain.invoke( | |
| {"question": "How does sleep affect overall health?"} | |
| )["response"].content | |
| contextual_compression_retrieval_chain.invoke( | |
| {"question": "What are some natural remedies for stress and headaches?"} | |
| )["response"].content | |
| multi_query_retrieval_chain.invoke( | |
| {"question": "What exercises can help with lower back pain?"} | |
| )["response"].content | |
| multi_query_retrieval_chain.invoke( | |
| {"question": "How does sleep affect overall health?"} | |
| )["response"].content | |
| multi_query_retrieval_chain.invoke( | |
| {"question": "What are some natural remedies for stress and headaches?"} | |
| )["response"].content | |
| parent_document_retrieval_chain.invoke( | |
| {"question": "What exercises can help with lower back pain?"} | |
| )["response"].content | |
| parent_document_retrieval_chain.invoke( | |
| {"question": "How does sleep affect overall health?"} | |
| )["response"].content | |
| parent_document_retrieval_chain.invoke( | |
| {"question": "What are some natural remedies for stress and headaches?"} | |
| )["response"].content | |
| ensemble_retrieval_chain.invoke( | |
| {"question": "What exercises can help with lower back pain?"} | |
| )["response"].content | |
| ensemble_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[ | |
| "response" | |
| ].content | |
| ensemble_retrieval_chain.invoke( | |
| {"question": "What are some natural remedies for stress and headaches?"} | |
| )["response"].content | |
| semantic_retrieval_chain.invoke( | |
| {"question": "What exercises can help with lower back pain?"} | |
| )["response"].content | |
| semantic_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[ | |
| "response" | |
| ].content | |
| semantic_retrieval_chain.invoke( | |
| {"question": "What are some natural remedies for stress and headaches?"} | |
| )["response"].content | |
| # --- 1) Chain registry (use your existing chain objects) --- | |
| CHAINS = { | |
| "naive": naive_retrieval_chain, | |
| "bm25": bm25_retrieval_chain, | |
| "compression": contextual_compression_retrieval_chain, | |
| "multi_query": multi_query_retrieval_chain, | |
| "parent_doc": parent_document_retrieval_chain, | |
| "ensemble": ensemble_retrieval_chain, | |
| "semantic": semantic_retrieval_chain, | |
| } | |
| # --- 2) Minimal helpers to normalize outputs --- | |
| def _to_text(resp_dict): | |
| """Your chains return {'response': <AIMessage|str>, 'context': [...] }.""" | |
| r = resp_dict.get("response") | |
| if hasattr(r, "content"): # AIMessage | |
| return r.content | |
| return str(r) if r is not None else "" | |
| def _to_context(resp_dict): | |
| return resp_dict.get("context", []) | |
| # --- 3) Run a single question across selected chains --- | |
| def run_all(question: str, chains=CHAINS): | |
| results = {} | |
| for name, ch in chains.items(): | |
| out = ch.invoke({"question": question}) | |
| results[name] = { | |
| "answer": _to_text(out), | |
| "contexts": _to_context(out), | |
| } | |
| return results | |
| # --- 4) Convenience: quick pretty print for ad-hoc inspection --- | |
| def print_quick(results, max_len=200): | |
| for name, rec in results.items(): | |
| ans = rec["answer"].strip().replace("\n", " ") | |
| print(f"[{name}] {ans[:max_len]}{'…' if len(ans) > max_len else ''}") | |
| # single question across all chains | |
| res = run_all("What exercises can help with lower back pain?") | |
| print_quick(res) | |
| def run_batch(questions, chains=CHAINS): | |
| """Returns: dict[chain_name] -> list of {question, answer, contexts}.""" | |
| payloads = [{"question": q} for q in questions] | |
| all_results = {} | |
| for name, ch in chains.items(): | |
| outs = ch.batch(payloads) | |
| all_results[name] = [ | |
| { | |
| "question": q["question"], | |
| "answer": _to_text(o), | |
| "contexts": _to_context(o), | |
| } | |
| for q, o in zip(payloads, outs) | |
| ] | |
| return all_results | |
| def print_results(all_results, max_answer=150, max_context=100, max_ctxs=2): | |
| """Nicely print abridged chain results for inspection.""" | |
| for name, records in all_results.items(): | |
| print(f"\n=== {name.upper()} ===") | |
| for rec in records: | |
| print(f"Q: {rec['question']}") | |
| ans = rec["answer"].strip().replace("\n", " ") | |
| print(f"A: {ans[:max_answer]}{'…' if len(ans) > max_answer else ''}") | |
| ctxs = rec["contexts"][:max_ctxs] | |
| for i, c in enumerate(ctxs, 1): | |
| snippet = c.page_content.strip().replace("\n", " ") | |
| print( | |
| f" [ctx{i}] {snippet[:max_context]}{'…' if len(snippet) > max_context else ''}" | |
| ) | |
| print() # blank line between questions | |
| QUESTIONS = [ | |
| "What exercises can help with lower back pain?", | |
| "How does sleep affect overall health?", | |
| ] | |
| batched_results = run_batch(QUESTIONS) | |
| print_results(batched_results) | |
| # ============================================================ | |
| # PHASE 2: LangSmith Evaluation Harness | |
| # ============================================================ | |
| # | |
| # Three-step process: | |
| # A. Create (or reuse) a benchmark dataset in LangSmith | |
| # B. Define evaluator functions (QA correctness + context relevance) | |
| # C. Run the harness across all chains and print a summary | |
| # ============================================================ | |
| from langsmith import Client # noqa: E402 | |
| from langsmith.evaluation import evaluate # noqa: E402 | |
| from langchain.evaluation.qa import ContextQAEvalChain, QAEvalChain # noqa: E402 | |
| # --------------------------------------------------------------------------- | |
| # A. Dataset Creation | |
| # --------------------------------------------------------------------------- | |
| ls_client = Client() | |
| DATASET_NAME = "Health Wellness Benchmark v1" | |
| BENCHMARK_EXAMPLES = [ | |
| { | |
| "inputs": {"question": "What exercises can help with lower back pain?"}, | |
| "outputs": { | |
| "answer": ( | |
| "Exercises that can help with lower back pain include the Cat-Cow " | |
| "Stretch for spinal flexibility, Bird Dog for core stability, and " | |
| "Pelvic Tilts to strengthen the lower back. Child's Pose is also " | |
| "recommended for gentle stretching, and Bridges help strengthen " | |
| "the glutes and lower back." | |
| ) | |
| }, | |
| }, | |
| { | |
| "inputs": {"question": "How does sleep affect overall health?"}, | |
| "outputs": { | |
| "answer": ( | |
| "Sleep is essential for overall health. It allows the body to " | |
| "repair tissues, consolidate memories, and regulate hormones. " | |
| "Poor sleep is linked to increased risk of heart disease, obesity, " | |
| "weakened immunity, and impaired cognitive function. Adults should " | |
| "aim for 7-9 hours of quality sleep per night." | |
| ) | |
| }, | |
| }, | |
| { | |
| "inputs": { | |
| "question": "What are some natural remedies for stress and headaches?" | |
| }, | |
| "outputs": { | |
| "answer": ( | |
| "Natural remedies for stress include deep breathing exercises, " | |
| "meditation, regular physical activity, and herbal teas such as " | |
| "chamomile. For headaches, staying hydrated, applying peppermint " | |
| "oil to the temples, practicing relaxation techniques, and " | |
| "ensuring adequate sleep can provide relief." | |
| ) | |
| }, | |
| }, | |
| ] | |
| if ls_client.has_dataset(dataset_name=DATASET_NAME): | |
| print(f"Dataset '{DATASET_NAME}' already exists — skipping creation.") | |
| else: | |
| dataset = ls_client.create_dataset( | |
| dataset_name=DATASET_NAME, | |
| description="Benchmark questions for the Health & Wellness RAG chains.", | |
| ) | |
| ls_client.create_examples( | |
| inputs=[ex["inputs"] for ex in BENCHMARK_EXAMPLES], | |
| outputs=[ex["outputs"] for ex in BENCHMARK_EXAMPLES], | |
| dataset_id=dataset.id, | |
| ) | |
| print(f"Created dataset '{DATASET_NAME}' with {len(BENCHMARK_EXAMPLES)} examples.") | |
| # --------------------------------------------------------------------------- | |
| # B. Evaluators | |
| # --------------------------------------------------------------------------- | |
| eval_llm = ChatOpenAI(model="gpt-4.1-mini", temperature=0) | |
| def _extract_text(value): | |
| """Defensively extract a plain string from an AIMessage, dict, or str.""" | |
| if value is None: | |
| return "" | |
| if hasattr(value, "content"): # AIMessage | |
| return value.content | |
| if isinstance(value, dict) and "content" in value: | |
| return value["content"] | |
| return str(value) | |
| def _extract_context(value): | |
| """Join retrieved documents into a single context string.""" | |
| if not value: | |
| return "" | |
| pieces = [] | |
| for doc in value: | |
| if hasattr(doc, "page_content"): | |
| pieces.append(doc.page_content) | |
| elif isinstance(doc, dict) and "page_content" in doc: | |
| pieces.append(doc["page_content"]) | |
| elif isinstance(doc, str): | |
| pieces.append(doc) | |
| return "\n\n".join(pieces) | |
| _qa_chain = QAEvalChain.from_llm(eval_llm) | |
| _ctx_chain = ContextQAEvalChain.from_llm(eval_llm) | |
| def qa_correctness(outputs: dict, reference_outputs: dict, inputs: dict) -> dict: | |
| """Compare generated answer against reference answer for correctness.""" | |
| prediction = _extract_text(outputs.get("response")) | |
| reference = _extract_text(reference_outputs.get("answer", "")) | |
| question = inputs.get("question", "") | |
| eval_result = _qa_chain.evaluate_strings( | |
| prediction=prediction, | |
| reference=reference, | |
| input=question, | |
| ) | |
| score = 1.0 if eval_result.get("value", "").upper() == "CORRECT" else 0.0 | |
| return {"key": "qa_correctness", "score": score} | |
| def context_relevance(outputs: dict, reference_outputs: dict, inputs: dict) -> dict: | |
| """Check whether retrieved context supports the generated answer.""" | |
| prediction = _extract_text(outputs.get("response")) | |
| reference = _extract_text(reference_outputs.get("answer", "")) | |
| context_str = _extract_context(outputs.get("context")) | |
| question = inputs.get("question", "") | |
| eval_result = _ctx_chain.evaluate_strings( | |
| prediction=prediction, | |
| reference=reference, | |
| input=question, | |
| context=context_str, | |
| ) | |
| score = 1.0 if eval_result.get("value", "").upper() == "CORRECT" else 0.0 | |
| return {"key": "context_relevance", "score": score} | |
| # --------------------------------------------------------------------------- | |
| # C. Execution Harness | |
| # --------------------------------------------------------------------------- | |
| print("\n" + "=" * 60) | |
| print("Running LangSmith evaluation harness across all chains...") | |
| print("=" * 60) | |
| all_experiment_results = {} | |
| for chain_name, chain in CHAINS.items(): | |
| print(f"\nEvaluating chain: {chain_name}") | |
| results = evaluate( | |
| chain.invoke, | |
| data=DATASET_NAME, | |
| evaluators=[qa_correctness, context_relevance], | |
| experiment_prefix=f"retrieval-{chain_name}", | |
| metadata={ | |
| "chain": chain_name, | |
| "eval_model": "gpt-4.1-mini", | |
| "rag_model": "gpt-4.1-nano", | |
| }, | |
| max_concurrency=1, | |
| ) | |
| all_experiment_results[chain_name] = results | |
| # --------------------------------------------------------------------------- | |
| # D. Summary Output | |
| # --------------------------------------------------------------------------- | |
| print("\n" + "=" * 60) | |
| print("EVALUATION SUMMARY") | |
| print("=" * 60) | |
| for chain_name, results in all_experiment_results.items(): | |
| scores_by_key = {} | |
| for result in results: | |
| for eval_res in result.get("evaluation_results", {}).get("results", []): | |
| key = eval_res.key | |
| if eval_res.score is not None: | |
| scores_by_key.setdefault(key, []).append(eval_res.score) | |
| print(f"\n--- {chain_name} ---") | |
| if not scores_by_key: | |
| print(" (no scores collected)") | |
| for metric, scores in sorted(scores_by_key.items()): | |
| avg = sum(scores) / len(scores) if scores else 0.0 | |
| print(f" {metric}: {avg:.2f} ({len(scores)} examples)") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment