Last active
February 25, 2025 13:11
-
-
Save lukemarsden/3dea6a0d6f46097bcb32449c58cafb68 to your computer and use it in GitHub Desktop.
haystack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tempfile | |
import logging | |
import httpx | |
import re | |
from typing import List, Dict, Any, Optional, Union, BinaryIO | |
import numpy as np | |
from haystack import Document | |
from haystack.utils import Secret | |
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore | |
from haystack.components.preprocessors import DocumentSplitter, DocumentCleaner | |
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever | |
from haystack.components.embedders import OpenAIDocumentEmbedder | |
from unstructured.partition.auto import partition | |
from unstructured.documents.elements import ( | |
Title, ListItem, Header, Footer, Table, Image | |
) | |
from .config import settings | |
# Configure logging | |
logging.basicConfig(level=getattr(logging, settings.LOG_LEVEL)) | |
logger = logging.getLogger(__name__) | |
class UnstructuredConverter: | |
"""Converts documents to text using unstructured""" | |
def _element_to_markdown(self, element) -> str: | |
"""Convert an unstructured element to markdown format""" | |
if not str(element).strip(): | |
return "" | |
text = str(element).strip() | |
if isinstance(element, Title): | |
return f"# {text}" | |
elif isinstance(element, Header): | |
return f"## {text}" | |
elif isinstance(element, ListItem): | |
return f"- {text}" | |
elif isinstance(element, Table): | |
# Basic table formatting - could be enhanced | |
return f"**Table**: {text}" | |
elif isinstance(element, Image): | |
return f"![Image]{text}" | |
elif isinstance(element, Footer): | |
return f"*{text}*" | |
else: # NarrativeText, Text, etc | |
return text | |
def convert(self, file: Union[str, BinaryIO], metadata: Dict[str, Any] = None) -> List[Document]: | |
"""Convert a file to text using unstructured""" | |
logger.info("Converting file to text") | |
try: | |
# If file is a string (path), use it directly, otherwise save to temp file | |
if isinstance(file, str): | |
elements = partition(filename=file) | |
else: | |
with tempfile.NamedTemporaryFile(delete=False) as temp: | |
# If file is a BytesIO, get its content | |
if hasattr(file, 'read'): | |
content = file.read() | |
if isinstance(content, str): | |
content = content.encode('utf-8') | |
temp.write(content) | |
else: | |
temp.write(file) | |
temp_path = temp.name | |
try: | |
elements = partition(filename=temp_path) | |
finally: | |
os.unlink(temp_path) | |
markdown_elements = [ | |
self._element_to_markdown(el) | |
for el in elements | |
] | |
# Filter out empty strings and join with double newlines | |
text = "\n\n".join(el for el in markdown_elements if el) | |
if not text.strip(): | |
logger.warning("No text extracted from file") | |
return [] | |
logger.info(f"Extracted {len(text)} characters") | |
# Ensure metadata has required fields | |
if metadata is None: | |
metadata = {} | |
if "data_entity_id" not in metadata: | |
logger.warning("data_entity_id not provided in metadata") | |
if "document_id" not in metadata: | |
logger.warning("document_id not provided in metadata") | |
return [Document(content=text, meta=metadata)] | |
except Exception as e: | |
logger.error(f"Document conversion error: {str(e)}") | |
raise RuntimeError(f"Document conversion error: {str(e)}") | |
async def extract_text_from_url(self, url: str) -> str: | |
"""Extract text from a URL""" | |
logger.info(f"Extracting text from URL: {url}") | |
try: | |
async with httpx.AsyncClient(timeout=30.0) as client: | |
response = await client.get(url) | |
response.raise_for_status() | |
# Save content to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False) as temp: | |
temp.write(response.content) | |
temp_path = temp.name | |
try: | |
# Use the converter to extract text | |
docs = self.convert(temp_path) | |
if not docs: | |
return "" | |
return docs[0].content | |
finally: | |
# Clean up | |
os.unlink(temp_path) | |
except Exception as e: | |
logger.error(f"URL extraction error: {str(e)}") | |
raise RuntimeError(f"URL extraction error: {str(e)}") | |
class HaystackService: | |
"""Main service class for Haystack RAG operations""" | |
def __init__(self): | |
"""Initialize the Haystack service""" | |
logger.info("Initializing HaystackService") | |
# Initialize document store | |
try: | |
self.document_store = PgvectorDocumentStore( | |
connection_string=Secret.from_token(settings.PGVECTOR_DSN), | |
embedding_dimension=settings.EMBEDDING_DIM, | |
table_name=settings.PGVECTOR_TABLE, | |
vector_function="cosine_similarity", | |
search_strategy="hnsw", | |
recreate_table=True # XXX disable to avoid data loss? | |
) | |
logger.info(f"Connected to PgvectorDocumentStore: {settings.PGVECTOR_TABLE}") | |
except Exception as e: | |
logger.error(f"Failed to connect to PgvectorDocumentStore: {str(e)}") | |
raise | |
# Initialize components | |
self.embedder = OpenAIDocumentEmbedder( | |
api_key=Secret.from_token(settings.VLLM_API_KEY), | |
api_base_url=settings.VLLM_BASE_URL, | |
model=settings.EMBEDDINGS_MODEL | |
) | |
self.converter = UnstructuredConverter() | |
# Initialize document cleaner with custom patterns | |
self.cleaner = DocumentCleaner( | |
remove_empty_lines=True, | |
remove_extra_whitespaces=True, | |
# Remove runs of 5 or more dots | |
remove_regex=r'\.{5,}' | |
) | |
self.splitter = DocumentSplitter( | |
split_length=settings.CHUNK_SIZE, | |
split_overlap=settings.CHUNK_OVERLAP, | |
split_by="word", | |
respect_sentence_boundary=True | |
) | |
self.splitter.warm_up() | |
logger.info(f"Initialized DocumentSplitter with chunk_size={settings.CHUNK_SIZE}, overlap={settings.CHUNK_OVERLAP}") | |
# Initialize retriever | |
self.retriever = PgvectorEmbeddingRetriever( | |
document_store=self.document_store, | |
filters=None, | |
top_k=5 | |
) | |
logger.info("HaystackService initialization complete") | |
async def extract_text(self, file: Optional[Union[str, BinaryIO]] = None, url: Optional[str] = None) -> str: | |
"""Extract text from a file or URL without indexing it""" | |
logger.info(f"Extracting text from file or URL") | |
if file is not None: | |
# Extract from file | |
documents = self.converter.convert(file) | |
if not documents: | |
return "" | |
return documents[0].content | |
elif url: | |
# Extract from URL | |
return await self.converter.extract_text_from_url(url) | |
else: | |
raise ValueError("Either file or url must be provided") | |
def _truncate_text_for_embedding(self, text: str) -> str: | |
"""Truncate text to a safe limit for embedding API. | |
Args: | |
text: The text to truncate | |
Returns: | |
Truncated text | |
""" | |
# Convert max tokens to characters (approximate - using 4 chars per token as conservative estimate) | |
max_chars = settings.EMBEDDINGS_MAX_TOKENS * 4 | |
if len(text) <= max_chars: | |
return text | |
logger.warning( | |
f"Truncating text from {len(text)} characters to {max_chars} characters to avoid embedding API limits.\n" | |
f"Original text: {text}\n" | |
f"Truncated portion: {text[max_chars:]}" | |
) | |
# Try to truncate at a sentence boundary if possible | |
truncated = text[:max_chars] | |
last_period = truncated.rfind('.') | |
last_question = truncated.rfind('?') | |
last_exclamation = truncated.rfind('!') | |
# Find the last sentence boundary | |
last_sentence = max(last_period, last_question, last_exclamation) | |
# If we found a sentence boundary and it's not too far from our target, use it | |
if last_sentence > max_chars * 0.8: # Only use sentence boundary if we keep at least 80% of desired length | |
return text[:last_sentence + 1].strip() | |
# Otherwise just truncate at character limit | |
return truncated.strip() | |
async def process_and_index(self, file: Union[str, BinaryIO], metadata: Dict[str, Any] = None) -> Dict[str, Any]: | |
"""Process a document and index it in the document store""" | |
logger.info(f"Processing and indexing file with metadata: {metadata}") | |
# Convert document | |
documents = self.converter.convert(file, metadata) | |
if not documents: | |
logger.warning("No documents to index") | |
return {"status": "warning", "message": "No content extracted from file"} | |
# Clean documents before splitting | |
cleaned_docs = self.cleaner.run(documents=documents)["documents"] | |
# Split into chunks | |
chunks = [] | |
for doc in cleaned_docs: | |
result = self.splitter.run(documents=[doc]) | |
chunks.extend(result["documents"]) | |
logger.info(f"Split document into {len(chunks)} chunks") | |
# Truncate chunks if needed | |
for chunk in chunks: | |
chunk.content = self._truncate_text_for_embedding(chunk.content) | |
# Generate embeddings and store | |
result = self.embedder.run(documents=chunks) | |
chunks = result["documents"] | |
# Store in database | |
self.document_store.write_documents(chunks) | |
logger.info(f"Successfully indexed {len(chunks)} chunks") | |
return { | |
"status": "success", | |
"documents_processed": len(documents), | |
"chunks_indexed": len(chunks) | |
} | |
async def query(self, query_text: str, filters: Dict[str, Any] = None, top_k: int = 5) -> List[Dict[str, Any]]: | |
"""Query the document store for relevant documents""" | |
logger.info(f"Querying with: '{query_text}', filters: {filters}, top_k: {top_k}") | |
# Generate query embedding | |
query_result = self.embedder.run(documents=[Document(content=query_text)]) | |
query_embedding = query_result["documents"][0].embedding | |
# Retrieve documents | |
if filters: | |
formatted_filters = { | |
"operator": "AND", | |
"conditions": [ | |
{"field": f"meta.{key}", "operator": "==", "value": value} | |
for key, value in filters.items() | |
] | |
} | |
else: | |
formatted_filters = None | |
results = self.retriever.run( | |
query_embedding=query_embedding, | |
filters=formatted_filters, | |
top_k=top_k | |
)["documents"] | |
logger.info(f"Retrieved {len(results)} results") | |
# Format results | |
formatted_results = [ | |
{ | |
"content": doc.content, | |
"metadata": doc.meta, | |
"score": float(doc.score if doc.score is not None else 0.0) | |
} | |
for doc in results | |
] | |
return formatted_results | |
async def delete(self, filters: Dict[str, Any]) -> Dict[str, Any]: | |
"""Delete documents from the document store based on filters""" | |
logger.info(f"Deleting documents with filters: {filters}") | |
# Format filters to use meta prefix | |
formatted_filters = { | |
"operator": "AND", | |
"conditions": [ | |
{"field": f"meta.{key}", "operator": "==", "value": value} | |
for key, value in filters.items() | |
] | |
} | |
# Find matching documents | |
matching_docs = self.document_store.filter_documents(filters=formatted_filters) | |
if not matching_docs: | |
logger.info("No documents found matching filters") | |
return {"status": "success", "documents_deleted": 0} | |
# Delete the matching documents | |
self.document_store.delete_documents(document_ids=[doc.id for doc in matching_docs]) | |
deleted = len(matching_docs) | |
logger.info(f"Deleted {deleted} documents") | |
return {"status": "success", "documents_deleted": deleted} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment