Created
December 28, 2024 20:44
-
-
Save bigsnarfdude/08f9c27d5d27922396ed0775cdfc06ff to your computer and use it in GitHub Desktop.
scoring_engine_with_metadata_creation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chromadb | |
import json | |
from ollama import Client | |
from typing import List, Dict, Any | |
import re | |
from dataclasses import dataclass | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor | |
@dataclass | |
class QueryMetrics: | |
"""Metrics for evaluating query and response quality.""" | |
query_comprehension_score: float | |
response_relevance_score: float | |
response_accuracy_score: float | |
citation_relevance_score: float | |
topic_adherence_score: float | |
content_effectiveness_score: float | |
escalation_efficiency_score: float | |
class EnhancedLocalEmbeddingFunction: | |
"""Enhanced embedding function with better error handling and caching.""" | |
def __init__(self): | |
self.client = Client(host='http://localhost:11434') | |
self.cache = {} # Simple embedding cache | |
def __call__(self, texts: List[str]) -> List[List[float]]: | |
embeddings = [] | |
for text in texts: | |
# Check cache first | |
cache_key = hash(text) | |
if cache_key in self.cache: | |
embeddings.append(self.cache[cache_key]) | |
continue | |
try: | |
# Clean and normalize text | |
cleaned_text = self._preprocess_text(text) | |
response = self.client.embeddings( | |
model='llama2', | |
prompt=cleaned_text | |
) | |
embedding = response['embedding'] | |
# Cache the result | |
self.cache[cache_key] = embedding | |
embeddings.append(embedding) | |
except Exception as e: | |
print(f"Error generating embedding: {e}") | |
embeddings.append([0.0] * 4096) | |
return embeddings | |
def _preprocess_text(self, text: str) -> str: | |
"""Clean and normalize text for better embedding quality.""" | |
text = re.sub(r'\s+', ' ', text) # Normalize whitespace | |
text = text.strip() | |
text = text.lower() # Normalize case | |
return text | |
class RegulationValidator: | |
"""Validates and scores regulation responses.""" | |
def __init__(self): | |
self.client = Client(host='http://localhost:11434') | |
def validate_response(self, query: str, response: Dict[str, Any], | |
original_text: str) -> QueryMetrics: | |
"""Validate and score the response against multiple criteria.""" | |
# Query comprehension | |
query_score = self._evaluate_query_comprehension(query, response) | |
# Response relevance and accuracy | |
relevance_score = self._evaluate_response_relevance(response, original_text) | |
accuracy_score = self._evaluate_response_accuracy(response, original_text) | |
# Citation relevance | |
citation_score = self._evaluate_citations(response, original_text) | |
# Topic adherence | |
topic_score = self._evaluate_topic_adherence(query, response, original_text) | |
# Content effectiveness | |
content_score = self._evaluate_content_effectiveness(response) | |
# Escalation efficiency | |
escalation_score = self._evaluate_escalation_efficiency(response) | |
return QueryMetrics( | |
query_comprehension_score=query_score, | |
response_relevance_score=relevance_score, | |
response_accuracy_score=accuracy_score, | |
citation_relevance_score=citation_score, | |
topic_adherence_score=topic_score, | |
content_effectiveness_score=content_score, | |
escalation_efficiency_score=escalation_score | |
) | |
def _evaluate_query_comprehension(self, query: str, response: Dict[str, Any]) -> float: | |
"""Evaluate how well the response addresses the query.""" | |
# Implement evaluation logic | |
return 0.0 # Placeholder | |
# Implement other evaluation methods similarly... | |
class AviationRegulationRAG: | |
def __init__(self): | |
"""Initialize enhanced RAG system.""" | |
self.chroma_client = chromadb.Client() | |
self.embedding_function = EnhancedLocalEmbeddingFunction() | |
self.validator = RegulationValidator() | |
def create_regulation_collection(self, regulations_data: List[Dict[str, str]]): | |
"""Create and populate enhanced ChromaDB collection.""" | |
self.collection = self.chroma_client.create_collection( | |
name="aviation_regulations", | |
embedding_function=self.embedding_function | |
) | |
# Enhanced document processing | |
processed_docs = self._process_documents(regulations_data) | |
self.collection.add( | |
documents=[doc["text"] for doc in processed_docs], | |
metadatas=[doc["metadata"] for doc in processed_docs], | |
ids=[doc["id"] for doc in processed_docs] | |
) | |
def _process_documents(self, regulations_data: List[Dict[str, str]]) | |
-> List[Dict[str, Any]]: | |
"""Process and enhance regulation documents.""" | |
processed_docs = [] | |
for reg in regulations_data: | |
# Extract key phrases and entities | |
key_phrases = self._extract_key_phrases(reg["text"]) | |
# Create enhanced metadata | |
metadata = { | |
"section": reg["section_number"], | |
"title": reg["title"], | |
"key_phrases": key_phrases, | |
"word_count": len(reg["text"].split()), | |
"complexity_score": self._calculate_complexity(reg["text"]) | |
} | |
processed_docs.append({ | |
"id": f"reg_{reg['section_number']}", | |
"text": reg["text"], | |
"metadata": metadata | |
}) | |
return processed_docs | |
def _extract_key_phrases(self, text: str) -> List[str]: | |
"""Extract key phrases from text using local model.""" | |
try: | |
response = self.client.chat(model='llama2', messages=[ | |
{ | |
'role': 'system', | |
'content': 'Extract key technical phrases from the text. Return as JSON array.' | |
}, | |
{ | |
'role': 'user', | |
'content': text | |
} | |
]) | |
return json.loads(response.message.content) | |
except: | |
return [] | |
def _calculate_complexity(self, text: str) -> float: | |
"""Calculate text complexity score.""" | |
words = text.split() | |
avg_word_length = sum(len(word) for word in words) / len(words) | |
sentence_count = len(re.split(r'[.!?]+', text)) | |
return (avg_word_length * 0.5) + (len(words)/sentence_count * 0.5) | |
def get_relevant_regulation(self, section_number: str, | |
context: str = "") -> Dict[str, Any]: | |
"""Enhanced regulation retrieval with context awareness.""" | |
# Try exact match first | |
results = self.collection.get( | |
where={"section": section_number} | |
) | |
if results and results['documents']: | |
regulation = { | |
"text": results['documents'][0], | |
"metadata": results['metadatas'][0] | |
} | |
# Validate and score the result | |
metrics = self.validator.validate_response( | |
query=f"Section {section_number}", | |
response=regulation, | |
original_text=results['documents'][0] | |
) | |
regulation["metrics"] = metrics | |
return regulation | |
# If no exact match, do semantic search with context | |
search_text = f"Section {section_number} {context}".strip() | |
results = self.collection.query( | |
query_texts=[search_text], | |
n_results=3 # Get top 3 for reranking | |
) | |
if not results['documents']: | |
return None | |
# Rerank results based on multiple criteria | |
ranked_results = self._rerank_results( | |
query=search_text, | |
documents=results['documents'][0], | |
metadatas=results['metadatas'][0] | |
) | |
return ranked_results[0] if ranked_results else None | |
def _rerank_results(self, query: str, documents: List[str], | |
metadatas: List[Dict]) -> List[Dict]: | |
"""Rerank results using multiple criteria.""" | |
scored_results = [] | |
for doc, metadata in zip(documents, metadatas): | |
# Calculate various relevance scores | |
semantic_score = self._calculate_semantic_similarity(query, doc) | |
key_phrase_score = self._calculate_key_phrase_overlap(query, metadata) | |
complexity_match = self._calculate_complexity_match(query, doc) | |
# Combine scores | |
final_score = ( | |
semantic_score * 0.5 + | |
key_phrase_score * 0.3 + | |
complexity_match * 0.2 | |
) | |
scored_results.append({ | |
"text": doc, | |
"metadata": metadata, | |
"score": final_score | |
}) | |
return sorted(scored_results, key=lambda x: x["score"], reverse=True) | |
def _calculate_semantic_similarity(self, query: str, doc: str) -> float: | |
"""Calculate semantic similarity between query and document.""" | |
query_embedding = self.embedding_function([query])[0] | |
doc_embedding = self.embedding_function([doc])[0] | |
return np.dot(query_embedding, doc_embedding) | |
def _calculate_key_phrase_overlap(self, query: str, metadata: Dict) -> float: | |
"""Calculate overlap between query and document key phrases.""" | |
query_phrases = set(self._extract_key_phrases(query)) | |
doc_phrases = set(metadata.get("key_phrases", [])) | |
if not doc_phrases: | |
return 0.0 | |
return len(query_phrases & doc_phrases) / len(doc_phrases) | |
def _calculate_complexity_match(self, query: str, doc: str) -> float: | |
"""Calculate how well document complexity matches query needs.""" | |
query_complexity = self._calculate_complexity(query) | |
doc_complexity = self._calculate_complexity(doc) | |
return 1.0 - abs(query_complexity - doc_complexity) | |
def generate_regulation_examples(rag: AviationRegulationRAG, | |
section_number: str) -> Dict[str, Any]: | |
"""Enhanced example generation with better prompting and validation.""" | |
regulation = rag.get_relevant_regulation(section_number) | |
if not regulation: | |
print(f"No regulation found for section {section_number}") | |
return None | |
# Enhanced prompt with better structure and guidelines | |
prompt = f""" | |
Based on this aviation regulation: | |
{regulation['text']} | |
Generate a comprehensive JSON response that must: | |
1. Accurately reflect the regulation's content | |
2. Include relevant citations | |
3. Maintain consistency with aviation standards | |
4. Provide realistic scenarios | |
5. Include clear compliance criteria | |
The response must be a single JSON object with this exact structure: | |
{{ | |
"metadata": {{ | |
"section": "{section_number}", | |
"topic": "Extract from regulation", | |
"compliance_requirements": ["List from regulation"], | |
"keywords": ["Key terms"], | |
"related_sections": ["Referenced sections"] | |
}}, | |
"scenarios": {{ | |
"common_cases": [ | |
{{ | |
"type": "Specific scenario type", | |
"situation_description": "Detailed description", | |
"compliance_status": "Clear status", | |
"violations": ["Specific violations if any"], | |
"relevant_regulation_text": "Exact quote", | |
"required_remediation": "Specific steps" | |
}} | |
], | |
"edge_cases": [...], | |
"complex_cases": [...] | |
}} | |
}} | |
Requirements: | |
- Base ALL examples on the regulation text | |
- Include EXACT quotes for regulation text | |
- Provide SPECIFIC rather than generic descriptions | |
- Ensure ALL scenarios are realistic and practical | |
- Include precise references to regulation sections | |
Return ONLY valid JSON with no additional text. | |
""" | |
try: | |
client = Client(host='http://localhost:11434') | |
response = client.chat(model='llama2', messages=[ | |
{ | |
'role': 'system', | |
'content': 'You are a precise JSON generator for aviation regulations. Generate only valid JSON that exactly matches the requested structure.' | |
}, | |
{ | |
'role': 'user', | |
'content': prompt | |
} | |
]) | |
if response and hasattr(response, 'message'): | |
try: | |
result = json.loads(response.message.content) | |
# Validate the response | |
metrics = rag.validator.validate_response( | |
query=f"Section {section_number}", | |
response=result, | |
original_text=regulation['text'] | |
) | |
# Add metrics to result | |
result['performance_metrics'] = { | |
'query_comprehension': metrics.query_comprehension_score, | |
'response_relevance': metrics.response_relevance_score, | |
'response_accuracy': metrics.response_accuracy_score, | |
'citation_relevance': metrics.citation_relevance_score, | |
'topic_adherence': metrics.topic_adherence_score, | |
'content_effectiveness': metrics.content_effectiveness_score, | |
'escalation_efficiency': metrics.escalation_efficiency_score | |
} | |
return result | |
except json.JSONDecodeError as e: | |
print(f"Failed to parse response as JSON: {e}") | |
print("Raw response:", response.message.content[:200] + "...") | |
return None | |
else: | |
print("No valid response received from the model") | |
return None | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
print("Please ensure Ollama server is running and accessible") | |
return None | |
def main(): | |
# Example regulation data | |
regulations_data = [ | |
{ | |
"section_number": "401.05", | |
"title": "Flight Crew Requirements", | |
"text": """[Your regulation text here]""" | |
} | |
] | |
# Initialize enhanced RAG | |
rag = AviationRegulationRAG() | |
# Create collection with regulations | |
rag.create_regulation_collection(regulations_data) | |
# Generate examples with performance metrics | |
section_number = "401.05" | |
result = generate_regulation_examples(rag, section_number) | |
if result: | |
print("Performance Metrics:") | |
for metric, score in result['performance_metrics'].items(): | |
print(f"{metric}: {score:.2f}") | |
print("\nGenerated Content:") | |
print(json.dumps(result, indent=2)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment