Last active
January 22, 2025 17:18
-
-
Save pszemraj/19f440349542fdd9c151465bd85c5c3c to your computer and use it in GitHub Desktop.
basic rag search system for top_k posts w gte-modernbert-base
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from typing import Dict, List | |
import faiss | |
import numpy as np | |
from datasets import load_dataset | |
from sentence_transformers import CrossEncoder, SentenceTransformer | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class RedditRAG: | |
def __init__(self, model_max_length: int = None, rerank_top_m: int = 20): | |
"""Initialize the RAG system with FAISS, GTE ModernBERT, and reranking capabilities. | |
Args: | |
model_max_length: Maximum sequence length for the encoder | |
rerank_top_m: Number of initial results to consider for reranking (M) | |
""" | |
# Initialize embedding model | |
self.encoder = SentenceTransformer( | |
"Alibaba-NLP/gte-modernbert-base", | |
tokenizer_kwargs={ | |
"model_max_length": model_max_length | |
if model_max_length is not None | |
else None | |
}, | |
) | |
# Initialize reranker | |
self.reranker = CrossEncoder( | |
"Alibaba-NLP/gte-reranker-modernbert-base", | |
automodel_args={"torch_dtype": "auto"}, | |
) | |
self.rerank_top_m = rerank_top_m | |
self.index = None | |
self.metadata = [] | |
self.documents = [] # Store original documents for reranking | |
def prepare_data(self, dataset, max_rows=None): | |
"""Process dataset and prepare content for embedding.""" | |
if max_rows is not None: | |
dataset = dataset.select(range(max_rows)) | |
def combine_text(example): | |
"""Combine title and selftext.""" | |
selftext = example["selftext"] if example["selftext"] else "" | |
example["content"] = f"{example['title']}\n\n{selftext}".strip() | |
return example | |
def standard_permalink(example): | |
"""Check and update relevant permalinks.""" | |
base_link = example["permalink"] | |
if base_link is None or len(base_link.strip()) < 2: | |
return {"permalink": None} | |
return { | |
"permalink": f"https://www.reddit.com{base_link}" | |
if base_link.startswith("/r/") | |
else base_link | |
} | |
dataset = dataset.map(combine_text, desc="Combining title and selftext") | |
dataset = dataset.map(standard_permalink, desc="fixing permalinks") | |
logger.info(f"Prepared dataset with {len(dataset)} entries") | |
return dataset | |
def build_index(self, dataset, batch_size: int = 32): | |
"""Build FAISS index from dataset embeddings.""" | |
logger.info("Computing embeddings...") | |
self.documents = dataset["content"] | |
embeddings = self.encoder.encode( | |
self.documents, normalize_embeddings=True, show_progress_bar=True | |
) | |
# Store metadata | |
self.metadata = [ | |
{"title": title, "score": int(score), "permalink": url} | |
for title, score, url in zip( | |
dataset["title"], dataset["score"], dataset["permalink"] | |
) | |
] | |
# Initialize and populate FAISS index | |
dimension = embeddings.shape[1] | |
self.index = faiss.IndexFlatIP(dimension) | |
# Normalize vectors for cosine similarity | |
faiss.normalize_L2(embeddings) | |
# Add vectors to index | |
self.index.add(embeddings) | |
logger.info(f"Built index with {len(dataset)} vectors") | |
def rerank_results( | |
self, query: str, initial_results: List[Dict], k: int | |
) -> List[Dict]: | |
"""Rerank the initial results using the cross-encoder model.""" | |
# Prepare pairs for reranking | |
pairs = [[query, self.documents[result["index"]]] for result in initial_results] | |
# Get reranking scores | |
rerank_scores = self.reranker.predict(pairs) | |
# Add reranking scores to results | |
for score, result in zip(rerank_scores, initial_results): | |
result["rerank_score"] = float(score) | |
# Sort by reranking score and get top k | |
reranked_results = sorted( | |
initial_results, key=lambda x: x["rerank_score"], reverse=True | |
)[:k] | |
return reranked_results | |
def query(self, query_text: str, n_results: int = 5) -> List[Dict]: | |
"""Query the FAISS index, rerank results, and return the most relevant posts.""" | |
if self.index is None: | |
raise ValueError("Index not built. Call build_index first.") | |
# Encode and normalize query | |
query_embedding = self.encoder.encode([query_text]) | |
faiss.normalize_L2(query_embedding) | |
# Search index - get more results than needed for reranking | |
scores, indices = self.index.search(query_embedding, self.rerank_top_m) | |
# Format initial results | |
initial_results = [ | |
{ | |
**self.metadata[idx], | |
"similarity_score": float(score), | |
"index": int(idx), # Store index for accessing original document | |
} | |
for score, idx in zip(scores[0], indices[0]) | |
] | |
# Rerank results | |
final_results = self.rerank_results(query_text, initial_results, n_results) | |
# Remove the index from final results as it's no longer needed | |
for result in final_results: | |
del result["index"] | |
return final_results | |
def initialize_rag_system( | |
dataset_name="pszemraj/LocalLLaMA-posts", | |
max_rows: int = None, | |
model_max_length: int = None, | |
rerank_top_m: int = 20, | |
): | |
"""Initialize and populate the RAG system with the dataset.""" | |
logger.info(f"Loading dataset: {dataset_name}") | |
ds = load_dataset(dataset_name) | |
logger.info(f"Dataset loaded with {len(ds['train'])} entries") | |
# Initialize RAG system | |
rag = RedditRAG(model_max_length=model_max_length, rerank_top_m=rerank_top_m) | |
# Prepare and index data | |
processed_ds = rag.prepare_data(ds["train"], max_rows) | |
rag.build_index(processed_ds) | |
return rag | |
# Example usage: | |
if __name__ == "__main__": | |
# Initialize with reranking of top 20 results | |
rag = initialize_rag_system(max_rows=1000, rerank_top_m=20) | |
my_query = "How to improve language model training?" | |
results = rag.query(my_query, n_results=5) # Get top 5 after reranking | |
print(f"Query: {my_query}") | |
print("-" * 80, "\n") | |
for i, result in enumerate(results, start=1): | |
print(f"RESULT {i}:") | |
print(f"\tTitle: {result['title']}") | |
print(f"\tReddit Score: {result['score']}") | |
print(f"\tURL: {result['permalink']}") | |
print(f"\tInitial similarity: {result['similarity_score']:.3f}") | |
print(f"\tReranking score: {result['rerank_score']:.3f}") | |
print("-" * 80, "\n") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
datasets | |
faiss-cpu | |
numpy | |
sentence-transformers | |
transformers>=4.48.0 |
with gte + reranker
Batches: 100%
32/32 [00:03<00:00, 31.90it/s]
Query: How to improve language model training?
--------------------------------------------------------------------------------
RESULT 1:
Title: Engineering training for better quality
Reddit Score: 3
URL: https://www.reddit.com/r/LocalLLaMA/comments/13ejbwx/engineering_training_for_better_quality/
Initial similarity: 0.729
Reranking score: 0.909
--------------------------------------------------------------------------------
RESULT 2:
Title: Training a model with larger context
Reddit Score: 4
URL: https://www.reddit.com/r/LocalLLaMA/comments/13gzwmv/training_a_model_with_larger_context/
Initial similarity: 0.761
Reranking score: 0.906
--------------------------------------------------------------------------------
RESULT 3:
Title: How to improve the quality of Large Language Models and solve the alignment problem
Reddit Score: 8
URL: https://www.reddit.com/r/LocalLLaMA/comments/139iyl2/how_to_improve_the_quality_of_large_language/
Initial similarity: 0.795
Reranking score: 0.904
--------------------------------------------------------------------------------
RESULT 4:
Title: Is it possible to use llama.cpp or create Alpaca Lora for YALM-100b model?
Reddit Score: 14
URL: https://www.reddit.com/r/LocalLLaMA/comments/12q6288/is_it_possible_to_use_llamacpp_or_create_alpaca/
Initial similarity: 0.731
Reranking score: 0.887
--------------------------------------------------------------------------------
RESULT 5:
Title: Finetuning to beat ChatGPT: It's all about communication & management, these are already solved problems
Reddit Score: 32
URL: https://www.reddit.com/r/LocalLLaMA/comments/120e7m7/finetuning_to_beat_chatgpt_its_all_about/
Initial similarity: 0.747
Reranking score: 0.864
--------------------------------------------------------------------------------
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
example output