Last active
June 7, 2025 06:45
-
-
Save anonymousmaharaj/5cac0cda8cb6c09699d5d6f8ed1d52cc to your computer and use it in GitHub Desktop.
Openwebui Custom RAG Filter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
title: Custom RAG Filter with OpenSearch & OpenAI-compatible APIs | |
author: Alexey Fateev | |
author_url: https://github.com/anonymousmaharaj | |
funding_url: https://gist.github.com/anonymousmaharaj | |
version: 1.0.0 | |
license: MIT | |
requirements: opensearch-py, httpx | |
""" | |
from typing import Optional, List, Dict | |
from pydantic import BaseModel, Field | |
from opensearchpy import AsyncOpenSearch | |
import httpx | |
class Filter: | |
class Valves(BaseModel): | |
# RAG settings | |
enable_rag: bool = Field(default=True, description="Enable custom RAG") | |
top_k_chunks: int = Field( | |
default=10, description="Number of top chunks for initial search" | |
) | |
final_k_chunks: int = Field( | |
default=3, description="Number of final chunks after rerank" | |
) | |
similarity_threshold: float = Field( | |
default=0.7, description="Minimum similarity threshold (0-1)" | |
) | |
rag_prompt: str = Field( | |
default="", description="Additional prompt for RAG (optional)" | |
) | |
# OpenAI-compatible API settings for embedding | |
embedding_api_url: str = Field( | |
default="http://localhost:8000/v1/embeddings", | |
description="URL for embedding API (OpenAI-compatible)", | |
) | |
embedding_model: str = Field( | |
default="text-embedding-ada-002", description="Embedding model name" | |
) | |
embedding_api_key: str = Field( | |
default="", description="API key for embedding service" | |
) | |
# OpenAI-compatible API settings for rerank | |
rerank_api_url: str = Field( | |
default="http://localhost:8001/v1/rerank", | |
description="URL for rerank API (OpenAI-compatible)", | |
) | |
rerank_model: str = Field( | |
default="rerank-multilingual-v3.0", description="Rerank model name" | |
) | |
rerank_api_key: str = Field( | |
default="", description="API key for rerank service" | |
) | |
# OpenSearch settings | |
opensearch_host: str = Field(default="localhost", description="OpenSearch host") | |
opensearch_port: int = Field(default=9200, description="OpenSearch port") | |
opensearch_username: str = Field( | |
default="admin", description="OpenSearch username" | |
) | |
opensearch_password: str = Field( | |
default="admin", description="OpenSearch password" | |
) | |
opensearch_use_ssl: bool = Field( | |
default=False, description="Use SSL for OpenSearch" | |
) | |
opensearch_verify_certs: bool = Field( | |
default=False, description="Verify SSL certificates" | |
) | |
index_name: str = Field( | |
default="documents", description="Index name in OpenSearch" | |
) | |
embedding_field_name: str = Field( | |
default="embedding", description="Embedding field name in OpenSearch" | |
) | |
def __init__(self): | |
self.valves = self.Valves() | |
self._opensearch_client = None | |
def _get_opensearch_client(self) -> AsyncOpenSearch: | |
"""Gets async OpenSearch client with lazy initialization""" | |
if self._opensearch_client is None: | |
auth = (self.valves.opensearch_username, self.valves.opensearch_password) | |
self._opensearch_client = AsyncOpenSearch( | |
hosts=[self.valves.opensearch_host], | |
http_auth=auth, | |
use_ssl=self.valves.opensearch_use_ssl, | |
verify_certs=False, | |
ssl_assert_hostname=False, | |
ssl_show_warn=False, | |
ssl_context=None, | |
) | |
return self._opensearch_client | |
async def inlet(self, body: dict, user: Optional[dict] = None) -> dict: | |
""" | |
RAG inlet - gets query embedding via OpenAI API, | |
searches for similar chunks in OpenSearch, applies rerank and enriches the query | |
""" | |
if not self.valves.enable_rag: | |
return body | |
print(f"🔄 Starting custom RAG with OpenSearch + OpenAI APIs...") | |
# Get last user message | |
messages = body.get("messages", []) | |
if not messages: | |
return body | |
last_message = messages[-1] | |
user_query = last_message.get("content", "") | |
if not user_query.strip(): | |
return body | |
try: | |
# Step 1: Get embedding via OpenAI API | |
print(f"📊 Generating embedding via OpenAI API...") | |
query_embedding = await self._get_embedding(user_query) | |
if not query_embedding: | |
print(f"❌ Failed to get embedding") | |
return body | |
# Step 2: Search for similar chunks in OpenSearch | |
print(f"🔍 Searching for similar chunks in OpenSearch...") | |
initial_chunks = await self._search_opensearch(query_embedding) | |
if not initial_chunks: | |
print(f"❌ No relevant chunks found") | |
return body | |
# Step 3: Apply rerank to improve results | |
print(f"🎯 Applying rerank to {len(initial_chunks)} chunks...") | |
ranked_chunks = await self._rerank_chunks(user_query, initial_chunks) | |
if not ranked_chunks: | |
print(f"❌ Rerank returned no results") | |
ranked_chunks = initial_chunks[: self.valves.final_k_chunks] | |
# Step 4: Enrich the query with top chunks after rerank | |
final_chunks = ranked_chunks[: self.valves.final_k_chunks] | |
print(f"✨ Enriching query with {len(final_chunks)} top chunks") | |
enriched_query = self._enrich_query_with_chunks(user_query, final_chunks) | |
# Modify last message | |
last_message["content"] = enriched_query | |
print(f"🔍 Enriched query: {enriched_query}") | |
print(f"✅ RAG finished successfully") | |
except Exception as e: | |
print(f"❌ Error in RAG pipeline: {e}") | |
# On error, return original query | |
return body | |
async def _get_embedding(self, text: str) -> Optional[List[float]]: | |
"""Gets embedding via OpenAI-compatible API""" | |
try: | |
headers = {"Content-Type": "application/json"} | |
if self.valves.embedding_api_key: | |
headers["Authorization"] = f"Bearer {self.valves.embedding_api_key}" | |
payload = {"model": self.valves.embedding_model, "input": [text]} | |
async with httpx.AsyncClient(timeout=30.0, verify=False) as client: | |
response = await client.post( | |
self.valves.embedding_api_url, headers=headers, json=payload | |
) | |
response.raise_for_status() | |
result = response.json() | |
if "data" in result and len(result["data"]) > 0: | |
return result["data"][0]["embedding"] | |
else: | |
print(f"❌ Unexpected response format from embedding API: {result}") | |
return None | |
except Exception as e: | |
print(f"❌ Error getting embedding: {e}") | |
return None | |
async def _search_opensearch(self, query_embedding: List[float]) -> List[Dict]: | |
"""Searches for similar chunks in OpenSearch""" | |
try: | |
client = self._get_opensearch_client() | |
# Check if index exists | |
index_exists = await client.indices.exists(index=self.valves.index_name) | |
if not index_exists: | |
print(f"❌ Index '{self.valves.index_name}' not found in OpenSearch") | |
return [] | |
# OpenSearch kNN search query | |
search_body = { | |
"size": self.valves.top_k_chunks, | |
"_source": {"excludes": [self.valves.embedding_field_name]}, | |
"query": { | |
"knn": { | |
self.valves.embedding_field_name: { | |
"vector": query_embedding, | |
"k": self.valves.top_k_chunks, | |
} | |
} | |
}, | |
} | |
print(f"Search body: {search_body}") | |
response = await client.search( | |
index=self.valves.index_name, body=search_body | |
) | |
# Convert results to required format | |
results = [] | |
for hit in response["hits"]["hits"]: | |
score = hit["_score"] | |
source = hit["_source"] | |
# Filter by similarity threshold | |
if score >= self.valves.similarity_threshold: | |
results.append( | |
{ | |
"content": source.get("content", source.get("text", "")), | |
"metadata": source.get("metadata", {}), | |
"similarity": score, | |
"title": source.get("title", ""), | |
"source": source.get("source", ""), | |
"file_name": source.get("file_name", ""), | |
} | |
) | |
return results | |
except Exception as e: | |
print(f"❌ Error searching OpenSearch: {e}") | |
return [] | |
async def _rerank_chunks(self, query: str, chunks: List[Dict]) -> List[Dict]: | |
"""Applies rerank model to found chunks""" | |
try: | |
if not chunks: | |
return chunks | |
headers = {"Content-Type": "application/json"} | |
if self.valves.rerank_api_key: | |
headers["Authorization"] = f"Bearer {self.valves.rerank_api_key}" | |
# Prepare documents for rerank | |
documents = [] | |
for chunk in chunks: | |
content = chunk.get("content", "").strip() | |
if content: | |
documents.append(content) | |
if not documents: | |
return chunks | |
# Request format for your API | |
payload = { | |
"model": self.valves.rerank_model, | |
"text_1": [query], # query as array | |
"text_2": documents, # documents to compare | |
} | |
async with httpx.AsyncClient(timeout=30.0, verify=False) as client: | |
response = await client.post( | |
self.valves.rerank_api_url, headers=headers, json=payload | |
) | |
response.raise_for_status() | |
result = response.json() | |
# Process rerank results | |
if "data" in result: | |
# Create list of (chunk, score) pairs for sorting | |
chunk_scores = [] | |
for score_result in result["data"]: | |
index = score_result.get("index", 0) | |
score = score_result.get("score", 0) | |
if index < len(chunks): | |
chunk = chunks[index].copy() | |
chunk["rerank_score"] = score | |
chunk["original_similarity"] = chunk.get("similarity", 0) | |
chunk_scores.append((chunk, score)) | |
# Sort by rerank score (descending) | |
chunk_scores.sort(key=lambda x: x[1], reverse=True) | |
# Return top chunks | |
ranked_chunks = [ | |
chunk | |
for chunk, score in chunk_scores[: self.valves.final_k_chunks] | |
] | |
return ranked_chunks | |
else: | |
print(f"❌ Unexpected response format from rerank API: {result}") | |
return chunks | |
except Exception as e: | |
print(f"❌ Error in rerank: {e}") | |
return chunks | |
def _enrich_query_with_chunks(self, original_query: str, chunks: List[Dict]) -> str: | |
"""Enriches the original query with found chunks""" | |
if not chunks: | |
return original_query | |
context_parts = [] | |
for chunk in chunks: | |
content = chunk.get("content", "").strip() | |
similarity = chunk.get("similarity", 0) | |
rerank_score = chunk.get("rerank_score") | |
metadata = chunk.get("metadata", {}) | |
# Extract source info | |
source_info = "" | |
if chunk.get("source"): | |
source_info = f" (source: {chunk['source']})" | |
elif chunk.get("title"): | |
source_info = f" (from: {chunk['title']})" | |
elif chunk.get("file_name"): | |
source_info = f" (file: {chunk['file_name']})" | |
elif "source" in metadata: | |
source_info = f" (source: {metadata['source']})" | |
elif "title" in metadata: | |
source_info = f" (from: {metadata['title']})" | |
elif "file_name" in metadata: | |
source_info = f" (file: {metadata['file_name']})" | |
# Score info (for debugging, not included in final context) | |
score_info = f"similarity: {similarity:.3f}" | |
if rerank_score is not None: | |
score_info += f", rerank: {rerank_score:.3f}" | |
# Add only content to final context | |
context_parts.append(f"{content}{source_info}") | |
# Form context | |
context = "\n\n".join(context_parts) | |
# Rules template | |
rules = """Follow these rules: | |
1. Use ONLY the information from the provided context. | |
2. Do not add information that is not in the context, even if you know it. | |
3. Do not miss any information from the context that is important for answering the question, return links exactly as they are if present. | |
4. The answer must be CLEAR, DETAILED, and STRUCTURED. Use lists, tables, and bullet points if necessary. | |
5. If there are contradictions in the context, point them out and provide possible options.""" | |
# Add RAG_PROMPT if present | |
rag_prompt_part = ( | |
f"\n{self.valves.rag_prompt}" if self.valves.rag_prompt.strip() else "" | |
) | |
enriched_query = f"""{rules} | |
Answer the user's question: <|qs_|>{original_query}<|qe_|>{rag_prompt_part} | |
Using as much as possible the following information: <|context|> | |
{context} | |
Answer""" | |
return enriched_query |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment