Skip to content

Instantly share code, notes, and snippets.

@sergeyklay
Last active January 23, 2025 19:50
Show Gist options
  • Save sergeyklay/79c1771fd884d63270d5edf282d2199c to your computer and use it in GitHub Desktop.
Save sergeyklay/79c1771fd884d63270d5edf282d2199c to your computer and use it in GitHub Desktop.
A Retrieval-Augmented Generation (RAG) proof-of-concept that uses Dense Passage Retrieval (DPR) and GPT-2 to answer questions from a document.
import sys
from typing import List, Optional, Tuple
import faiss
import numpy as np
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer,
)
def read_and_split_text(filename: str) -> List[str]:
"Reads a text file and splits it into paragraphs."
with open(filename, 'r', encoding='utf-8') as file:
text = file.read()
# Split the text into paragraphs (simple split by newline characters)
paragraphs = text.split('\n')
# Filter out any empty paragraphs or undesired entries
return [p for p in paragraphs if p.strip()]
def encode_contexts(text_list: List[str], tokenizer, encoder) -> np.ndarray:
"Encode a list of texts into embeddings"
embeddings = []
for text in text_list:
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
outputs = encoder(**inputs)
embeddings.append(outputs.pooler_output)
return torch.cat(embeddings).detach().numpy()
def search_relevant_contexts(question, tokenizer, encoder, index, k=5) -> Tuple[np.ndarray, np.ndarray]:
"Searches for the most relevant contexts to a given question."
# Tokenize the question
question_inputs = tokenizer(question, return_tensors='pt')
# Encode the question to get the embedding
question_embedding = encoder(**question_inputs).pooler_output.detach().numpy()
# Search the index to retrieve top k relevant contexts
return index.search(question_embedding, k=k)
def generate_answer(model, tokenizer, question: str, config: Optional[dict] = None, contexts: Optional[List[str]] = None) -> str:
"Generates an answer to a question using the retrieved contexts."
config = config or {}
# Concatenate the retrieved contexts to form the input to GPT2
input_text = question + ' ' + ' '.join(contexts) if contexts is not None else question
# Tokenize the input question
inputs = tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True)
# Generate output directly from the question without additional context
summary_ids = model.generate(
inputs['input_ids'],
max_length=config.get('max_length', 200),
min_length=config.get('min_length', 40),
length_penalty=config.get('length_penalty', 2.0),
num_beams=config.get('num_beams', 4),
early_stopping=config.get('early_stopping', True),
pad_token_id=tokenizer.eos_token_id,
attention_mask=inputs['attention_mask'],
)
# Decode and return the generated text
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Testing the code
if __name__ == "__main__":
# Read the text file and split it into paragraphs
paragraphs = read_and_split_text('an-article-from-confluence.txt')
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
# Encode the paragraphs to create embeddings
context_embeddings = encode_contexts(paragraphs, context_tokenizer, context_encoder)
# Convert list of numpy arrays into a single numpy array
embedding_dim = 768 # This should match the dimension of embeddings
context_embeddings_np = np.array(context_embeddings).astype('float32')
# Create a FAISS index for the embeddings
index = faiss.IndexFlatL2(embedding_dim)
index.add(context_embeddings_np) # Add the context embeddings to the index
# Load DPR question encoder and tokenizer
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
# GPT2 model and tokenizer
model_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
model.generation_config.pad_token_id = model_tokenizer.pad_token_id
question = "Which Node.js Version Can I Use?"
_, I = search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5)
contexts = [paragraphs[idx] for idx in I[0]]
# Check if the terminal supports colors
colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
# Print headers with colors if supported, plain text if not
red_color = "\033[91m" if colors_supported else ""
green_color = "\033[92m" if colors_supported else ""
reset_color = "\033[0m" if colors_supported else ""
print(f"\n{red_color}Answer without context:{reset_color}")
print("-"*100)
print(generate_answer(model, model_tokenizer, question))
print(f"\n{green_color}Answer with context:{reset_color}")
print("-"*100)
print(generate_answer(model, model_tokenizer, question, contexts=contexts))
@sergeyklay
Copy link
Author

Results

Answer without context

Which Node.js Version Can I Use?

If you are using Node.js version 1.0.0 or later, you can use the following command to install Node.js version 1.0.0 or later:

npm install -g npm install -g

If you are using Node.js version 1.0.0 or later, you can use the following command to install Node.js version 1.0.0 or later:

npm install -g npm install -g

If you are using Node.js version 1.0.0 or later, you can use the following command to install Node.js version 1.0.0 or later:

npm install -g npm install -g

If you are using Node.js version 1.0.0 or later, you can use the following command to install Node.js version 1.0.0 or later:

npm install -

Answer with context

Which Node.js Version Can I Use? Standard Node.js Version: All projects must use the LTS (Long Term Support) version of Node.js defined in the GitHub variable ${{ vars.NODE_VERSION }}. Therefore, the Node.js update in the product should occur once a year and only to an even version. To ensure a stable and predictable development environment, as well as to optimize application testing and deployment processes, the following Node.js version management strategy will be employed: The transition to a new version of Node.js should be conducted synchronously across all projects. A separate branch should be used to test the new version of Node.js, to avoid impacting current development and the stability of projects. After successful testing and resolution of all identified issues, the new version can be implemented into the main branches of the projects.

The transition to a new version of Node.js should be conducted synchronously across all projects. A separate branch should be used to

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment