Created
June 4, 2025 14:35
-
-
Save a-agmon/d303b5425647e5085fba54e134407f85 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utils.database import get_db_connection, TABLE_NAME | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
def zscore_search(query_text, z_threshold=2.5, top_k=1000): | |
""" | |
Perform z-score calibrated vector search. | |
Args: | |
query_text: Text to search for | |
z_threshold: Z-score threshold for filtering | |
top_k: Maximum number of results to consider | |
Returns: | |
List of movies above the z-score threshold | |
""" | |
# Setup embedding model | |
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Generate query embedding | |
query_vector = embed_model.get_text_embedding(query_text) | |
# Connect to database | |
conn = get_db_connection() | |
try: | |
with conn.cursor() as cur: | |
# Z-score calibrated search query | |
zscore_query = f""" | |
WITH params AS ( | |
SELECT %s::vector AS qvec, %s::int AS top_k, %s::float AS z_cut | |
), | |
top AS ( | |
SELECT id, title, overview, | |
1 - (vector <=> (SELECT qvec FROM params)) AS sim | |
FROM {TABLE_NAME} | |
WHERE vector IS NOT NULL | |
ORDER BY vector <=> (SELECT qvec FROM params) | |
LIMIT (SELECT top_k FROM params) | |
), | |
sample AS ( | |
SELECT 1 - (vector <=> (SELECT qvec FROM params)) AS sim | |
FROM {TABLE_NAME} TABLESAMPLE SYSTEM (1) | |
WHERE vector IS NOT NULL | |
), | |
stats AS ( | |
SELECT AVG(sim) AS mu, STDDEV_POP(sim) AS sigma FROM sample | |
) | |
SELECT t.id, t.title, t.overview, t.sim, s.mu, s.sigma | |
FROM top t, stats s | |
WHERE t.sim >= s.mu + s.sigma * (SELECT z_cut FROM params) | |
ORDER BY t.sim DESC; | |
""" | |
cur.execute(zscore_query, (query_vector, top_k, z_threshold)) | |
results = cur.fetchall() | |
return results | |
finally: | |
conn.close() | |
def test_zscore_search(): | |
"""Test the z-score search with a sample query.""" | |
print("Testing z-score calibrated search...") | |
# Test search for "love" movies | |
results = zscore_search("love emotions relationships", z_threshold=2.5, top_k=100000) | |
if results: | |
print(f"Found {len(results)} movies above z-score threshold") | |
print(f"Sample statistics - Mean: {results[0][4]:.4f}, StdDev: {results[0][5]:.4f}") | |
print("\n" + "="*80) | |
print("TOP 5 RESULTS:") | |
print("="*80) | |
for i, (movie_id, title, overview, sim, mu, sigma) in enumerate(results[:5]): | |
z_score = (sim - mu) / sigma if sigma > 0 else 0 | |
print(f"{i+1}. {title} (ID: {movie_id})") | |
print(f" Similarity: {sim:.4f}, Z-score: {z_score:.2f}") | |
print(f" Overview: {overview[:200]}..." if overview and len(overview) > 200 else f" Overview: {overview}") | |
print() | |
if len(results) > 5: | |
print("\n" + "="*80) | |
print("BOTTOM 5 RESULTS:") | |
print("="*80) | |
bottom_5 = results[-5:] | |
for i, (movie_id, title, overview, sim, mu, sigma) in enumerate(bottom_5): | |
z_score = (sim - mu) / sigma if sigma > 0 else 0 | |
print(f"{len(results)-4+i}. {title} (ID: {movie_id})") | |
print(f" Similarity: {sim:.4f}, Z-score: {z_score:.2f}") | |
print(f" Overview: {overview[:200]}..." if overview and len(overview) > 200 else f" Overview: {overview}") | |
print() | |
else: | |
print("No movies found above threshold") | |
if __name__ == "__main__": | |
test_zscore_search() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment