Skip to content

Instantly share code, notes, and snippets.

@donbr
Last active February 18, 2026 01:49
Show Gist options
  • Select an option

  • Save donbr/6c57f20f8008dbffddd5e8e7a72cb026 to your computer and use it in GitHub Desktop.

Select an option

Save donbr/6c57f20f8008dbffddd5e8e7a72cb026 to your computer and use it in GitHub Desktop.
Advanced Retrieval with LangChain
# # Advanced Retrieval with LangChain
# Standard Library Imports
import getpass
import os
from operator import itemgetter
from uuid import uuid4
# Third-Party Imports
from dotenv import load_dotenv
# LangChain Core
from langchain.retrievers import EnsembleRetriever, ParentDocumentRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.storage import InMemoryStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
# LangChain Community
from langchain_community.document_loaders import TextLoader
from langchain_community.retrievers import BM25Retriever
# LangChain Integrations
from langchain_cohere import CohereRerank
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_qdrant import QdrantVectorStore
# Qdrant
from qdrant_client import QdrantClient, models
load_dotenv()
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key:")
if not os.environ.get("COHERE_API_KEY"):
os.environ["COHERE_API_KEY"] = getpass.getpass("Cohere API Key:")
if not os.environ.get("LANGCHAIN_API_KEY"):
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("LangChain API Key:")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = f"AIM - Advanced Retrieval - {uuid4().hex[0:8]}"
chat_model = ChatOpenAI(model="gpt-4.1-nano")
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
RAG_TEMPLATE = """\
You are a helpful and kind assistant. Use the context provided below to answer the question.
If you do not know the answer, or are unsure, say you don't know.
Query:
{question}
Context:
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)
loader = TextLoader("data/HealthWellnessGuide.txt")
raw_docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
wellness_docs = text_splitter.split_documents(raw_docs)
print(f"Raw documents: {len(raw_docs)}")
print(f"Split chunks: {len(wellness_docs)}")
print(f"\nExample chunk:\n{wellness_docs[0]}")
vectorstore = QdrantVectorStore.from_documents(
wellness_docs,
embeddings,
location=":memory:",
collection_name="wellness_guide",
)
naive_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
semantic_chunker = SemanticChunker(embeddings, breakpoint_threshold_type="percentile")
semantic_documents = semantic_chunker.split_documents(raw_docs)
semantic_vectorstore = QdrantVectorStore.from_documents(
semantic_documents,
embeddings,
location=":memory:",
collection_name="wellness_guide_semantic_chunks",
)
semantic_retriever = semantic_vectorstore.as_retriever(search_kwargs={"k": 10})
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
client = QdrantClient(location=":memory:")
client.create_collection(
collection_name="wellness_parent_child",
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE),
)
parent_document_vectorstore = QdrantVectorStore(
collection_name="wellness_parent_child",
embedding=OpenAIEmbeddings(model="text-embedding-3-small"),
client=client,
)
store = InMemoryStore()
parent_document_retriever = ParentDocumentRetriever(
vectorstore=parent_document_vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
parent_document_retriever.add_documents(raw_docs, ids=None)
bm25_retriever = BM25Retriever.from_documents(wellness_docs)
compressor = CohereRerank(model="rerank-v3.5")
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=naive_retriever
)
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=naive_retriever, llm=chat_model
)
retriever_list = [
bm25_retriever,
naive_retriever,
parent_document_retriever,
compression_retriever,
multi_query_retriever,
]
equal_weighting = [1 / len(retriever_list)] * len(retriever_list)
ensemble_retriever = EnsembleRetriever(
retrievers=retriever_list, weights=equal_weighting
)
naive_retrieval_chain = (
{
"context": itemgetter("question") | naive_retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | chat_model, "context": itemgetter("context")}
)
bm25_retrieval_chain = (
{
"context": itemgetter("question") | bm25_retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | chat_model, "context": itemgetter("context")}
)
contextual_compression_retrieval_chain = (
{
"context": itemgetter("question") | compression_retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | chat_model, "context": itemgetter("context")}
)
multi_query_retrieval_chain = (
{
"context": itemgetter("question") | multi_query_retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | chat_model, "context": itemgetter("context")}
)
parent_document_retrieval_chain = (
{
"context": itemgetter("question") | parent_document_retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | chat_model, "context": itemgetter("context")}
)
ensemble_retrieval_chain = (
{
"context": itemgetter("question") | ensemble_retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | chat_model, "context": itemgetter("context")}
)
semantic_retrieval_chain = (
{
"context": itemgetter("question") | semantic_retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": rag_prompt | chat_model, "context": itemgetter("context")}
)
naive_retrieval_chain.invoke(
{"question": "What exercises can help with lower back pain?"}
)["response"].content
naive_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[
"response"
].content
naive_retrieval_chain.invoke(
{"question": "What are some natural remedies for stress and headaches?"}
)["response"].content
bm25_retrieval_chain.invoke(
{"question": "What exercises can help with lower back pain?"}
)["response"].content
bm25_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[
"response"
].content
bm25_retrieval_chain.invoke(
{"question": "What are some natural remedies for stress and headaches?"}
)["response"].content
contextual_compression_retrieval_chain.invoke(
{"question": "What exercises can help with lower back pain?"}
)["response"].content
contextual_compression_retrieval_chain.invoke(
{"question": "How does sleep affect overall health?"}
)["response"].content
contextual_compression_retrieval_chain.invoke(
{"question": "What are some natural remedies for stress and headaches?"}
)["response"].content
multi_query_retrieval_chain.invoke(
{"question": "What exercises can help with lower back pain?"}
)["response"].content
multi_query_retrieval_chain.invoke(
{"question": "How does sleep affect overall health?"}
)["response"].content
multi_query_retrieval_chain.invoke(
{"question": "What are some natural remedies for stress and headaches?"}
)["response"].content
parent_document_retrieval_chain.invoke(
{"question": "What exercises can help with lower back pain?"}
)["response"].content
parent_document_retrieval_chain.invoke(
{"question": "How does sleep affect overall health?"}
)["response"].content
parent_document_retrieval_chain.invoke(
{"question": "What are some natural remedies for stress and headaches?"}
)["response"].content
ensemble_retrieval_chain.invoke(
{"question": "What exercises can help with lower back pain?"}
)["response"].content
ensemble_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[
"response"
].content
ensemble_retrieval_chain.invoke(
{"question": "What are some natural remedies for stress and headaches?"}
)["response"].content
semantic_retrieval_chain.invoke(
{"question": "What exercises can help with lower back pain?"}
)["response"].content
semantic_retrieval_chain.invoke({"question": "How does sleep affect overall health?"})[
"response"
].content
semantic_retrieval_chain.invoke(
{"question": "What are some natural remedies for stress and headaches?"}
)["response"].content
# --- 1) Chain registry (use your existing chain objects) ---
CHAINS = {
"naive": naive_retrieval_chain,
"bm25": bm25_retrieval_chain,
"compression": contextual_compression_retrieval_chain,
"multi_query": multi_query_retrieval_chain,
"parent_doc": parent_document_retrieval_chain,
"ensemble": ensemble_retrieval_chain,
"semantic": semantic_retrieval_chain,
}
# --- 2) Minimal helpers to normalize outputs ---
def _to_text(resp_dict):
"""Your chains return {'response': <AIMessage|str>, 'context': [...] }."""
r = resp_dict.get("response")
if hasattr(r, "content"): # AIMessage
return r.content
return str(r) if r is not None else ""
def _to_context(resp_dict):
return resp_dict.get("context", [])
# --- 3) Run a single question across selected chains ---
def run_all(question: str, chains=CHAINS):
results = {}
for name, ch in chains.items():
out = ch.invoke({"question": question})
results[name] = {
"answer": _to_text(out),
"contexts": _to_context(out),
}
return results
# --- 4) Convenience: quick pretty print for ad-hoc inspection ---
def print_quick(results, max_len=200):
for name, rec in results.items():
ans = rec["answer"].strip().replace("\n", " ")
print(f"[{name}] {ans[:max_len]}{'…' if len(ans) > max_len else ''}")
# single question across all chains
res = run_all("What exercises can help with lower back pain?")
print_quick(res)
def run_batch(questions, chains=CHAINS):
"""Returns: dict[chain_name] -> list of {question, answer, contexts}."""
payloads = [{"question": q} for q in questions]
all_results = {}
for name, ch in chains.items():
outs = ch.batch(payloads)
all_results[name] = [
{
"question": q["question"],
"answer": _to_text(o),
"contexts": _to_context(o),
}
for q, o in zip(payloads, outs)
]
return all_results
def print_results(all_results, max_answer=150, max_context=100, max_ctxs=2):
"""Nicely print abridged chain results for inspection."""
for name, records in all_results.items():
print(f"\n=== {name.upper()} ===")
for rec in records:
print(f"Q: {rec['question']}")
ans = rec["answer"].strip().replace("\n", " ")
print(f"A: {ans[:max_answer]}{'…' if len(ans) > max_answer else ''}")
ctxs = rec["contexts"][:max_ctxs]
for i, c in enumerate(ctxs, 1):
snippet = c.page_content.strip().replace("\n", " ")
print(
f" [ctx{i}] {snippet[:max_context]}{'…' if len(snippet) > max_context else ''}"
)
print() # blank line between questions
QUESTIONS = [
"What exercises can help with lower back pain?",
"How does sleep affect overall health?",
]
batched_results = run_batch(QUESTIONS)
print_results(batched_results)
# ============================================================
# PHASE 2: LangSmith Evaluation Harness
# ============================================================
#
# Three-step process:
# A. Create (or reuse) a benchmark dataset in LangSmith
# B. Define evaluator functions (QA correctness + context relevance)
# C. Run the harness across all chains and print a summary
# ============================================================
from langsmith import Client # noqa: E402
from langsmith.evaluation import evaluate # noqa: E402
from langchain.evaluation.qa import ContextQAEvalChain, QAEvalChain # noqa: E402
# ---------------------------------------------------------------------------
# A. Dataset Creation
# ---------------------------------------------------------------------------
ls_client = Client()
DATASET_NAME = "Health Wellness Benchmark v1"
BENCHMARK_EXAMPLES = [
{
"inputs": {"question": "What exercises can help with lower back pain?"},
"outputs": {
"answer": (
"Exercises that can help with lower back pain include the Cat-Cow "
"Stretch for spinal flexibility, Bird Dog for core stability, and "
"Pelvic Tilts to strengthen the lower back. Child's Pose is also "
"recommended for gentle stretching, and Bridges help strengthen "
"the glutes and lower back."
)
},
},
{
"inputs": {"question": "How does sleep affect overall health?"},
"outputs": {
"answer": (
"Sleep is essential for overall health. It allows the body to "
"repair tissues, consolidate memories, and regulate hormones. "
"Poor sleep is linked to increased risk of heart disease, obesity, "
"weakened immunity, and impaired cognitive function. Adults should "
"aim for 7-9 hours of quality sleep per night."
)
},
},
{
"inputs": {
"question": "What are some natural remedies for stress and headaches?"
},
"outputs": {
"answer": (
"Natural remedies for stress include deep breathing exercises, "
"meditation, regular physical activity, and herbal teas such as "
"chamomile. For headaches, staying hydrated, applying peppermint "
"oil to the temples, practicing relaxation techniques, and "
"ensuring adequate sleep can provide relief."
)
},
},
]
if ls_client.has_dataset(dataset_name=DATASET_NAME):
print(f"Dataset '{DATASET_NAME}' already exists — skipping creation.")
else:
dataset = ls_client.create_dataset(
dataset_name=DATASET_NAME,
description="Benchmark questions for the Health & Wellness RAG chains.",
)
ls_client.create_examples(
inputs=[ex["inputs"] for ex in BENCHMARK_EXAMPLES],
outputs=[ex["outputs"] for ex in BENCHMARK_EXAMPLES],
dataset_id=dataset.id,
)
print(f"Created dataset '{DATASET_NAME}' with {len(BENCHMARK_EXAMPLES)} examples.")
# ---------------------------------------------------------------------------
# B. Evaluators
# ---------------------------------------------------------------------------
eval_llm = ChatOpenAI(model="gpt-4.1-mini", temperature=0)
def _extract_text(value):
"""Defensively extract a plain string from an AIMessage, dict, or str."""
if value is None:
return ""
if hasattr(value, "content"): # AIMessage
return value.content
if isinstance(value, dict) and "content" in value:
return value["content"]
return str(value)
def _extract_context(value):
"""Join retrieved documents into a single context string."""
if not value:
return ""
pieces = []
for doc in value:
if hasattr(doc, "page_content"):
pieces.append(doc.page_content)
elif isinstance(doc, dict) and "page_content" in doc:
pieces.append(doc["page_content"])
elif isinstance(doc, str):
pieces.append(doc)
return "\n\n".join(pieces)
_qa_chain = QAEvalChain.from_llm(eval_llm)
_ctx_chain = ContextQAEvalChain.from_llm(eval_llm)
def qa_correctness(outputs: dict, reference_outputs: dict, inputs: dict) -> dict:
"""Compare generated answer against reference answer for correctness."""
prediction = _extract_text(outputs.get("response"))
reference = _extract_text(reference_outputs.get("answer", ""))
question = inputs.get("question", "")
eval_result = _qa_chain.evaluate_strings(
prediction=prediction,
reference=reference,
input=question,
)
score = 1.0 if eval_result.get("value", "").upper() == "CORRECT" else 0.0
return {"key": "qa_correctness", "score": score}
def context_relevance(outputs: dict, reference_outputs: dict, inputs: dict) -> dict:
"""Check whether retrieved context supports the generated answer."""
prediction = _extract_text(outputs.get("response"))
reference = _extract_text(reference_outputs.get("answer", ""))
context_str = _extract_context(outputs.get("context"))
question = inputs.get("question", "")
eval_result = _ctx_chain.evaluate_strings(
prediction=prediction,
reference=reference,
input=question,
context=context_str,
)
score = 1.0 if eval_result.get("value", "").upper() == "CORRECT" else 0.0
return {"key": "context_relevance", "score": score}
# ---------------------------------------------------------------------------
# C. Execution Harness
# ---------------------------------------------------------------------------
print("\n" + "=" * 60)
print("Running LangSmith evaluation harness across all chains...")
print("=" * 60)
all_experiment_results = {}
for chain_name, chain in CHAINS.items():
print(f"\nEvaluating chain: {chain_name}")
results = evaluate(
chain.invoke,
data=DATASET_NAME,
evaluators=[qa_correctness, context_relevance],
experiment_prefix=f"retrieval-{chain_name}",
metadata={
"chain": chain_name,
"eval_model": "gpt-4.1-mini",
"rag_model": "gpt-4.1-nano",
},
max_concurrency=1,
)
all_experiment_results[chain_name] = results
# ---------------------------------------------------------------------------
# D. Summary Output
# ---------------------------------------------------------------------------
print("\n" + "=" * 60)
print("EVALUATION SUMMARY")
print("=" * 60)
for chain_name, results in all_experiment_results.items():
scores_by_key = {}
for result in results:
for eval_res in result.get("evaluation_results", {}).get("results", []):
key = eval_res.key
if eval_res.score is not None:
scores_by_key.setdefault(key, []).append(eval_res.score)
print(f"\n--- {chain_name} ---")
if not scores_by_key:
print(" (no scores collected)")
for metric, scores in sorted(scores_by_key.items()):
avg = sum(scores) / len(scores) if scores else 0.0
print(f" {metric}: {avg:.2f} ({len(scores)} examples)")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment