Created
March 1, 2025 22:14
-
-
Save Dhravya/f4e71ae9db31594d9970ca290837cb64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import re | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from typing import List, Dict, Tuple | |
from tqdm import tqdm | |
class InfiniRetri: | |
def __init__(self, model_name, chunk_size=512, top_k=50, phrase_window=5, device=None): | |
self.chunk_size = chunk_size | |
self.top_k = top_k | |
self.phrase_window = phrase_window | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Loading model {model_name} on {self.device}") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
output_attentions=True, | |
attn_implementation="eager" | |
).to(self.device) | |
self.model.eval() | |
self.reset_cache() | |
def reset_cache(self): | |
self.cache = [] | |
def chunk_document(self, document): | |
sentences = re.split(r'(?<=[.!?])\s+', document) | |
sentences = [s for s in sentences if s.strip()] | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
for sentence in sentences: | |
tokens = self.tokenizer.encode(sentence, add_special_tokens=False) | |
token_count = len(tokens) | |
if current_length + token_count > self.chunk_size and current_chunk: | |
chunks.append({ | |
"text": " ".join(current_chunk), | |
"sentences": current_chunk | |
}) | |
current_chunk = [sentence] | |
current_length = token_count | |
else: | |
current_chunk.append(sentence) | |
current_length += token_count | |
if current_chunk: | |
chunks.append({ | |
"text": " ".join(current_chunk), | |
"sentences": current_chunk | |
}) | |
print(f"Split document into {len(chunks)} chunks") | |
return chunks | |
def get_attention_scores(self, query, context): | |
full_text = f"{query}\n\n{context}" | |
inputs = self.tokenizer(full_text, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs, output_attentions=True) | |
last_layer_attention = outputs.attentions[-1] | |
avg_attention = last_layer_attention.mean(dim=1).squeeze(0) | |
query_tokens = self.tokenizer.encode(query, add_special_tokens=False) | |
query_length = len(query_tokens) | |
return avg_attention, query_length | |
def calculate_token_importance(self, attention_matrix, query_length): | |
query_to_context = attention_matrix[:query_length, query_length:] | |
token_importance = query_to_context.sum(dim=0) | |
context_length = token_importance.size(0) | |
phrase_importance = torch.zeros(context_length, device=self.device) | |
for i in range(context_length): | |
start = max(0, i - self.phrase_window // 2) | |
end = min(context_length, i + self.phrase_window // 2 + 1) | |
phrase_importance[i] = token_importance[start:end].sum() | |
return phrase_importance | |
def identify_important_sentences(self, context, importance_scores, query_length): | |
if len(importance_scores) > self.top_k: | |
_, top_indices = torch.topk(importance_scores, self.top_k) | |
top_indices = top_indices.cpu().numpy() | |
else: | |
top_indices = torch.arange(len(importance_scores)).cpu().numpy() | |
context_tokens = self.tokenizer.encode(context, add_special_tokens=False) | |
sentences = re.split(r'(?<=[.!?])\s+', context) | |
sentences = [s for s in sentences if s.strip()] | |
token_to_sentence = {} | |
position = 0 | |
for i, sentence in enumerate(sentences): | |
sent_tokens = self.tokenizer.encode(sentence, add_special_tokens=False) | |
for j in range(len(sent_tokens)): | |
if position < len(context_tokens): | |
token_to_sentence[position] = i | |
position += 1 | |
important_sentence_indices = set() | |
for token_idx in top_indices: | |
if token_idx < len(token_to_sentence): | |
sentence_idx = token_to_sentence[token_idx] | |
important_sentence_indices.add(sentence_idx) | |
important_sentences = [sentences[idx] for idx in important_sentence_indices if idx < len(sentences)] | |
return important_sentences | |
def process_document(self, document, query): | |
self.reset_cache() | |
chunks = self.chunk_document(document) | |
all_important_sentences = [] | |
for i, chunk in enumerate(tqdm(chunks, desc="Processing chunks")): | |
print(f"\nProcessing chunk {i+1}/{len(chunks)}") | |
if self.cache: | |
merged_context = " ".join(self.cache + [chunk["text"]]) | |
else: | |
merged_context = chunk["text"] | |
if len(self.tokenizer.encode(merged_context)) > 1024: | |
print("Warning: Merged context too large, truncating...") | |
merged_context = merged_context[:1024] | |
attention_matrix, query_length = self.get_attention_scores(query, merged_context) | |
importance_scores = self.calculate_token_importance(attention_matrix, query_length) | |
important_sentences = self.identify_important_sentences( | |
merged_context, importance_scores, query_length | |
) | |
print(f"Found {len(important_sentences)} important sentences") | |
self.cache = important_sentences | |
all_important_sentences.extend(important_sentences) | |
seen = set() | |
unique_sentences = [s for s in all_important_sentences if not (s in seen or seen.add(s))] | |
return unique_sentences | |
def test_with_made_up_fact(): | |
print("Testing InfiniRetri with a made-up fact...") | |
retriever = InfiniRetri( | |
model_name="gpt2", | |
chunk_size=256, | |
top_k=5, | |
phrase_window=3 | |
) | |
needle = "The fictitious city of Zorbania is the capital of the imaginary country Fantasia, which has a population of 5.7 million people." | |
filler_sentences = [ | |
"Countries around the world have diverse cultures and histories.", | |
"Some nations are known for their mountainous landscapes.", | |
"Others are famous for their coastal areas and beaches.", | |
"Political systems vary widely across different regions.", | |
"Economic development follows different patterns globally.", | |
"Educational systems are structured differently in various countries.", | |
"Cultural heritage is preserved through numerous institutions.", | |
"Technological advancement varies by region and country.", | |
"Agricultural practices depend on climate and geography.", | |
"Tourism is an important industry for many nations." | |
] | |
document_parts = [] | |
for _ in range(30): | |
document_parts.append(filler_sentences[_ % len(filler_sentences)]) | |
filler_before = filler_sentences[:5] | |
filler_after = filler_sentences[5:] | |
document_parts.extend(filler_before) | |
document_parts.append(needle) | |
document_parts.extend(filler_after) | |
for _ in range(30): | |
document_parts.append(filler_sentences[(_ + 5) % len(filler_sentences)]) | |
full_document = " ".join(document_parts) | |
query = "What is the capital of Fantasia?" | |
retrieved_sentences = retriever.process_document(full_document, query) | |
print("\nQuery:", query) | |
print("\nRetrieved Sentences:") | |
for i, sentence in enumerate(retrieved_sentences): | |
print(f"{i+1}. {sentence}") | |
if needle in retrieved_sentences: | |
print("\nSUCCESS: The needle was found in the retrieved sentences!") | |
else: | |
print("\nFAILURE: The needle was not found in the retrieved sentences.") | |
print("\nNeedle:", needle) | |
if __name__ == "__main__": | |
test_with_made_up_fact() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment