Last active
January 23, 2025 19:50
Revisions
-
sergeyklay revised this gist
Jan 23, 2025 . No changes.There are no files selected for viewing
-
sergeyklay revised this gist
Jan 23, 2025 . No changes.There are no files selected for viewing
-
sergeyklay revised this gist
Jan 23, 2025 . 1 changed file with 5 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,17 +1,17 @@ import sys from typing import List, Optional, Tuple import faiss import numpy as np import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, ) def read_and_split_text(filename: str) -> List[str]: @@ -107,7 +107,6 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No contexts = [paragraphs[idx] for idx in I[0]] # Check if the terminal supports colors colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() # Print headers with colors if supported, plain text if not -
sergeyklay revised this gist
Jan 23, 2025 . 1 changed file with 20 additions and 23 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,5 @@ import sys import faiss from transformers import ( AutoTokenizer, @@ -8,7 +10,7 @@ DPRQuestionEncoderTokenizer ) import torch from typing import List, Optional, Tuple import numpy as np @@ -24,23 +26,23 @@ def read_and_split_text(filename: str) -> List[str]: return [p for p in paragraphs if p.strip()] def encode_contexts(text_list: List[str], tokenizer, encoder) -> np.ndarray: "Encode a list of texts into embeddings" embeddings = [] for text in text_list: inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256) outputs = encoder(**inputs) embeddings.append(outputs.pooler_output) return torch.cat(embeddings).detach().numpy() def search_relevant_contexts(question, tokenizer, encoder, index, k=5) -> Tuple[np.ndarray, np.ndarray]: "Searches for the most relevant contexts to a given question." # Tokenize the question question_inputs = tokenizer(question, return_tensors='pt') # Encode the question to get the embedding question_embedding = encoder(**question_inputs).pooler_output.detach().numpy() # Search the index to retrieve top k relevant contexts return index.search(question_embedding, k=k) @@ -75,14 +77,13 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No # Testing the code if __name__ == "__main__": # Read the text file and split it into paragraphs paragraphs = read_and_split_text('an-article-from-confluence.txt') context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base') context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base') # Encode the paragraphs to create embeddings context_embeddings = encode_contexts(paragraphs, context_tokenizer, context_encoder) # Convert list of numpy arrays into a single numpy array embedding_dim = 768 # This should match the dimension of embeddings @@ -97,9 +98,9 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base') # GPT2 model and tokenizer model_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") model.generation_config.pad_token_id = model_tokenizer.pad_token_id question = "Which Node.js Version Can I Use?" _, I = search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5) @@ -109,19 +110,15 @@ def generate_answer(model, tokenizer, question: str, config: Optional[dict] = No import sys colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() # Print headers with colors if supported, plain text if not red_color = "\033[91m" if colors_supported else "" green_color = "\033[92m" if colors_supported else "" reset_color = "\033[0m" if colors_supported else "" print(f"\n{red_color}Answer without context:{reset_color}") print("-"*100) print(generate_answer(model, model_tokenizer, question)) print(f"\n{green_color}Answer with context:{reset_color}") print("-"*100) print(generate_answer(model, model_tokenizer, question, contexts=contexts)) -
sergeyklay revised this gist
Jan 23, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,7 +12,7 @@ import numpy as np def read_and_split_text(filename: str) -> List[str]: "Reads a text file and splits it into paragraphs." with open(filename, 'r', encoding='utf-8') as file: text = file.read() -
sergeyklay renamed this gist
Jan 23, 2025 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
sergeyklay created this gist
Jan 23, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,127 @@ import faiss from transformers import ( AutoTokenizer, AutoModelForCausalLM, DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer ) import torch from typing import List, Optional import numpy as np def read_and_split_text(filename: str) -> list[str]: "Reads a text file and splits it into paragraphs." with open(filename, 'r', encoding='utf-8') as file: text = file.read() # Split the text into paragraphs (simple split by newline characters) paragraphs = text.split('\n') # Filter out any empty paragraphs or undesired entries return [p for p in paragraphs if p.strip()] def encode_contexts(text_list: list[str]) -> np.ndarray: "Encode a list of texts into embeddings" embeddings = [] for text in text_list: inputs = context_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256) outputs = context_encoder(**inputs) embeddings.append(outputs.pooler_output) return torch.cat(embeddings).detach().numpy() def search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5): "Searches for the most relevant contexts to a given question." # Tokenize the question question_inputs = question_tokenizer(question, return_tensors='pt') # Encode the question to get the embedding question_embedding = question_encoder(**question_inputs).pooler_output.detach().numpy() # Search the index to retrieve top k relevant contexts return index.search(question_embedding, k=k) def generate_answer(model, tokenizer, question: str, config: Optional[dict] = None, contexts: Optional[List[str]] = None) -> str: "Generates an answer to a question using the retrieved contexts." config = config or {} # Concatenate the retrieved contexts to form the input to GPT2 input_text = question + ' ' + ' '.join(contexts) if contexts is not None else question # Tokenize the input question inputs = tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True) # Generate output directly from the question without additional context summary_ids = model.generate( inputs['input_ids'], max_length=config.get('max_length', 200), min_length=config.get('min_length', 40), length_penalty=config.get('length_penalty', 2.0), num_beams=config.get('num_beams', 4), early_stopping=config.get('early_stopping', True), pad_token_id=tokenizer.eos_token_id, attention_mask=inputs['attention_mask'], ) # Decode and return the generated text return tokenizer.decode(summary_ids[0], skip_special_tokens=True) # Testing the code if __name__ == "__main__": # Read the text file and split it into paragraphs # Some Node.js version usage Policy for our Company paragraphs = read_and_split_text('an-article-from-confluence.txt') context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base') context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base') # Encode the paragraphs to create embeddings context_embeddings = encode_contexts(paragraphs) # Convert list of numpy arrays into a single numpy array embedding_dim = 768 # This should match the dimension of embeddings context_embeddings_np = np.array(context_embeddings).astype('float32') # Create a FAISS index for the embeddings index = faiss.IndexFlatL2(embedding_dim) index.add(context_embeddings_np) # Add the context embeddings to the index # Load DPR question encoder and tokenizer question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base') question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base') # GPT2 model and tokenizer tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") model.generation_config.pad_token_id = tokenizer.pad_token_id question = "Which Node.js Version Can I Use?" _, I = search_relevant_contexts(question, question_tokenizer, question_encoder, index, k=5) contexts = [paragraphs[idx] for idx in I[0]] # Check if the terminal supports colors import sys colors_supported = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() print("") # Print headers with colors if supported, plain text if not if colors_supported: print("\033[91mAnswer without context:\033[0m") else: print("Answer without context:") print("-"*100) print(generate_answer(model, tokenizer, question)) if colors_supported: print("\033[92mAnswer with context:\033[0m") else: print("Answer with context:") print("-"*100) print(generate_answer(model, tokenizer, question, contexts=contexts))