Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Created December 28, 2024 20:44
Show Gist options
  • Save bigsnarfdude/08f9c27d5d27922396ed0775cdfc06ff to your computer and use it in GitHub Desktop.
Save bigsnarfdude/08f9c27d5d27922396ed0775cdfc06ff to your computer and use it in GitHub Desktop.
scoring_engine_with_metadata_creation.py
import chromadb
import json
from ollama import Client
from typing import List, Dict, Any
import re
from dataclasses import dataclass
import numpy as np
from concurrent.futures import ThreadPoolExecutor
@dataclass
class QueryMetrics:
"""Metrics for evaluating query and response quality."""
query_comprehension_score: float
response_relevance_score: float
response_accuracy_score: float
citation_relevance_score: float
topic_adherence_score: float
content_effectiveness_score: float
escalation_efficiency_score: float
class EnhancedLocalEmbeddingFunction:
"""Enhanced embedding function with better error handling and caching."""
def __init__(self):
self.client = Client(host='http://localhost:11434')
self.cache = {} # Simple embedding cache
def __call__(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for text in texts:
# Check cache first
cache_key = hash(text)
if cache_key in self.cache:
embeddings.append(self.cache[cache_key])
continue
try:
# Clean and normalize text
cleaned_text = self._preprocess_text(text)
response = self.client.embeddings(
model='llama2',
prompt=cleaned_text
)
embedding = response['embedding']
# Cache the result
self.cache[cache_key] = embedding
embeddings.append(embedding)
except Exception as e:
print(f"Error generating embedding: {e}")
embeddings.append([0.0] * 4096)
return embeddings
def _preprocess_text(self, text: str) -> str:
"""Clean and normalize text for better embedding quality."""
text = re.sub(r'\s+', ' ', text) # Normalize whitespace
text = text.strip()
text = text.lower() # Normalize case
return text
class RegulationValidator:
"""Validates and scores regulation responses."""
def __init__(self):
self.client = Client(host='http://localhost:11434')
def validate_response(self, query: str, response: Dict[str, Any],
original_text: str) -> QueryMetrics:
"""Validate and score the response against multiple criteria."""
# Query comprehension
query_score = self._evaluate_query_comprehension(query, response)
# Response relevance and accuracy
relevance_score = self._evaluate_response_relevance(response, original_text)
accuracy_score = self._evaluate_response_accuracy(response, original_text)
# Citation relevance
citation_score = self._evaluate_citations(response, original_text)
# Topic adherence
topic_score = self._evaluate_topic_adherence(query, response, original_text)
# Content effectiveness
content_score = self._evaluate_content_effectiveness(response)
# Escalation efficiency
escalation_score = self._evaluate_escalation_efficiency(response)
return QueryMetrics(
query_comprehension_score=query_score,
response_relevance_score=relevance_score,
response_accuracy_score=accuracy_score,
citation_relevance_score=citation_score,
topic_adherence_score=topic_score,
content_effectiveness_score=content_score,
escalation_efficiency_score=escalation_score
)
def _evaluate_query_comprehension(self, query: str, response: Dict[str, Any]) -> float:
"""Evaluate how well the response addresses the query."""
# Implement evaluation logic
return 0.0 # Placeholder
# Implement other evaluation methods similarly...
class AviationRegulationRAG:
def __init__(self):
"""Initialize enhanced RAG system."""
self.chroma_client = chromadb.Client()
self.embedding_function = EnhancedLocalEmbeddingFunction()
self.validator = RegulationValidator()
def create_regulation_collection(self, regulations_data: List[Dict[str, str]]):
"""Create and populate enhanced ChromaDB collection."""
self.collection = self.chroma_client.create_collection(
name="aviation_regulations",
embedding_function=self.embedding_function
)
# Enhanced document processing
processed_docs = self._process_documents(regulations_data)
self.collection.add(
documents=[doc["text"] for doc in processed_docs],
metadatas=[doc["metadata"] for doc in processed_docs],
ids=[doc["id"] for doc in processed_docs]
)
def _process_documents(self, regulations_data: List[Dict[str, str]])
-> List[Dict[str, Any]]:
"""Process and enhance regulation documents."""
processed_docs = []
for reg in regulations_data:
# Extract key phrases and entities
key_phrases = self._extract_key_phrases(reg["text"])
# Create enhanced metadata
metadata = {
"section": reg["section_number"],
"title": reg["title"],
"key_phrases": key_phrases,
"word_count": len(reg["text"].split()),
"complexity_score": self._calculate_complexity(reg["text"])
}
processed_docs.append({
"id": f"reg_{reg['section_number']}",
"text": reg["text"],
"metadata": metadata
})
return processed_docs
def _extract_key_phrases(self, text: str) -> List[str]:
"""Extract key phrases from text using local model."""
try:
response = self.client.chat(model='llama2', messages=[
{
'role': 'system',
'content': 'Extract key technical phrases from the text. Return as JSON array.'
},
{
'role': 'user',
'content': text
}
])
return json.loads(response.message.content)
except:
return []
def _calculate_complexity(self, text: str) -> float:
"""Calculate text complexity score."""
words = text.split()
avg_word_length = sum(len(word) for word in words) / len(words)
sentence_count = len(re.split(r'[.!?]+', text))
return (avg_word_length * 0.5) + (len(words)/sentence_count * 0.5)
def get_relevant_regulation(self, section_number: str,
context: str = "") -> Dict[str, Any]:
"""Enhanced regulation retrieval with context awareness."""
# Try exact match first
results = self.collection.get(
where={"section": section_number}
)
if results and results['documents']:
regulation = {
"text": results['documents'][0],
"metadata": results['metadatas'][0]
}
# Validate and score the result
metrics = self.validator.validate_response(
query=f"Section {section_number}",
response=regulation,
original_text=results['documents'][0]
)
regulation["metrics"] = metrics
return regulation
# If no exact match, do semantic search with context
search_text = f"Section {section_number} {context}".strip()
results = self.collection.query(
query_texts=[search_text],
n_results=3 # Get top 3 for reranking
)
if not results['documents']:
return None
# Rerank results based on multiple criteria
ranked_results = self._rerank_results(
query=search_text,
documents=results['documents'][0],
metadatas=results['metadatas'][0]
)
return ranked_results[0] if ranked_results else None
def _rerank_results(self, query: str, documents: List[str],
metadatas: List[Dict]) -> List[Dict]:
"""Rerank results using multiple criteria."""
scored_results = []
for doc, metadata in zip(documents, metadatas):
# Calculate various relevance scores
semantic_score = self._calculate_semantic_similarity(query, doc)
key_phrase_score = self._calculate_key_phrase_overlap(query, metadata)
complexity_match = self._calculate_complexity_match(query, doc)
# Combine scores
final_score = (
semantic_score * 0.5 +
key_phrase_score * 0.3 +
complexity_match * 0.2
)
scored_results.append({
"text": doc,
"metadata": metadata,
"score": final_score
})
return sorted(scored_results, key=lambda x: x["score"], reverse=True)
def _calculate_semantic_similarity(self, query: str, doc: str) -> float:
"""Calculate semantic similarity between query and document."""
query_embedding = self.embedding_function([query])[0]
doc_embedding = self.embedding_function([doc])[0]
return np.dot(query_embedding, doc_embedding)
def _calculate_key_phrase_overlap(self, query: str, metadata: Dict) -> float:
"""Calculate overlap between query and document key phrases."""
query_phrases = set(self._extract_key_phrases(query))
doc_phrases = set(metadata.get("key_phrases", []))
if not doc_phrases:
return 0.0
return len(query_phrases & doc_phrases) / len(doc_phrases)
def _calculate_complexity_match(self, query: str, doc: str) -> float:
"""Calculate how well document complexity matches query needs."""
query_complexity = self._calculate_complexity(query)
doc_complexity = self._calculate_complexity(doc)
return 1.0 - abs(query_complexity - doc_complexity)
def generate_regulation_examples(rag: AviationRegulationRAG,
section_number: str) -> Dict[str, Any]:
"""Enhanced example generation with better prompting and validation."""
regulation = rag.get_relevant_regulation(section_number)
if not regulation:
print(f"No regulation found for section {section_number}")
return None
# Enhanced prompt with better structure and guidelines
prompt = f"""
Based on this aviation regulation:
{regulation['text']}
Generate a comprehensive JSON response that must:
1. Accurately reflect the regulation's content
2. Include relevant citations
3. Maintain consistency with aviation standards
4. Provide realistic scenarios
5. Include clear compliance criteria
The response must be a single JSON object with this exact structure:
{{
"metadata": {{
"section": "{section_number}",
"topic": "Extract from regulation",
"compliance_requirements": ["List from regulation"],
"keywords": ["Key terms"],
"related_sections": ["Referenced sections"]
}},
"scenarios": {{
"common_cases": [
{{
"type": "Specific scenario type",
"situation_description": "Detailed description",
"compliance_status": "Clear status",
"violations": ["Specific violations if any"],
"relevant_regulation_text": "Exact quote",
"required_remediation": "Specific steps"
}}
],
"edge_cases": [...],
"complex_cases": [...]
}}
}}
Requirements:
- Base ALL examples on the regulation text
- Include EXACT quotes for regulation text
- Provide SPECIFIC rather than generic descriptions
- Ensure ALL scenarios are realistic and practical
- Include precise references to regulation sections
Return ONLY valid JSON with no additional text.
"""
try:
client = Client(host='http://localhost:11434')
response = client.chat(model='llama2', messages=[
{
'role': 'system',
'content': 'You are a precise JSON generator for aviation regulations. Generate only valid JSON that exactly matches the requested structure.'
},
{
'role': 'user',
'content': prompt
}
])
if response and hasattr(response, 'message'):
try:
result = json.loads(response.message.content)
# Validate the response
metrics = rag.validator.validate_response(
query=f"Section {section_number}",
response=result,
original_text=regulation['text']
)
# Add metrics to result
result['performance_metrics'] = {
'query_comprehension': metrics.query_comprehension_score,
'response_relevance': metrics.response_relevance_score,
'response_accuracy': metrics.response_accuracy_score,
'citation_relevance': metrics.citation_relevance_score,
'topic_adherence': metrics.topic_adherence_score,
'content_effectiveness': metrics.content_effectiveness_score,
'escalation_efficiency': metrics.escalation_efficiency_score
}
return result
except json.JSONDecodeError as e:
print(f"Failed to parse response as JSON: {e}")
print("Raw response:", response.message.content[:200] + "...")
return None
else:
print("No valid response received from the model")
return None
except Exception as e:
print(f"Error: {str(e)}")
print("Please ensure Ollama server is running and accessible")
return None
def main():
# Example regulation data
regulations_data = [
{
"section_number": "401.05",
"title": "Flight Crew Requirements",
"text": """[Your regulation text here]"""
}
]
# Initialize enhanced RAG
rag = AviationRegulationRAG()
# Create collection with regulations
rag.create_regulation_collection(regulations_data)
# Generate examples with performance metrics
section_number = "401.05"
result = generate_regulation_examples(rag, section_number)
if result:
print("Performance Metrics:")
for metric, score in result['performance_metrics'].items():
print(f"{metric}: {score:.2f}")
print("\nGenerated Content:")
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment