Created
April 19, 2026 05:37
-
-
Save raghunandankavi2010/aaa5b83f58276a699ca1ea37a9e4c34b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Medical RAG Comparative Study - Full Implementation | |
| Author: Raghunandan Kavi | |
| Institution: Liverpool John Moores University | |
| Dataset: MedQuAD (Medical Question Answering Dataset) | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import zipfile | |
| import requests | |
| import xml.etree.ElementTree as ET | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Any | |
| from dataclasses import dataclass | |
| from collections import Counter, defaultdict | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| os.environ['OPENAI_API_KEY'] = 'OPENAI_API_KEY' | |
| # Latest LangChain imports (modern package layout) | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_core.documents import Document | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter, CharacterTextSplitter | |
| from langchain.chains.retrieval import create_retrieval_chain | |
| from langchain.chains.combine_documents.stuff import create_stuff_documents_chain | |
| # Evaluation frameworks | |
| try: | |
| from ragas import evaluate | |
| from ragas.metrics import Faithfulness, AnswerRelevancy, ContextPrecision, ContextRecall | |
| RAGAS_AVAILABLE = True | |
| except ImportError: | |
| RAGAS_AVAILABLE = False | |
| print("Warning: RAGAS not installed. Running with custom metrics only.") | |
| try: | |
| from deepeval import evaluate as deepeval_evaluate | |
| from deepeval.metrics import HallucinationMetric, AnswerRelevancyMetric | |
| DEEPEVAL_AVAILABLE = True | |
| except ImportError: | |
| DEEPEVAL_AVAILABLE = False | |
| print("Warning: DeepEval not installed.") | |
| # Download required NLTK data | |
| import nltk | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt', quiet=True) | |
| nltk.download('punkt_tab', quiet=True) | |
| # Configuration | |
| @dataclass | |
| class Config: | |
| DATA_DIR: str = "./medquad_data" | |
| VECTORDB_DIR: str = "./chroma_db" | |
| RESULTS_DIR: str = "./results" | |
| CHUNK_SIZE: int = 400 | |
| CHUNK_OVERLAP: int = 50 | |
| EMBEDDING_MODEL: str = "text-embedding-3-large" | |
| LLM_MODEL: str = "gpt-4o-mini" | |
| TEMPERATURE: float = 0.0 | |
| TOP_K: int = 5 | |
| EVAL_SAMPLE_SIZE: int = 100 # Reduced for API cost control; increase for full study | |
| config = Config() | |
| # Create directories | |
| for dir_path in [config.DATA_DIR, config.VECTORDB_DIR, config.RESULTS_DIR]: | |
| os.makedirs(dir_path, exist_ok=True) | |
| class MedQuADLoader: | |
| """Handles downloading and parsing of MedQuAD dataset.""" | |
| def __init__(self, data_dir: str): | |
| self.data_dir = Path(data_dir) | |
| self.csv_path = self.data_dir / "medquad.csv" | |
| def download(self) -> None: | |
| """Download MedQuAD from GitHub repository.""" | |
| if self.csv_path.exists(): | |
| print(f"Dataset already exists at {self.csv_path}") | |
| return | |
| print("Downloading MedQuAD dataset...") | |
| # Clone the repository | |
| repo_url = "https://github.com/abachaa/MedQuAD.git" | |
| clone_dir = self.data_dir / "MedQuAD_repo" | |
| if not clone_dir.exists(): | |
| import subprocess | |
| subprocess.run(["git", "clone", repo_url, str(clone_dir)], check=True) | |
| # Parse XML files and convert to CSV | |
| data = [] | |
| xml_dir = clone_dir / "MedQuAD-master" # Adjust based on actual structure | |
| if not xml_dir.exists(): | |
| xml_dir = clone_dir # Fallback | |
| print(f"Parsing XML files from {xml_dir}...") | |
| for xml_file in xml_dir.rglob("*.xml"): | |
| try: | |
| tree = ET.parse(xml_file) | |
| root = tree.getroot() | |
| # Extract QA pairs based on MedQuAD structure | |
| focus = root.findtext(".//Focus", default="") | |
| for qa in root.findall(".//QAPair"): | |
| question = qa.findtext("Question", default="") | |
| answer = qa.findtext("Answer", default="") | |
| qtype = qa.get("qtype", "general") | |
| if question and answer: | |
| data.append({ | |
| "focus": focus, | |
| "question": question.strip(), | |
| "answer": answer.strip(), | |
| "qtype": qtype, | |
| "source_file": xml_file.name | |
| }) | |
| except Exception as e: | |
| print(f"Error parsing {xml_file}: {e}") | |
| continue | |
| # Save to CSV | |
| df = pd.DataFrame(data) | |
| df.to_csv(self.csv_path, index=False) | |
| print(f"Saved {len(df)} QA pairs to {self.csv_path}") | |
| # Cleanup | |
| if clone_dir.exists(): | |
| import shutil | |
| shutil.rmtree(clone_dir) | |
| def load(self) -> pd.DataFrame: | |
| """Load and preprocess the dataset.""" | |
| if not self.csv_path.exists(): | |
| self.download() | |
| df = pd.read_csv(self.csv_path) | |
| # Clean data | |
| df = df.dropna(subset=["question", "answer"]) | |
| df = df[df["question"].str.len() > 10] | |
| df = df[df["answer"].str.len() > 20] | |
| # Add metadata | |
| df["answer_len"] = df["answer"].apply(lambda x: len(str(x).split())) | |
| df["question_len"] = df["question"].apply(lambda x: len(str(x).split())) | |
| return df | |
| class DocumentProcessor: | |
| """Handles document chunking strategies.""" | |
| def __init__(self, chunk_size: int = config.CHUNK_SIZE, | |
| chunk_overlap: int = config.CHUNK_OVERLAP): | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| def prepare_documents(self, df: pd.DataFrame) -> List[Document]: | |
| """Convert DataFrame to LangChain Documents.""" | |
| docs = [] | |
| for _, row in df.iterrows(): | |
| content = ( | |
| f"FOCUS: {row['focus']}\n" | |
| f"QTYPE: {row['qtype']}\n" | |
| f"QUESTION: {row['question']}\n" | |
| f"ANSWER: {row['answer']}" | |
| ) | |
| metadata = { | |
| "focus": row.get("focus", ""), | |
| "qtype": row.get("qtype", ""), | |
| "source": row.get("source_file", ""), | |
| "question": row["question"], | |
| "gold_answer": row["answer"] | |
| } | |
| docs.append(Document(page_content=content, metadata=metadata)) | |
| return docs | |
| def chunk_recursive(self, docs: List[Document]) -> List[Document]: | |
| """Recursive character text splitting.""" | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.chunk_size, | |
| chunk_overlap=self.chunk_overlap, | |
| separators=["\n\n", "\n", ".", " ", ""], | |
| length_function=len | |
| ) | |
| return splitter.split_documents(docs) | |
| def chunk_fixed(self, docs: List[Document]) -> List[Document]: | |
| """Fixed-size character splitting.""" | |
| splitter = CharacterTextSplitter( | |
| chunk_size=self.chunk_size, | |
| chunk_overlap=self.chunk_overlap, | |
| separator=" " | |
| ) | |
| return splitter.split_documents(docs) | |
| def chunk_sentence(self, docs: List[Document]) -> List[Document]: | |
| """Sentence-aware splitting.""" | |
| chunks = [] | |
| for doc in docs: | |
| sentences = nltk.sent_tokenize(doc.page_content) | |
| current_chunk = [] | |
| current_length = 0 | |
| for sent in sentences: | |
| sent_len = len(sent) | |
| if current_length + sent_len > self.chunk_size and current_chunk: | |
| chunk_text = " ".join(current_chunk) | |
| chunks.append(Document(page_content=chunk_text, metadata=doc.metadata)) | |
| # Keep overlap | |
| current_chunk = current_chunk[-2:] if len(current_chunk) > 2 else [] | |
| current_length = sum(len(s) for s in current_chunk) | |
| current_chunk.append(sent) | |
| current_length += sent_len | |
| if current_chunk: | |
| chunk_text = " ".join(current_chunk) | |
| chunks.append(Document(page_content=chunk_text, metadata=doc.metadata)) | |
| return chunks | |
| class RetrieverFactory: | |
| """Factory for creating different retrieval strategies.""" | |
| @staticmethod | |
| def create_standard_retriever(vectorstore: Chroma, k: int = 5) -> BaseRetriever: | |
| """Standard dense retrieval with MMR.""" | |
| return vectorstore.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": k, "fetch_k": k * 4} | |
| ) | |
| @staticmethod | |
| def create_tfidf_retriever(chunks: List[Document], k: int = 5): | |
| """TF-IDF based keyword retriever.""" | |
| texts = [c.page_content for c in chunks] | |
| vectorizer = TfidfVectorizer(stop_words='english', max_features=10000) | |
| doc_matrix = vectorizer.fit_transform(texts) | |
| class TFIDFRetriever(BaseRetriever): | |
| def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| query_vec = vectorizer.transform([query]) | |
| similarities = cosine_similarity(query_vec, doc_matrix)[0] | |
| top_indices = np.argsort(similarities)[::-1][:k] | |
| return [chunks[i] for i in top_indices] | |
| async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| return self._get_relevant_documents(query, **kwargs) | |
| return TFIDFRetriever() | |
| @staticmethod | |
| def create_hybrid_retriever(vectorstore: Chroma, chunks: List[Document], k: int = 5): | |
| """Hybrid: Combine semantic and TF-IDF with Reciprocal Rank Fusion.""" | |
| semantic_retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k}) | |
| tfidf_retriever = RetrieverFactory.create_tfidf_retriever(chunks, k=k) | |
| class HybridRetriever(BaseRetriever): | |
| def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| # Get results from both retrievers | |
| semantic_docs = semantic_retriever.invoke(query) | |
| tfidf_docs = tfidf_retriever.invoke(query) | |
| # Reciprocal Rank Fusion | |
| doc_scores = defaultdict(float) | |
| doc_map = {} | |
| for rank, doc in enumerate(semantic_docs): | |
| key = doc.page_content[:200] | |
| doc_scores[key] += 1.0 / (rank + 60) # k=60 constant | |
| doc_map[key] = doc | |
| for rank, doc in enumerate(tfidf_docs): | |
| key = doc.page_content[:200] | |
| doc_scores[key] += 1.0 / (rank + 60) | |
| if key not in doc_map: | |
| doc_map[key] = doc | |
| # Sort by fused score | |
| sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True) | |
| return [doc_map[key] for key, _ in sorted_docs[:k]] | |
| async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| return self._get_relevant_documents(query, **kwargs) | |
| return HybridRetriever() | |
| @staticmethod | |
| def create_multiquery_retriever(chunks: List[Document], k: int = 5): | |
| """Multi-Query Expansion using TF-IDF term expansion.""" | |
| texts = [c.page_content for c in chunks] | |
| vectorizer = TfidfVectorizer(stop_words='english', max_features=5000) | |
| doc_matrix = vectorizer.fit_transform(texts) | |
| class MultiQueryRetriever(BaseRetriever): | |
| def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| # Expand query using top TF-IDF terms | |
| try: | |
| query_vec = vectorizer.transform([query]).toarray()[0] | |
| feature_names = vectorizer.get_feature_names_out() | |
| top_indices = query_vec.argsort()[::-1][:3] | |
| expansion_terms = [feature_names[i] for i in top_indices if query_vec[i] > 0] | |
| expanded_queries = [query] + [f"{query} {term}" for term in expansion_terms] | |
| except: | |
| expanded_queries = [query] | |
| # Retrieve for each query and merge | |
| seen = set() | |
| merged = [] | |
| for q in expanded_queries: | |
| q_vec = vectorizer.transform([q]) | |
| sims = cosine_similarity(q_vec, doc_matrix)[0] | |
| top_indices = np.argsort(sims)[::-1][:k] | |
| for idx in top_indices: | |
| doc = chunks[idx] | |
| key = doc.page_content[:200] | |
| if key not in seen: | |
| seen.add(key) | |
| merged.append(doc) | |
| if len(merged) >= k: | |
| break | |
| return merged[:k] | |
| async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| return self._get_relevant_documents(query, **kwargs) | |
| return MultiQueryRetriever() | |
| @staticmethod | |
| def create_reranking_retriever(vectorstore: Chroma, llm: ChatOpenAI, k: int = 5): | |
| """Query Reformulation + Broad Retrieval + Reranking.""" | |
| broad_retriever = vectorstore.as_retriever(search_kwargs={"k": 20}) | |
| class RerankingRetriever(BaseRetriever): | |
| def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| # Step 1: Reformulate query | |
| reformulate_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "Reformulate the following medical query to be more specific and clinical for information retrieval. Use medical terminology."), | |
| ("human", "Original: {query}\nReformulated:") | |
| ]) | |
| try: | |
| chain = reformulate_prompt | llm | |
| result = chain.invoke({"query": query}) | |
| reformulated = result.content.strip() | |
| except: | |
| reformulated = query | |
| # Step 2: Broad retrieval | |
| broad_docs = broad_retriever.invoke(reformulated) | |
| # Step 3: Simple reranking using cross-encoder style scoring (using LLM for scoring simulation) | |
| # In production, use a dedicated cross-encoder model like BAAI/bge-reranker | |
| if len(broad_docs) <= k: | |
| return broad_docs | |
| # Simple heuristic reranking: prefer docs with query term overlap | |
| query_terms = set(reformulated.lower().split()) | |
| scored_docs = [] | |
| for doc in broad_docs: | |
| doc_terms = set(doc.page_content.lower().split()) | |
| score = len(query_terms & doc_terms) / max(len(query_terms), 1) | |
| scored_docs.append((score, doc)) | |
| scored_docs.sort(key=lambda x: x[0], reverse=True) | |
| return [doc for _, doc in scored_docs[:k]] | |
| async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]: | |
| return self._get_relevant_documents(query, **kwargs) | |
| return RerankingRetriever() | |
| class MedicalRAGPipeline: | |
| """Complete RAG pipeline with evaluation.""" | |
| def __init__(self, df: pd.DataFrame): | |
| self.df = df | |
| self.processor = DocumentProcessor() | |
| self.embeddings = OpenAIEmbeddings(model=config.EMBEDDING_MODEL) | |
| self.llm = ChatOpenAI( | |
| model=config.LLM_MODEL, | |
| temperature=config.TEMPERATURE, | |
| max_tokens=512 | |
| ) | |
| # Prepare documents | |
| print("Preparing documents...") | |
| self.raw_docs = self.processor.prepare_documents(df) | |
| # System prompt for safety | |
| self.system_prompt = """You are a medical information assistant. Use only the provided context to answer the question. | |
| If the answer is not in the context, state clearly that you do not have sufficient information. | |
| Do not provide specific medical advice, dosage recommendations, or diagnoses. | |
| Always encourage consulting healthcare professionals for personal medical decisions. | |
| Context: {context}""" | |
| def create_vectorstore(self, chunks: List[Document], persist_dir: str) -> Chroma: | |
| """Create or load Chroma vector store.""" | |
| if os.path.exists(persist_dir) and os.listdir(persist_dir): | |
| print(f"Loading existing vector store from {persist_dir}") | |
| return Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) | |
| print(f"Creating new vector store at {persist_dir}") | |
| vectorstore = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=self.embeddings, | |
| persist_directory=persist_dir | |
| ) | |
| return vectorstore | |
| def build_pipeline(self, strategy: str, chunk_method: str = "recursive") -> Tuple[Any, List[Document]]: | |
| """Build specific RAG pipeline.""" | |
| print(f"\nBuilding pipeline: {strategy} with {chunk_method} chunking") | |
| # Chunking | |
| if chunk_method == "recursive": | |
| chunks = self.processor.chunk_recursive(self.raw_docs) | |
| elif chunk_method == "fixed": | |
| chunks = self.processor.chunk_fixed(self.raw_docs) | |
| elif chunk_method == "sentence": | |
| chunks = self.processor.chunk_sentence(self.raw_docs) | |
| else: | |
| chunks = self.processor.chunk_recursive(self.raw_docs) | |
| persist_dir = f"{config.VECTORDB_DIR}/{chunk_method}" | |
| vectorstore = self.create_vectorstore(chunks, persist_dir) | |
| # Create retriever based on strategy | |
| if strategy == "vanilla": | |
| retriever = None | |
| elif strategy == "standard": | |
| retriever = RetrieverFactory.create_standard_retriever(vectorstore) | |
| elif strategy == "hybrid": | |
| retriever = RetrieverFactory.create_hybrid_retriever(vectorstore, chunks) | |
| elif strategy == "multiquery": | |
| retriever = RetrieverFactory.create_multiquery_retriever(chunks) | |
| elif strategy == "reranking": | |
| retriever = RetrieverFactory.create_reranking_retriever(vectorstore, self.llm) | |
| else: | |
| raise ValueError(f"Unknown strategy: {strategy}") | |
| # Build chain | |
| if strategy == "vanilla": | |
| # No retrieval, direct generation | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a medical information assistant. Answer based on your training knowledge, but acknowledge uncertainty. Encourage consulting healthcare professionals."), | |
| ("human", "{input}") | |
| ]) | |
| chain = prompt | self.llm | |
| else: | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", self.system_prompt), | |
| ("human", "{input}") | |
| ]) | |
| doc_chain = create_stuff_documents_chain(self.llm, prompt) | |
| chain = create_retrieval_chain(retriever, doc_chain) | |
| return chain, chunks | |
| def evaluate_semantic(self, ground_truth: str, prediction: str) -> float: | |
| """Calculate semantic similarity between answers.""" | |
| if not ground_truth or not prediction: | |
| return 0.0 | |
| # Simple word overlap metric (proxy for semantic similarity) | |
| gt_words = set(ground_truth.lower().split()) | |
| pred_words = set(prediction.lower().split()) | |
| if not gt_words: | |
| return 0.0 | |
| intersection = len(gt_words & pred_words) | |
| return intersection / len(gt_words) | |
| def evaluate_faithfulness(self, answer: str, contexts: List[Document]) -> float: | |
| """Check if answer claims are supported by context.""" | |
| if not contexts: | |
| return 0.0 | |
| context_text = " ".join([c.page_content for c in contexts]).lower() | |
| answer_sentences = [s.strip() for s in answer.split('.') if len(s.strip()) > 10] | |
| if not answer_sentences: | |
| return 0.0 | |
| supported = 0 | |
| for sent in answer_sentences: | |
| # Check if key terms from sentence appear in context | |
| sent_words = set(sent.lower().split()) | |
| overlap = len(sent_words & set(context_text.split())) | |
| if overlap / max(len(sent_words), 1) > 0.3: # Threshold | |
| supported += 1 | |
| return supported / len(answer_sentences) | |
| def run_evaluation(self, strategies: List[str], eval_samples: int = config.EVAL_SAMPLE_SIZE): | |
| """Run comparative evaluation across strategies.""" | |
| results = [] | |
| # Sample evaluation set | |
| eval_df = self.df.sample(n=min(eval_samples, len(self.df)), random_state=42) | |
| eval_set = [ | |
| {"question": row["question"], "answer": row["answer"], "qtype": row.get("qtype", "general")} | |
| for _, row in eval_df.iterrows() | |
| ] | |
| for strategy in strategies: | |
| print(f"\n{'='*60}") | |
| print(f"Evaluating Strategy: {strategy.upper()}") | |
| print(f"{'='*60}") | |
| try: | |
| chain, chunks = self.build_pipeline(strategy) | |
| strategy_results = { | |
| "strategy": strategy, | |
| "correct": 0, | |
| "faithfulness_scores": [], | |
| "semantic_scores": [], | |
| "latencies": [], | |
| "hallucination_flags": [] | |
| } | |
| for i, sample in enumerate(eval_set): | |
| question = sample["question"] | |
| ground_truth = sample["answer"] | |
| # Measure latency | |
| start_time = time.time() | |
| try: | |
| if strategy == "vanilla": | |
| response = chain.invoke({"input": question}) | |
| answer = response.content | |
| contexts = [] | |
| else: | |
| response = chain.invoke({"input": question}) | |
| answer = response["answer"] | |
| contexts = response.get("context", []) | |
| latency = (time.time() - start_time) * 1000 # ms | |
| # Calculate metrics | |
| semantic_sim = self.evaluate_semantic(ground_truth, answer) | |
| faithfulness = self.evaluate_faithfulness(answer, contexts) if strategy != "vanilla" else 0.0 | |
| # Heuristic hallucination detection | |
| is_hallucination = faithfulness < 0.5 if strategy != "vanilla" else semantic_sim < 0.3 | |
| strategy_results["semantic_scores"].append(semantic_sim) | |
| strategy_results["faithfulness_scores"].append(faithfulness) | |
| strategy_results["latencies"].append(latency) | |
| strategy_results["hallucination_flags"].append(1 if is_hallucination else 0) | |
| # Check correctness (simplified) | |
| if semantic_sim > 0.5: | |
| strategy_results["correct"] += 1 | |
| if (i + 1) % 20 == 0: | |
| print(f" Processed {i+1}/{len(eval_set)} samples...") | |
| except Exception as e: | |
| print(f" Error on sample {i}: {e}") | |
| continue | |
| # Aggregate results | |
| avg_faithfulness = np.mean(strategy_results["faithfulness_scores"]) if strategy_results["faithfulness_scores"] else 0 | |
| avg_semantic = np.mean(strategy_results["semantic_scores"]) if strategy_results["semantic_scores"] else 0 | |
| avg_latency = np.mean(strategy_results["latencies"]) if strategy_results["latencies"] else 0 | |
| hallucination_rate = np.mean(strategy_results["hallucination_flags"]) if strategy_results["hallucination_flags"] else 0 | |
| accuracy = strategy_results["correct"] / len(eval_set) | |
| results.append({ | |
| "Strategy": strategy, | |
| "Accuracy": accuracy, | |
| "Faithfulness": avg_faithfulness, | |
| "Semantic_Similarity": avg_semantic, | |
| "Hallucination_Rate": hallucination_rate, | |
| "Latency_ms": avg_latency | |
| }) | |
| print(f"\nResults for {strategy}:") | |
| print(f" Accuracy: {accuracy:.3f}") | |
| print(f" Faithfulness: {avg_faithfulness:.3f}") | |
| print(f" Hallucination Rate: {hallucination_rate:.3f}") | |
| print(f" Avg Latency: {avg_latency:.0f}ms") | |
| except Exception as e: | |
| print(f"Failed to evaluate {strategy}: {e}") | |
| continue | |
| return pd.DataFrame(results) | |
| def generate_visualizations(self, results_df: pd.DataFrame): | |
| """Generate comparison charts.""" | |
| if results_df.empty: | |
| print("No results to visualize") | |
| return | |
| plt.figure(figsize=(15, 10)) | |
| # Plot 1: Performance Metrics Comparison | |
| plt.subplot(2, 2, 1) | |
| metrics = ["Accuracy", "Faithfulness", "Semantic_Similarity"] | |
| x = np.arange(len(results_df)) | |
| width = 0.25 | |
| for i, metric in enumerate(metrics): | |
| plt.bar(x + i*width, results_df[metric], width, label=metric) | |
| plt.xlabel('Strategy') | |
| plt.ylabel('Score') | |
| plt.title('Performance Metrics Comparison') | |
| plt.xticks(x + width, results_df["Strategy"], rotation=45) | |
| plt.legend() | |
| plt.ylim(0, 1) | |
| # Plot 2: Hallucination Rate | |
| plt.subplot(2, 2, 2) | |
| colors = ['red' if x > 0.2 else 'orange' if x > 0.1 else 'green' for x in results_df["Hallucination_Rate"]] | |
| plt.bar(results_df["Strategy"], results_df["Hallucination_Rate"], color=colors) | |
| plt.xlabel('Strategy') | |
| plt.ylabel('Hallucination Rate') | |
| plt.title('Hallucination Rate by Strategy (Lower is Better)') | |
| plt.xticks(rotation=45) | |
| # Plot 3: Latency Comparison | |
| plt.subplot(2, 2, 3) | |
| plt.bar(results_df["Strategy"], results_df["Latency_ms"], color='skyblue') | |
| plt.xlabel('Strategy') | |
| plt.ylabel('Latency (ms)') | |
| plt.title('Response Latency by Strategy') | |
| plt.xticks(rotation=45) | |
| # Plot 4: Safety vs Speed Trade-off | |
| plt.subplot(2, 2, 4) | |
| plt.scatter(results_df["Latency_ms"], 1 - results_df["Hallucination_Rate"], | |
| s=200, alpha=0.6, c=range(len(results_df)), cmap='viridis') | |
| for i, txt in enumerate(results_df["Strategy"]): | |
| plt.annotate(txt, (results_df["Latency_ms"].iloc[i], 1 - results_df["Hallucination_Rate"].iloc[i]), | |
| xytext=(5, 5), textcoords='offset points', fontsize=8) | |
| plt.xlabel('Latency (ms)') | |
| plt.ylabel('Safety Score (1 - Hallucination Rate)') | |
| plt.title('Safety vs. Speed Trade-off') | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(f"{config.RESULTS_DIR}/rag_comparison.png", dpi=300, bbox_inches='tight') | |
| print(f"\nSaved visualization to {config.RESULTS_DIR}/rag_comparison.png") | |
| plt.show() | |
| # Save results table | |
| results_df.to_csv(f"{config.RESULTS_DIR}/rag_results.csv", index=False) | |
| print(f"Saved results to {config.RESULTS_DIR}/rag_results.csv") | |
| def main(): | |
| """Main execution function.""" | |
| print("Medical RAG Comparative Study") | |
| print("=" * 60) | |
| # Load data | |
| loader = MedQuADLoader(config.DATA_DIR) | |
| df = loader.load() | |
| print(f"Loaded {len(df)} medical QA pairs") | |
| # Initialize pipeline | |
| pipeline = MedicalRAGPipeline(df) | |
| # Define strategies to evaluate | |
| strategies = [ | |
| "vanilla", # Baseline | |
| "standard", # Basic RAG | |
| "multiquery", # Multi-Query Expansion | |
| "hybrid", # Hybrid Retrieval | |
| "reranking" # Reranking + Reformulation | |
| ] | |
| # Run evaluation | |
| results_df = pipeline.run_evaluation(strategies, eval_samples=config.EVAL_SAMPLE_SIZE) | |
| # Generate visualizations | |
| pipeline.generate_visualizations(results_df) | |
| # Print final summary | |
| print("\n" + "="*60) | |
| print("FINAL SUMMARY") | |
| print("="*60) | |
| print(results_df.to_string(index=False)) | |
| # Identify best strategy | |
| if not results_df.empty: | |
| best_safety = results_df.loc[results_df['Hallucination_Rate'].idxmin()] | |
| best_speed = results_df.loc[results_df['Latency_ms'].idxmin()] | |
| best_overall = results_df.loc[(results_df['Faithfulness'] * 0.5 + | |
| (1 - results_df['Hallucination_Rate']) * 0.5).idxmax()] | |
| print(f"\nBest Safety (Lowest Hallucination): {best_safety['Strategy']} ({best_safety['Hallucination_Rate']:.1%})") | |
| print(f"Best Speed (Lowest Latency): {best_speed['Strategy']} ({best_speed['Latency_ms']:.0f}ms)") | |
| print(f"Best Overall Balance: {best_overall['Strategy']}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment