Skip to content

Instantly share code, notes, and snippets.

@sergeyklay
Last active January 23, 2025 19:50

Revisions

  1. sergeyklay revised this gist Jan 23, 2025. No changes.
  2. sergeyklay revised this gist Jan 23, 2025. No changes.
  3. sergeyklay revised this gist Jan 23, 2025. 1 changed file with 5 additions and 6 deletions.
    11 changes: 5 additions & 6 deletions simple-rag-example.py
    Original file line number Diff line number Diff line change
    @@ -1,17 +1,17 @@
    import sys
    from typing import List, Optional, Tuple

    import faiss
    import numpy as np
    import torch
    from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoTokenizer,
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer
    DPRQuestionEncoderTokenizer,
    )
    import torch
    from typing import List, Optional, Tuple
    import numpy as np


    def read_and_split_text(filename: str) -> List[str]:
    @@ -107,7 +107,6 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No
    contexts = [paragraphs[idx] for idx in I[0]]

    # Check if the terminal supports colors
    import sys
    colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()

    # Print headers with colors if supported, plain text if not
  4. sergeyklay revised this gist Jan 23, 2025. 1 changed file with 20 additions and 23 deletions.
    43 changes: 20 additions & 23 deletions simple-rag-example.py
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,5 @@
    import sys

    import faiss
    from transformers import (
    AutoTokenizer,
    @@ -8,7 +10,7 @@
    DPRQuestionEncoderTokenizer
    )
    import torch
    from typing import List, Optional
    from typing import List, Optional, Tuple
    import numpy as np


    @@ -24,23 +26,23 @@ def read_and_split_text(filename: str) -> List[str]:
    return [p for p in paragraphs if p.strip()]


    def encode_contexts(text_list: list[str]) -> np.ndarray:
    def encode_contexts(text_list: List[str], tokenizer, encoder) -> np.ndarray:
    "Encode a list of texts into embeddings"
    embeddings = []
    for text in text_list:
    inputs = context_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
    outputs = context_encoder(**inputs)
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
    outputs = encoder(**inputs)
    embeddings.append(outputs.pooler_output)
    return torch.cat(embeddings).detach().numpy()


    def search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5):
    def search_relevant_contexts(question, tokenizer, encoder, index, k=5) -> Tuple[np.ndarray, np.ndarray]:
    "Searches for the most relevant contexts to a given question."
    # Tokenize the question
    question_inputs = question_tokenizer(question, return_tensors='pt')
    question_inputs = tokenizer(question, return_tensors='pt')

    # Encode the question to get the embedding
    question_embedding = question_encoder(**question_inputs).pooler_output.detach().numpy()
    question_embedding = encoder(**question_inputs).pooler_output.detach().numpy()

    # Search the index to retrieve top k relevant contexts
    return index.search(question_embedding, k=k)
    @@ -75,14 +77,13 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No
    # Testing the code
    if __name__ == "__main__":
    # Read the text file and split it into paragraphs
    # Some Node.js version usage Policy for our Company
    paragraphs = read_and_split_text('an-article-from-confluence.txt')

    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
    context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

    # Encode the paragraphs to create embeddings
    context_embeddings = encode_contexts(paragraphs)
    context_embeddings = encode_contexts(paragraphs, context_tokenizer, context_encoder)

    # Convert list of numpy arrays into a single numpy array
    embedding_dim = 768 # This should match the dimension of embeddings
    @@ -97,9 +98,9 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

    # GPT2 model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    model_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.pad_token_id = model_tokenizer.pad_token_id

    question = "Which Node.js Version Can I Use?"
    _, I = search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5)
    @@ -109,19 +110,15 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No
    import sys
    colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()

    print("")

    # Print headers with colors if supported, plain text if not
    if colors_supported:
    print("\033[91mAnswer without context:\033[0m")
    else:
    print("Answer without context:")
    red_color = "\033[91m" if colors_supported else ""
    green_color = "\033[92m" if colors_supported else ""
    reset_color = "\033[0m" if colors_supported else ""

    print(f"\n{red_color}Answer without context:{reset_color}")
    print("-"*100)
    print(generate_answer(model, tokenizer, question))
    print(generate_answer(model, model_tokenizer, question))

    if colors_supported:
    print("\033[92mAnswer with context:\033[0m")
    else:
    print("Answer with context:")
    print(f"\n{green_color}Answer with context:{reset_color}")
    print("-"*100)
    print(generate_answer(model, tokenizer, question, contexts=contexts))
    print(generate_answer(model, model_tokenizer, question, contexts=contexts))
  5. sergeyklay revised this gist Jan 23, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion simple-rag-example.py
    Original file line number Diff line number Diff line change
    @@ -12,7 +12,7 @@
    import numpy as np


    def read_and_split_text(filename: str) -> list[str]:
    def read_and_split_text(filename: str) -> List[str]:
    "Reads a text file and splits it into paragraphs."
    with open(filename, 'r', encoding='utf-8') as file:
    text = file.read()
  6. sergeyklay renamed this gist Jan 23, 2025. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  7. sergeyklay created this gist Jan 23, 2025.
    127 changes: 127 additions & 0 deletions sumple-rag-example.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,127 @@
    import faiss
    from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer
    )
    import torch
    from typing import List, Optional
    import numpy as np


    def read_and_split_text(filename: str) -> list[str]:
    "Reads a text file and splits it into paragraphs."
    with open(filename, 'r', encoding='utf-8') as file:
    text = file.read()

    # Split the text into paragraphs (simple split by newline characters)
    paragraphs = text.split('\n')

    # Filter out any empty paragraphs or undesired entries
    return [p for p in paragraphs if p.strip()]


    def encode_contexts(text_list: list[str]) -> np.ndarray:
    "Encode a list of texts into embeddings"
    embeddings = []
    for text in text_list:
    inputs = context_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
    outputs = context_encoder(**inputs)
    embeddings.append(outputs.pooler_output)
    return torch.cat(embeddings).detach().numpy()


    def search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5):
    "Searches for the most relevant contexts to a given question."
    # Tokenize the question
    question_inputs = question_tokenizer(question, return_tensors='pt')

    # Encode the question to get the embedding
    question_embedding = question_encoder(**question_inputs).pooler_output.detach().numpy()

    # Search the index to retrieve top k relevant contexts
    return index.search(question_embedding, k=k)


    def generate_answer(model, tokenizer, question: str, config: Optional[dict] = None, contexts: Optional[List[str]] = None) -> str:
    "Generates an answer to a question using the retrieved contexts."
    config = config or {}

    # Concatenate the retrieved contexts to form the input to GPT2
    input_text = question + ' ' + ' '.join(contexts) if contexts is not None else question

    # Tokenize the input question
    inputs = tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True)

    # Generate output directly from the question without additional context
    summary_ids = model.generate(
    inputs['input_ids'],
    max_length=config.get('max_length', 200),
    min_length=config.get('min_length', 40),
    length_penalty=config.get('length_penalty', 2.0),
    num_beams=config.get('num_beams', 4),
    early_stopping=config.get('early_stopping', True),
    pad_token_id=tokenizer.eos_token_id,
    attention_mask=inputs['attention_mask'],
    )

    # Decode and return the generated text
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)


    # Testing the code
    if __name__ == "__main__":
    # Read the text file and split it into paragraphs
    # Some Node.js version usage Policy for our Company
    paragraphs = read_and_split_text('an-article-from-confluence.txt')

    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
    context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

    # Encode the paragraphs to create embeddings
    context_embeddings = encode_contexts(paragraphs)

    # Convert list of numpy arrays into a single numpy array
    embedding_dim = 768 # This should match the dimension of embeddings
    context_embeddings_np = np.array(context_embeddings).astype('float32')

    # Create a FAISS index for the embeddings
    index = faiss.IndexFlatL2(embedding_dim)
    index.add(context_embeddings_np) # Add the context embeddings to the index

    # Load DPR question encoder and tokenizer
    question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

    # GPT2 model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
    model.generation_config.pad_token_id = tokenizer.pad_token_id

    question = "Which Node.js Version Can I Use?"
    _, I = search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5)
    contexts = [paragraphs[idx] for idx in I[0]]

    # Check if the terminal supports colors
    import sys
    colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()

    print("")

    # Print headers with colors if supported, plain text if not
    if colors_supported:
    print("\033[91mAnswer without context:\033[0m")
    else:
    print("Answer without context:")
    print("-"*100)
    print(generate_answer(model, tokenizer, question))

    if colors_supported:
    print("\033[92mAnswer with context:\033[0m")
    else:
    print("Answer with context:")
    print("-"*100)
    print(generate_answer(model, tokenizer, question, contexts=contexts))