Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active January 22, 2025 17:18
Show Gist options
  • Save pszemraj/19f440349542fdd9c151465bd85c5c3c to your computer and use it in GitHub Desktop.
Save pszemraj/19f440349542fdd9c151465bd85c5c3c to your computer and use it in GitHub Desktop.
basic rag search system for top_k posts w gte-modernbert-base
import logging
from typing import Dict, List
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import CrossEncoder, SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RedditRAG:
def __init__(self, model_max_length: int = None, rerank_top_m: int = 20):
"""Initialize the RAG system with FAISS, GTE ModernBERT, and reranking capabilities.
Args:
model_max_length: Maximum sequence length for the encoder
rerank_top_m: Number of initial results to consider for reranking (M)
"""
# Initialize embedding model
self.encoder = SentenceTransformer(
"Alibaba-NLP/gte-modernbert-base",
tokenizer_kwargs={
"model_max_length": model_max_length
if model_max_length is not None
else None
},
)
# Initialize reranker
self.reranker = CrossEncoder(
"Alibaba-NLP/gte-reranker-modernbert-base",
automodel_args={"torch_dtype": "auto"},
)
self.rerank_top_m = rerank_top_m
self.index = None
self.metadata = []
self.documents = [] # Store original documents for reranking
def prepare_data(self, dataset, max_rows=None):
"""Process dataset and prepare content for embedding."""
if max_rows is not None:
dataset = dataset.select(range(max_rows))
def combine_text(example):
"""Combine title and selftext."""
selftext = example["selftext"] if example["selftext"] else ""
example["content"] = f"{example['title']}\n\n{selftext}".strip()
return example
def standard_permalink(example):
"""Check and update relevant permalinks."""
base_link = example["permalink"]
if base_link is None or len(base_link.strip()) < 2:
return {"permalink": None}
return {
"permalink": f"https://www.reddit.com{base_link}"
if base_link.startswith("/r/")
else base_link
}
dataset = dataset.map(combine_text, desc="Combining title and selftext")
dataset = dataset.map(standard_permalink, desc="fixing permalinks")
logger.info(f"Prepared dataset with {len(dataset)} entries")
return dataset
def build_index(self, dataset, batch_size: int = 32):
"""Build FAISS index from dataset embeddings."""
logger.info("Computing embeddings...")
self.documents = dataset["content"]
embeddings = self.encoder.encode(
self.documents, normalize_embeddings=True, show_progress_bar=True
)
# Store metadata
self.metadata = [
{"title": title, "score": int(score), "permalink": url}
for title, score, url in zip(
dataset["title"], dataset["score"], dataset["permalink"]
)
]
# Initialize and populate FAISS index
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension)
# Normalize vectors for cosine similarity
faiss.normalize_L2(embeddings)
# Add vectors to index
self.index.add(embeddings)
logger.info(f"Built index with {len(dataset)} vectors")
def rerank_results(
self, query: str, initial_results: List[Dict], k: int
) -> List[Dict]:
"""Rerank the initial results using the cross-encoder model."""
# Prepare pairs for reranking
pairs = [[query, self.documents[result["index"]]] for result in initial_results]
# Get reranking scores
rerank_scores = self.reranker.predict(pairs)
# Add reranking scores to results
for score, result in zip(rerank_scores, initial_results):
result["rerank_score"] = float(score)
# Sort by reranking score and get top k
reranked_results = sorted(
initial_results, key=lambda x: x["rerank_score"], reverse=True
)[:k]
return reranked_results
def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
"""Query the FAISS index, rerank results, and return the most relevant posts."""
if self.index is None:
raise ValueError("Index not built. Call build_index first.")
# Encode and normalize query
query_embedding = self.encoder.encode([query_text])
faiss.normalize_L2(query_embedding)
# Search index - get more results than needed for reranking
scores, indices = self.index.search(query_embedding, self.rerank_top_m)
# Format initial results
initial_results = [
{
**self.metadata[idx],
"similarity_score": float(score),
"index": int(idx), # Store index for accessing original document
}
for score, idx in zip(scores[0], indices[0])
]
# Rerank results
final_results = self.rerank_results(query_text, initial_results, n_results)
# Remove the index from final results as it's no longer needed
for result in final_results:
del result["index"]
return final_results
def initialize_rag_system(
dataset_name="pszemraj/LocalLLaMA-posts",
max_rows: int = None,
model_max_length: int = None,
rerank_top_m: int = 20,
):
"""Initialize and populate the RAG system with the dataset."""
logger.info(f"Loading dataset: {dataset_name}")
ds = load_dataset(dataset_name)
logger.info(f"Dataset loaded with {len(ds['train'])} entries")
# Initialize RAG system
rag = RedditRAG(model_max_length=model_max_length, rerank_top_m=rerank_top_m)
# Prepare and index data
processed_ds = rag.prepare_data(ds["train"], max_rows)
rag.build_index(processed_ds)
return rag
# Example usage:
if __name__ == "__main__":
# Initialize with reranking of top 20 results
rag = initialize_rag_system(max_rows=1000, rerank_top_m=20)
my_query = "How to improve language model training?"
results = rag.query(my_query, n_results=5) # Get top 5 after reranking
print(f"Query: {my_query}")
print("-" * 80, "\n")
for i, result in enumerate(results, start=1):
print(f"RESULT {i}:")
print(f"\tTitle: {result['title']}")
print(f"\tReddit Score: {result['score']}")
print(f"\tURL: {result['permalink']}")
print(f"\tInitial similarity: {result['similarity_score']:.3f}")
print(f"\tReranking score: {result['rerank_score']:.3f}")
print("-" * 80, "\n")
datasets
faiss-cpu
numpy
sentence-transformers
transformers>=4.48.0
@pszemraj
Copy link
Author

example output

Query: How to improve language model training for waifu roleplay?
-------------------------------------------------------------------------------- 

RESULT 1:
	Title: Training Data Preparation (Instruction Fields)
	Reddit Score: 6
	URL: https://www.reddit.com/r/LocalLLaMA/comments/13g38hk/training_data_preparation_instruction_fields/
	Cos sim: 0.594
-------------------------------------------------------------------------------- 

RESULT 2:
	Title: Finetuning to beat ChatGPT: It's all about communication &amp; management, these are already solved problems
	Reddit Score: 32
	URL: https://www.reddit.com/r/LocalLLaMA/comments/120e7m7/finetuning_to_beat_chatgpt_its_all_about/
	Cos sim: 0.576
-------------------------------------------------------------------------------- 

RESULT 3:
	Title: How do I write a role-play prompt for instruct-style models?
	Reddit Score: 9
	URL: https://www.reddit.com/r/LocalLLaMA/comments/12ma64h/how_do_i_write_a_roleplay_prompt_for/
	Cos sim: 0.565
-------------------------------------------------------------------------------- 

RESULT 4:
	Title: Does anyone else have the problem that their model is forgetting its initial prompt
	Reddit Score: 11
	URL: https://www.reddit.com/r/LocalLLaMA/comments/12qyeij/does_anyone_else_have_the_problem_that_their/
	Cos sim: 0.563
-------------------------------------------------------------------------------- 

RESULT 5:
	Title: The state of LLM AIs, as explained by somebody who doesn't actually understand LLM AIs
	Reddit Score: 86
	URL: https://www.reddit.com/r/LocalLLaMA/comments/12ld62s/the_state_of_llm_ais_as_explained_by_somebody_who/
	Cos sim: 0.559
--------------------------------------------------------------------------------

@pszemraj
Copy link
Author

with gte + reranker

Batches: 100%
 32/32 [00:03<00:00, 31.90it/s]
Query: How to improve language model training?
-------------------------------------------------------------------------------- 

RESULT 1:
	Title: Engineering training for better quality
	Reddit Score: 3
	URL: https://www.reddit.com/r/LocalLLaMA/comments/13ejbwx/engineering_training_for_better_quality/
	Initial similarity: 0.729
	Reranking score: 0.909
-------------------------------------------------------------------------------- 

RESULT 2:
	Title: Training a model with larger context
	Reddit Score: 4
	URL: https://www.reddit.com/r/LocalLLaMA/comments/13gzwmv/training_a_model_with_larger_context/
	Initial similarity: 0.761
	Reranking score: 0.906
-------------------------------------------------------------------------------- 

RESULT 3:
	Title: How to improve the quality of Large Language Models and solve the alignment problem
	Reddit Score: 8
	URL: https://www.reddit.com/r/LocalLLaMA/comments/139iyl2/how_to_improve_the_quality_of_large_language/
	Initial similarity: 0.795
	Reranking score: 0.904
-------------------------------------------------------------------------------- 

RESULT 4:
	Title: Is it possible to use llama.cpp or create Alpaca Lora for YALM-100b model?
	Reddit Score: 14
	URL: https://www.reddit.com/r/LocalLLaMA/comments/12q6288/is_it_possible_to_use_llamacpp_or_create_alpaca/
	Initial similarity: 0.731
	Reranking score: 0.887
-------------------------------------------------------------------------------- 

RESULT 5:
	Title: Finetuning to beat ChatGPT: It's all about communication &amp; management, these are already solved problems
	Reddit Score: 32
	URL: https://www.reddit.com/r/LocalLLaMA/comments/120e7m7/finetuning_to_beat_chatgpt_its_all_about/
	Initial similarity: 0.747
	Reranking score: 0.864
-------------------------------------------------------------------------------- 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment