Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save raghunandankavi2010/aaa5b83f58276a699ca1ea37a9e4c34b to your computer and use it in GitHub Desktop.

Select an option

Save raghunandankavi2010/aaa5b83f58276a699ca1ea37a9e4c34b to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Medical RAG Comparative Study - Full Implementation
Author: Raghunandan Kavi
Institution: Liverpool John Moores University
Dataset: MedQuAD (Medical Question Answering Dataset)
"""
import os
import sys
import json
import time
import zipfile
import requests
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import List, Dict, Tuple, Any
from dataclasses import dataclass
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
os.environ['OPENAI_API_KEY'] = 'OPENAI_API_KEY'
# Latest LangChain imports (modern package layout)
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
# Evaluation frameworks
try:
from ragas import evaluate
from ragas.metrics import Faithfulness, AnswerRelevancy, ContextPrecision, ContextRecall
RAGAS_AVAILABLE = True
except ImportError:
RAGAS_AVAILABLE = False
print("Warning: RAGAS not installed. Running with custom metrics only.")
try:
from deepeval import evaluate as deepeval_evaluate
from deepeval.metrics import HallucinationMetric, AnswerRelevancyMetric
DEEPEVAL_AVAILABLE = True
except ImportError:
DEEPEVAL_AVAILABLE = False
print("Warning: DeepEval not installed.")
# Download required NLTK data
import nltk
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
# Configuration
@dataclass
class Config:
DATA_DIR: str = "./medquad_data"
VECTORDB_DIR: str = "./chroma_db"
RESULTS_DIR: str = "./results"
CHUNK_SIZE: int = 400
CHUNK_OVERLAP: int = 50
EMBEDDING_MODEL: str = "text-embedding-3-large"
LLM_MODEL: str = "gpt-4o-mini"
TEMPERATURE: float = 0.0
TOP_K: int = 5
EVAL_SAMPLE_SIZE: int = 100 # Reduced for API cost control; increase for full study
config = Config()
# Create directories
for dir_path in [config.DATA_DIR, config.VECTORDB_DIR, config.RESULTS_DIR]:
os.makedirs(dir_path, exist_ok=True)
class MedQuADLoader:
"""Handles downloading and parsing of MedQuAD dataset."""
def __init__(self, data_dir: str):
self.data_dir = Path(data_dir)
self.csv_path = self.data_dir / "medquad.csv"
def download(self) -> None:
"""Download MedQuAD from GitHub repository."""
if self.csv_path.exists():
print(f"Dataset already exists at {self.csv_path}")
return
print("Downloading MedQuAD dataset...")
# Clone the repository
repo_url = "https://github.com/abachaa/MedQuAD.git"
clone_dir = self.data_dir / "MedQuAD_repo"
if not clone_dir.exists():
import subprocess
subprocess.run(["git", "clone", repo_url, str(clone_dir)], check=True)
# Parse XML files and convert to CSV
data = []
xml_dir = clone_dir / "MedQuAD-master" # Adjust based on actual structure
if not xml_dir.exists():
xml_dir = clone_dir # Fallback
print(f"Parsing XML files from {xml_dir}...")
for xml_file in xml_dir.rglob("*.xml"):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Extract QA pairs based on MedQuAD structure
focus = root.findtext(".//Focus", default="")
for qa in root.findall(".//QAPair"):
question = qa.findtext("Question", default="")
answer = qa.findtext("Answer", default="")
qtype = qa.get("qtype", "general")
if question and answer:
data.append({
"focus": focus,
"question": question.strip(),
"answer": answer.strip(),
"qtype": qtype,
"source_file": xml_file.name
})
except Exception as e:
print(f"Error parsing {xml_file}: {e}")
continue
# Save to CSV
df = pd.DataFrame(data)
df.to_csv(self.csv_path, index=False)
print(f"Saved {len(df)} QA pairs to {self.csv_path}")
# Cleanup
if clone_dir.exists():
import shutil
shutil.rmtree(clone_dir)
def load(self) -> pd.DataFrame:
"""Load and preprocess the dataset."""
if not self.csv_path.exists():
self.download()
df = pd.read_csv(self.csv_path)
# Clean data
df = df.dropna(subset=["question", "answer"])
df = df[df["question"].str.len() > 10]
df = df[df["answer"].str.len() > 20]
# Add metadata
df["answer_len"] = df["answer"].apply(lambda x: len(str(x).split()))
df["question_len"] = df["question"].apply(lambda x: len(str(x).split()))
return df
class DocumentProcessor:
"""Handles document chunking strategies."""
def __init__(self, chunk_size: int = config.CHUNK_SIZE,
chunk_overlap: int = config.CHUNK_OVERLAP):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def prepare_documents(self, df: pd.DataFrame) -> List[Document]:
"""Convert DataFrame to LangChain Documents."""
docs = []
for _, row in df.iterrows():
content = (
f"FOCUS: {row['focus']}\n"
f"QTYPE: {row['qtype']}\n"
f"QUESTION: {row['question']}\n"
f"ANSWER: {row['answer']}"
)
metadata = {
"focus": row.get("focus", ""),
"qtype": row.get("qtype", ""),
"source": row.get("source_file", ""),
"question": row["question"],
"gold_answer": row["answer"]
}
docs.append(Document(page_content=content, metadata=metadata))
return docs
def chunk_recursive(self, docs: List[Document]) -> List[Document]:
"""Recursive character text splitting."""
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
separators=["\n\n", "\n", ".", " ", ""],
length_function=len
)
return splitter.split_documents(docs)
def chunk_fixed(self, docs: List[Document]) -> List[Document]:
"""Fixed-size character splitting."""
splitter = CharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
separator=" "
)
return splitter.split_documents(docs)
def chunk_sentence(self, docs: List[Document]) -> List[Document]:
"""Sentence-aware splitting."""
chunks = []
for doc in docs:
sentences = nltk.sent_tokenize(doc.page_content)
current_chunk = []
current_length = 0
for sent in sentences:
sent_len = len(sent)
if current_length + sent_len > self.chunk_size and current_chunk:
chunk_text = " ".join(current_chunk)
chunks.append(Document(page_content=chunk_text, metadata=doc.metadata))
# Keep overlap
current_chunk = current_chunk[-2:] if len(current_chunk) > 2 else []
current_length = sum(len(s) for s in current_chunk)
current_chunk.append(sent)
current_length += sent_len
if current_chunk:
chunk_text = " ".join(current_chunk)
chunks.append(Document(page_content=chunk_text, metadata=doc.metadata))
return chunks
class RetrieverFactory:
"""Factory for creating different retrieval strategies."""
@staticmethod
def create_standard_retriever(vectorstore: Chroma, k: int = 5) -> BaseRetriever:
"""Standard dense retrieval with MMR."""
return vectorstore.as_retriever(
search_type="mmr",
search_kwargs={"k": k, "fetch_k": k * 4}
)
@staticmethod
def create_tfidf_retriever(chunks: List[Document], k: int = 5):
"""TF-IDF based keyword retriever."""
texts = [c.page_content for c in chunks]
vectorizer = TfidfVectorizer(stop_words='english', max_features=10000)
doc_matrix = vectorizer.fit_transform(texts)
class TFIDFRetriever(BaseRetriever):
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
query_vec = vectorizer.transform([query])
similarities = cosine_similarity(query_vec, doc_matrix)[0]
top_indices = np.argsort(similarities)[::-1][:k]
return [chunks[i] for i in top_indices]
async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]:
return self._get_relevant_documents(query, **kwargs)
return TFIDFRetriever()
@staticmethod
def create_hybrid_retriever(vectorstore: Chroma, chunks: List[Document], k: int = 5):
"""Hybrid: Combine semantic and TF-IDF with Reciprocal Rank Fusion."""
semantic_retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
tfidf_retriever = RetrieverFactory.create_tfidf_retriever(chunks, k=k)
class HybridRetriever(BaseRetriever):
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# Get results from both retrievers
semantic_docs = semantic_retriever.invoke(query)
tfidf_docs = tfidf_retriever.invoke(query)
# Reciprocal Rank Fusion
doc_scores = defaultdict(float)
doc_map = {}
for rank, doc in enumerate(semantic_docs):
key = doc.page_content[:200]
doc_scores[key] += 1.0 / (rank + 60) # k=60 constant
doc_map[key] = doc
for rank, doc in enumerate(tfidf_docs):
key = doc.page_content[:200]
doc_scores[key] += 1.0 / (rank + 60)
if key not in doc_map:
doc_map[key] = doc
# Sort by fused score
sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
return [doc_map[key] for key, _ in sorted_docs[:k]]
async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]:
return self._get_relevant_documents(query, **kwargs)
return HybridRetriever()
@staticmethod
def create_multiquery_retriever(chunks: List[Document], k: int = 5):
"""Multi-Query Expansion using TF-IDF term expansion."""
texts = [c.page_content for c in chunks]
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
doc_matrix = vectorizer.fit_transform(texts)
class MultiQueryRetriever(BaseRetriever):
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# Expand query using top TF-IDF terms
try:
query_vec = vectorizer.transform([query]).toarray()[0]
feature_names = vectorizer.get_feature_names_out()
top_indices = query_vec.argsort()[::-1][:3]
expansion_terms = [feature_names[i] for i in top_indices if query_vec[i] > 0]
expanded_queries = [query] + [f"{query} {term}" for term in expansion_terms]
except:
expanded_queries = [query]
# Retrieve for each query and merge
seen = set()
merged = []
for q in expanded_queries:
q_vec = vectorizer.transform([q])
sims = cosine_similarity(q_vec, doc_matrix)[0]
top_indices = np.argsort(sims)[::-1][:k]
for idx in top_indices:
doc = chunks[idx]
key = doc.page_content[:200]
if key not in seen:
seen.add(key)
merged.append(doc)
if len(merged) >= k:
break
return merged[:k]
async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]:
return self._get_relevant_documents(query, **kwargs)
return MultiQueryRetriever()
@staticmethod
def create_reranking_retriever(vectorstore: Chroma, llm: ChatOpenAI, k: int = 5):
"""Query Reformulation + Broad Retrieval + Reranking."""
broad_retriever = vectorstore.as_retriever(search_kwargs={"k": 20})
class RerankingRetriever(BaseRetriever):
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# Step 1: Reformulate query
reformulate_prompt = ChatPromptTemplate.from_messages([
("system", "Reformulate the following medical query to be more specific and clinical for information retrieval. Use medical terminology."),
("human", "Original: {query}\nReformulated:")
])
try:
chain = reformulate_prompt | llm
result = chain.invoke({"query": query})
reformulated = result.content.strip()
except:
reformulated = query
# Step 2: Broad retrieval
broad_docs = broad_retriever.invoke(reformulated)
# Step 3: Simple reranking using cross-encoder style scoring (using LLM for scoring simulation)
# In production, use a dedicated cross-encoder model like BAAI/bge-reranker
if len(broad_docs) <= k:
return broad_docs
# Simple heuristic reranking: prefer docs with query term overlap
query_terms = set(reformulated.lower().split())
scored_docs = []
for doc in broad_docs:
doc_terms = set(doc.page_content.lower().split())
score = len(query_terms & doc_terms) / max(len(query_terms), 1)
scored_docs.append((score, doc))
scored_docs.sort(key=lambda x: x[0], reverse=True)
return [doc for _, doc in scored_docs[:k]]
async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]:
return self._get_relevant_documents(query, **kwargs)
return RerankingRetriever()
class MedicalRAGPipeline:
"""Complete RAG pipeline with evaluation."""
def __init__(self, df: pd.DataFrame):
self.df = df
self.processor = DocumentProcessor()
self.embeddings = OpenAIEmbeddings(model=config.EMBEDDING_MODEL)
self.llm = ChatOpenAI(
model=config.LLM_MODEL,
temperature=config.TEMPERATURE,
max_tokens=512
)
# Prepare documents
print("Preparing documents...")
self.raw_docs = self.processor.prepare_documents(df)
# System prompt for safety
self.system_prompt = """You are a medical information assistant. Use only the provided context to answer the question.
If the answer is not in the context, state clearly that you do not have sufficient information.
Do not provide specific medical advice, dosage recommendations, or diagnoses.
Always encourage consulting healthcare professionals for personal medical decisions.
Context: {context}"""
def create_vectorstore(self, chunks: List[Document], persist_dir: str) -> Chroma:
"""Create or load Chroma vector store."""
if os.path.exists(persist_dir) and os.listdir(persist_dir):
print(f"Loading existing vector store from {persist_dir}")
return Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
print(f"Creating new vector store at {persist_dir}")
vectorstore = Chroma.from_documents(
documents=chunks,
embedding=self.embeddings,
persist_directory=persist_dir
)
return vectorstore
def build_pipeline(self, strategy: str, chunk_method: str = "recursive") -> Tuple[Any, List[Document]]:
"""Build specific RAG pipeline."""
print(f"\nBuilding pipeline: {strategy} with {chunk_method} chunking")
# Chunking
if chunk_method == "recursive":
chunks = self.processor.chunk_recursive(self.raw_docs)
elif chunk_method == "fixed":
chunks = self.processor.chunk_fixed(self.raw_docs)
elif chunk_method == "sentence":
chunks = self.processor.chunk_sentence(self.raw_docs)
else:
chunks = self.processor.chunk_recursive(self.raw_docs)
persist_dir = f"{config.VECTORDB_DIR}/{chunk_method}"
vectorstore = self.create_vectorstore(chunks, persist_dir)
# Create retriever based on strategy
if strategy == "vanilla":
retriever = None
elif strategy == "standard":
retriever = RetrieverFactory.create_standard_retriever(vectorstore)
elif strategy == "hybrid":
retriever = RetrieverFactory.create_hybrid_retriever(vectorstore, chunks)
elif strategy == "multiquery":
retriever = RetrieverFactory.create_multiquery_retriever(chunks)
elif strategy == "reranking":
retriever = RetrieverFactory.create_reranking_retriever(vectorstore, self.llm)
else:
raise ValueError(f"Unknown strategy: {strategy}")
# Build chain
if strategy == "vanilla":
# No retrieval, direct generation
prompt = ChatPromptTemplate.from_messages([
("system", "You are a medical information assistant. Answer based on your training knowledge, but acknowledge uncertainty. Encourage consulting healthcare professionals."),
("human", "{input}")
])
chain = prompt | self.llm
else:
prompt = ChatPromptTemplate.from_messages([
("system", self.system_prompt),
("human", "{input}")
])
doc_chain = create_stuff_documents_chain(self.llm, prompt)
chain = create_retrieval_chain(retriever, doc_chain)
return chain, chunks
def evaluate_semantic(self, ground_truth: str, prediction: str) -> float:
"""Calculate semantic similarity between answers."""
if not ground_truth or not prediction:
return 0.0
# Simple word overlap metric (proxy for semantic similarity)
gt_words = set(ground_truth.lower().split())
pred_words = set(prediction.lower().split())
if not gt_words:
return 0.0
intersection = len(gt_words & pred_words)
return intersection / len(gt_words)
def evaluate_faithfulness(self, answer: str, contexts: List[Document]) -> float:
"""Check if answer claims are supported by context."""
if not contexts:
return 0.0
context_text = " ".join([c.page_content for c in contexts]).lower()
answer_sentences = [s.strip() for s in answer.split('.') if len(s.strip()) > 10]
if not answer_sentences:
return 0.0
supported = 0
for sent in answer_sentences:
# Check if key terms from sentence appear in context
sent_words = set(sent.lower().split())
overlap = len(sent_words & set(context_text.split()))
if overlap / max(len(sent_words), 1) > 0.3: # Threshold
supported += 1
return supported / len(answer_sentences)
def run_evaluation(self, strategies: List[str], eval_samples: int = config.EVAL_SAMPLE_SIZE):
"""Run comparative evaluation across strategies."""
results = []
# Sample evaluation set
eval_df = self.df.sample(n=min(eval_samples, len(self.df)), random_state=42)
eval_set = [
{"question": row["question"], "answer": row["answer"], "qtype": row.get("qtype", "general")}
for _, row in eval_df.iterrows()
]
for strategy in strategies:
print(f"\n{'='*60}")
print(f"Evaluating Strategy: {strategy.upper()}")
print(f"{'='*60}")
try:
chain, chunks = self.build_pipeline(strategy)
strategy_results = {
"strategy": strategy,
"correct": 0,
"faithfulness_scores": [],
"semantic_scores": [],
"latencies": [],
"hallucination_flags": []
}
for i, sample in enumerate(eval_set):
question = sample["question"]
ground_truth = sample["answer"]
# Measure latency
start_time = time.time()
try:
if strategy == "vanilla":
response = chain.invoke({"input": question})
answer = response.content
contexts = []
else:
response = chain.invoke({"input": question})
answer = response["answer"]
contexts = response.get("context", [])
latency = (time.time() - start_time) * 1000 # ms
# Calculate metrics
semantic_sim = self.evaluate_semantic(ground_truth, answer)
faithfulness = self.evaluate_faithfulness(answer, contexts) if strategy != "vanilla" else 0.0
# Heuristic hallucination detection
is_hallucination = faithfulness < 0.5 if strategy != "vanilla" else semantic_sim < 0.3
strategy_results["semantic_scores"].append(semantic_sim)
strategy_results["faithfulness_scores"].append(faithfulness)
strategy_results["latencies"].append(latency)
strategy_results["hallucination_flags"].append(1 if is_hallucination else 0)
# Check correctness (simplified)
if semantic_sim > 0.5:
strategy_results["correct"] += 1
if (i + 1) % 20 == 0:
print(f" Processed {i+1}/{len(eval_set)} samples...")
except Exception as e:
print(f" Error on sample {i}: {e}")
continue
# Aggregate results
avg_faithfulness = np.mean(strategy_results["faithfulness_scores"]) if strategy_results["faithfulness_scores"] else 0
avg_semantic = np.mean(strategy_results["semantic_scores"]) if strategy_results["semantic_scores"] else 0
avg_latency = np.mean(strategy_results["latencies"]) if strategy_results["latencies"] else 0
hallucination_rate = np.mean(strategy_results["hallucination_flags"]) if strategy_results["hallucination_flags"] else 0
accuracy = strategy_results["correct"] / len(eval_set)
results.append({
"Strategy": strategy,
"Accuracy": accuracy,
"Faithfulness": avg_faithfulness,
"Semantic_Similarity": avg_semantic,
"Hallucination_Rate": hallucination_rate,
"Latency_ms": avg_latency
})
print(f"\nResults for {strategy}:")
print(f" Accuracy: {accuracy:.3f}")
print(f" Faithfulness: {avg_faithfulness:.3f}")
print(f" Hallucination Rate: {hallucination_rate:.3f}")
print(f" Avg Latency: {avg_latency:.0f}ms")
except Exception as e:
print(f"Failed to evaluate {strategy}: {e}")
continue
return pd.DataFrame(results)
def generate_visualizations(self, results_df: pd.DataFrame):
"""Generate comparison charts."""
if results_df.empty:
print("No results to visualize")
return
plt.figure(figsize=(15, 10))
# Plot 1: Performance Metrics Comparison
plt.subplot(2, 2, 1)
metrics = ["Accuracy", "Faithfulness", "Semantic_Similarity"]
x = np.arange(len(results_df))
width = 0.25
for i, metric in enumerate(metrics):
plt.bar(x + i*width, results_df[metric], width, label=metric)
plt.xlabel('Strategy')
plt.ylabel('Score')
plt.title('Performance Metrics Comparison')
plt.xticks(x + width, results_df["Strategy"], rotation=45)
plt.legend()
plt.ylim(0, 1)
# Plot 2: Hallucination Rate
plt.subplot(2, 2, 2)
colors = ['red' if x > 0.2 else 'orange' if x > 0.1 else 'green' for x in results_df["Hallucination_Rate"]]
plt.bar(results_df["Strategy"], results_df["Hallucination_Rate"], color=colors)
plt.xlabel('Strategy')
plt.ylabel('Hallucination Rate')
plt.title('Hallucination Rate by Strategy (Lower is Better)')
plt.xticks(rotation=45)
# Plot 3: Latency Comparison
plt.subplot(2, 2, 3)
plt.bar(results_df["Strategy"], results_df["Latency_ms"], color='skyblue')
plt.xlabel('Strategy')
plt.ylabel('Latency (ms)')
plt.title('Response Latency by Strategy')
plt.xticks(rotation=45)
# Plot 4: Safety vs Speed Trade-off
plt.subplot(2, 2, 4)
plt.scatter(results_df["Latency_ms"], 1 - results_df["Hallucination_Rate"],
s=200, alpha=0.6, c=range(len(results_df)), cmap='viridis')
for i, txt in enumerate(results_df["Strategy"]):
plt.annotate(txt, (results_df["Latency_ms"].iloc[i], 1 - results_df["Hallucination_Rate"].iloc[i]),
xytext=(5, 5), textcoords='offset points', fontsize=8)
plt.xlabel('Latency (ms)')
plt.ylabel('Safety Score (1 - Hallucination Rate)')
plt.title('Safety vs. Speed Trade-off')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{config.RESULTS_DIR}/rag_comparison.png", dpi=300, bbox_inches='tight')
print(f"\nSaved visualization to {config.RESULTS_DIR}/rag_comparison.png")
plt.show()
# Save results table
results_df.to_csv(f"{config.RESULTS_DIR}/rag_results.csv", index=False)
print(f"Saved results to {config.RESULTS_DIR}/rag_results.csv")
def main():
"""Main execution function."""
print("Medical RAG Comparative Study")
print("=" * 60)
# Load data
loader = MedQuADLoader(config.DATA_DIR)
df = loader.load()
print(f"Loaded {len(df)} medical QA pairs")
# Initialize pipeline
pipeline = MedicalRAGPipeline(df)
# Define strategies to evaluate
strategies = [
"vanilla", # Baseline
"standard", # Basic RAG
"multiquery", # Multi-Query Expansion
"hybrid", # Hybrid Retrieval
"reranking" # Reranking + Reformulation
]
# Run evaluation
results_df = pipeline.run_evaluation(strategies, eval_samples=config.EVAL_SAMPLE_SIZE)
# Generate visualizations
pipeline.generate_visualizations(results_df)
# Print final summary
print("\n" + "="*60)
print("FINAL SUMMARY")
print("="*60)
print(results_df.to_string(index=False))
# Identify best strategy
if not results_df.empty:
best_safety = results_df.loc[results_df['Hallucination_Rate'].idxmin()]
best_speed = results_df.loc[results_df['Latency_ms'].idxmin()]
best_overall = results_df.loc[(results_df['Faithfulness'] * 0.5 +
(1 - results_df['Hallucination_Rate']) * 0.5).idxmax()]
print(f"\nBest Safety (Lowest Hallucination): {best_safety['Strategy']} ({best_safety['Hallucination_Rate']:.1%})")
print(f"Best Speed (Lowest Latency): {best_speed['Strategy']} ({best_speed['Latency_ms']:.0f}ms)")
print(f"Best Overall Balance: {best_overall['Strategy']}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment