Last active
January 23, 2025 19:50
-
-
Save sergeyklay/79c1771fd884d63270d5edf282d2199c to your computer and use it in GitHub Desktop.
A Retrieval-Augmented Generation (RAG) proof-of-concept that uses Dense Passage Retrieval (DPR) and GPT-2 to answer questions from a document.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from typing import List, Optional, Tuple | |
import faiss | |
import numpy as np | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
DPRContextEncoder, | |
DPRContextEncoderTokenizer, | |
DPRQuestionEncoder, | |
DPRQuestionEncoderTokenizer, | |
) | |
def read_and_split_text(filename: str) -> List[str]: | |
"Reads a text file and splits it into paragraphs." | |
with open(filename, 'r', encoding='utf-8') as file: | |
text = file.read() | |
# Split the text into paragraphs (simple split by newline characters) | |
paragraphs = text.split('\n') | |
# Filter out any empty paragraphs or undesired entries | |
return [p for p in paragraphs if p.strip()] | |
def encode_contexts(text_list: List[str], tokenizer, encoder) -> np.ndarray: | |
"Encode a list of texts into embeddings" | |
embeddings = [] | |
for text in text_list: | |
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256) | |
outputs = encoder(**inputs) | |
embeddings.append(outputs.pooler_output) | |
return torch.cat(embeddings).detach().numpy() | |
def search_relevant_contexts(question, tokenizer, encoder, index, k=5) -> Tuple[np.ndarray, np.ndarray]: | |
"Searches for the most relevant contexts to a given question." | |
# Tokenize the question | |
question_inputs = tokenizer(question, return_tensors='pt') | |
# Encode the question to get the embedding | |
question_embedding = encoder(**question_inputs).pooler_output.detach().numpy() | |
# Search the index to retrieve top k relevant contexts | |
return index.search(question_embedding, k=k) | |
def generate_answer(model, tokenizer, question: str, config: Optional[dict] = None, contexts: Optional[List[str]] = None) -> str: | |
"Generates an answer to a question using the retrieved contexts." | |
config = config or {} | |
# Concatenate the retrieved contexts to form the input to GPT2 | |
input_text = question + ' ' + ' '.join(contexts) if contexts is not None else question | |
# Tokenize the input question | |
inputs = tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True) | |
# Generate output directly from the question without additional context | |
summary_ids = model.generate( | |
inputs['input_ids'], | |
max_length=config.get('max_length', 200), | |
min_length=config.get('min_length', 40), | |
length_penalty=config.get('length_penalty', 2.0), | |
num_beams=config.get('num_beams', 4), | |
early_stopping=config.get('early_stopping', True), | |
pad_token_id=tokenizer.eos_token_id, | |
attention_mask=inputs['attention_mask'], | |
) | |
# Decode and return the generated text | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Testing the code | |
if __name__ == "__main__": | |
# Read the text file and split it into paragraphs | |
paragraphs = read_and_split_text('an-article-from-confluence.txt') | |
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base') | |
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base') | |
# Encode the paragraphs to create embeddings | |
context_embeddings = encode_contexts(paragraphs, context_tokenizer, context_encoder) | |
# Convert list of numpy arrays into a single numpy array | |
embedding_dim = 768 # This should match the dimension of embeddings | |
context_embeddings_np = np.array(context_embeddings).astype('float32') | |
# Create a FAISS index for the embeddings | |
index = faiss.IndexFlatL2(embedding_dim) | |
index.add(context_embeddings_np) # Add the context embeddings to the index | |
# Load DPR question encoder and tokenizer | |
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base') | |
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base') | |
# GPT2 model and tokenizer | |
model_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
model.generation_config.pad_token_id = model_tokenizer.pad_token_id | |
question = "Which Node.js Version Can I Use?" | |
_, I = search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5) | |
contexts = [paragraphs[idx] for idx in I[0]] | |
# Check if the terminal supports colors | |
colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() | |
# Print headers with colors if supported, plain text if not | |
red_color = "\033[91m" if colors_supported else "" | |
green_color = "\033[92m" if colors_supported else "" | |
reset_color = "\033[0m" if colors_supported else "" | |
print(f"\n{red_color}Answer without context:{reset_color}") | |
print("-"*100) | |
print(generate_answer(model, model_tokenizer, question)) | |
print(f"\n{green_color}Answer with context:{reset_color}") | |
print("-"*100) | |
print(generate_answer(model, model_tokenizer, question, contexts=contexts)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results
Answer without context
Answer with context