Skip to content

Instantly share code, notes, and snippets.

@Dhravya
Created March 1, 2025 22:14
Show Gist options
  • Save Dhravya/f4e71ae9db31594d9970ca290837cb64 to your computer and use it in GitHub Desktop.
Save Dhravya/f4e71ae9db31594d9970ca290837cb64 to your computer and use it in GitHub Desktop.
import torch
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict, Tuple
from tqdm import tqdm
class InfiniRetri:
def __init__(self, model_name, chunk_size=512, top_k=50, phrase_window=5, device=None):
self.chunk_size = chunk_size
self.top_k = top_k
self.phrase_window = phrase_window
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading model {model_name} on {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
output_attentions=True,
attn_implementation="eager"
).to(self.device)
self.model.eval()
self.reset_cache()
def reset_cache(self):
self.cache = []
def chunk_document(self, document):
sentences = re.split(r'(?<=[.!?])\s+', document)
sentences = [s for s in sentences if s.strip()]
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
tokens = self.tokenizer.encode(sentence, add_special_tokens=False)
token_count = len(tokens)
if current_length + token_count > self.chunk_size and current_chunk:
chunks.append({
"text": " ".join(current_chunk),
"sentences": current_chunk
})
current_chunk = [sentence]
current_length = token_count
else:
current_chunk.append(sentence)
current_length += token_count
if current_chunk:
chunks.append({
"text": " ".join(current_chunk),
"sentences": current_chunk
})
print(f"Split document into {len(chunks)} chunks")
return chunks
def get_attention_scores(self, query, context):
full_text = f"{query}\n\n{context}"
inputs = self.tokenizer(full_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
last_layer_attention = outputs.attentions[-1]
avg_attention = last_layer_attention.mean(dim=1).squeeze(0)
query_tokens = self.tokenizer.encode(query, add_special_tokens=False)
query_length = len(query_tokens)
return avg_attention, query_length
def calculate_token_importance(self, attention_matrix, query_length):
query_to_context = attention_matrix[:query_length, query_length:]
token_importance = query_to_context.sum(dim=0)
context_length = token_importance.size(0)
phrase_importance = torch.zeros(context_length, device=self.device)
for i in range(context_length):
start = max(0, i - self.phrase_window // 2)
end = min(context_length, i + self.phrase_window // 2 + 1)
phrase_importance[i] = token_importance[start:end].sum()
return phrase_importance
def identify_important_sentences(self, context, importance_scores, query_length):
if len(importance_scores) > self.top_k:
_, top_indices = torch.topk(importance_scores, self.top_k)
top_indices = top_indices.cpu().numpy()
else:
top_indices = torch.arange(len(importance_scores)).cpu().numpy()
context_tokens = self.tokenizer.encode(context, add_special_tokens=False)
sentences = re.split(r'(?<=[.!?])\s+', context)
sentences = [s for s in sentences if s.strip()]
token_to_sentence = {}
position = 0
for i, sentence in enumerate(sentences):
sent_tokens = self.tokenizer.encode(sentence, add_special_tokens=False)
for j in range(len(sent_tokens)):
if position < len(context_tokens):
token_to_sentence[position] = i
position += 1
important_sentence_indices = set()
for token_idx in top_indices:
if token_idx < len(token_to_sentence):
sentence_idx = token_to_sentence[token_idx]
important_sentence_indices.add(sentence_idx)
important_sentences = [sentences[idx] for idx in important_sentence_indices if idx < len(sentences)]
return important_sentences
def process_document(self, document, query):
self.reset_cache()
chunks = self.chunk_document(document)
all_important_sentences = []
for i, chunk in enumerate(tqdm(chunks, desc="Processing chunks")):
print(f"\nProcessing chunk {i+1}/{len(chunks)}")
if self.cache:
merged_context = " ".join(self.cache + [chunk["text"]])
else:
merged_context = chunk["text"]
if len(self.tokenizer.encode(merged_context)) > 1024:
print("Warning: Merged context too large, truncating...")
merged_context = merged_context[:1024]
attention_matrix, query_length = self.get_attention_scores(query, merged_context)
importance_scores = self.calculate_token_importance(attention_matrix, query_length)
important_sentences = self.identify_important_sentences(
merged_context, importance_scores, query_length
)
print(f"Found {len(important_sentences)} important sentences")
self.cache = important_sentences
all_important_sentences.extend(important_sentences)
seen = set()
unique_sentences = [s for s in all_important_sentences if not (s in seen or seen.add(s))]
return unique_sentences
def test_with_made_up_fact():
print("Testing InfiniRetri with a made-up fact...")
retriever = InfiniRetri(
model_name="gpt2",
chunk_size=256,
top_k=5,
phrase_window=3
)
needle = "The fictitious city of Zorbania is the capital of the imaginary country Fantasia, which has a population of 5.7 million people."
filler_sentences = [
"Countries around the world have diverse cultures and histories.",
"Some nations are known for their mountainous landscapes.",
"Others are famous for their coastal areas and beaches.",
"Political systems vary widely across different regions.",
"Economic development follows different patterns globally.",
"Educational systems are structured differently in various countries.",
"Cultural heritage is preserved through numerous institutions.",
"Technological advancement varies by region and country.",
"Agricultural practices depend on climate and geography.",
"Tourism is an important industry for many nations."
]
document_parts = []
for _ in range(30):
document_parts.append(filler_sentences[_ % len(filler_sentences)])
filler_before = filler_sentences[:5]
filler_after = filler_sentences[5:]
document_parts.extend(filler_before)
document_parts.append(needle)
document_parts.extend(filler_after)
for _ in range(30):
document_parts.append(filler_sentences[(_ + 5) % len(filler_sentences)])
full_document = " ".join(document_parts)
query = "What is the capital of Fantasia?"
retrieved_sentences = retriever.process_document(full_document, query)
print("\nQuery:", query)
print("\nRetrieved Sentences:")
for i, sentence in enumerate(retrieved_sentences):
print(f"{i+1}. {sentence}")
if needle in retrieved_sentences:
print("\nSUCCESS: The needle was found in the retrieved sentences!")
else:
print("\nFAILURE: The needle was not found in the retrieved sentences.")
print("\nNeedle:", needle)
if __name__ == "__main__":
test_with_made_up_fact()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment