Created
April 24, 2023 14:46
-
-
Save bitbutter/981b18b48195163bb94c25a38b3d3d8d to your computer and use it in GitHub Desktop.
A python console program that uses Chroma and OpenAI to query text files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chromadb | |
import openai | |
import os | |
from tqdm import tqdm | |
import tiktoken | |
from chromadb.errors import NotEnoughElementsException | |
import re | |
from colorama import Fore, Style | |
# Instructions (assumes Windows OS) | |
# In the console/terminal use this command to install the necessary python libraries on your machine: pip install chromadb openai tqdm tiktoken colorama | |
# Place this script (knowledge_extractor.py) next to a directory named 'documents'. Put text files you want to use as sources of information inside this folder. | |
# Edit the line below to reflect your openAI api key. | |
# In the console, at the location of this script enter this command to run it: python ./knowledge_extractor.py | |
# openai.api_key = YOUR_OPENAI_API_KEY_HERE | |
openai.api_key = os.getenv("OPENAI_API_KEY") # delete this line if you're using the line above | |
chroma_client = chromadb.Client() | |
collection = chroma_client.create_collection(name="my_collection") | |
def detect_hard_line_breaks(data): | |
lines = data.strip().split('\n') | |
num_lines = len(lines) | |
lines_with_soft_breaks = sum(1 for line in lines if re.search(r'[.!?;:]\s+[a-z]', line)) | |
proportion_soft_breaks = lines_with_soft_breaks / num_lines | |
return proportion_soft_breaks < 0.1 # Adjust the threshold based on your observations | |
def chunk_text(text, chunk_size=300): | |
words = text.split() | |
chunks = [] | |
for i in range(0, len(words), chunk_size): | |
chunk = " ".join(words[i:i + chunk_size]) | |
chunks.append(chunk) | |
return chunks | |
def read_and_embed_file(file_name, title): | |
with open(file_name, 'r', encoding="utf-8") as file: | |
data = file.read() | |
if detect_hard_line_breaks(data): | |
# Split the text into chunks of 300 words each | |
paragraphs = chunk_text(data, 300) | |
else: | |
# Split into paragraphs using a regular expression | |
paragraphs = re.split(r'\n{1,2}', data.strip()) | |
metadata_list = [] | |
ids_list = [] | |
for idx, paragraph in enumerate(paragraphs): | |
metadata_list.append({"source": f"{title}"}) | |
ids_list.append(f"id{idx + 1}") | |
collection.add( | |
documents=paragraphs, | |
metadatas=metadata_list, | |
ids=ids_list | |
) | |
def generate_prompt(sources, question): | |
return f""" | |
Answer the question below usine the sources to get relevant information. Cite sources using in-text citations with square brackets. | |
For example: [1] refers to source 1 and [2] refers to source 2. Cite once per sentence. | |
If the context doesn't answer the question, output "I don't know". | |
Sources: {sources} | |
Question: {question} | |
Result:""" | |
def make_openai_call(context, question): | |
model_name = "text-davinci-003" | |
max_tokens=4097 | |
max_tokens = max_tokens - 200 # Reserve 200 tokens for the completion | |
sources = '' | |
total_tokens = 0 | |
# To get the tokeniser corresponding to a specific model in the OpenAI API: | |
encoding = tiktoken.encoding_for_model("text-davinci-003") | |
question_tokens=len(encoding.encode(question)) | |
for idx, paragraph in enumerate(context): | |
paragraph_tokens=len(encoding.encode(paragraph)) | |
if total_tokens + paragraph_tokens + question_tokens <= max_tokens: | |
sources += f"Source {idx + 1}: {paragraph}\n" | |
total_tokens += paragraph_tokens | |
else: | |
break | |
prompt = generate_prompt(sources, question) | |
response = openai.Completion.create( | |
model=model_name, | |
prompt=prompt, | |
temperature=0.6, | |
max_tokens=200, | |
top_p=1, | |
frequency_penalty=0.0, | |
presence_penalty=0.0, | |
stop=["\n"] | |
) | |
return response['choices'][0]['text'] | |
def pretty_print_results(query, summary, sources): | |
print("--------------------") | |
print(Fore.YELLOW + summary) | |
print(Fore.BLUE + "*********") | |
print("Sources:") | |
for idx, source in enumerate(sources): | |
print(f"{idx + 1}: {source}\n") | |
print("*********"+ Style.RESET_ALL) | |
if __name__ == "__main__": | |
documents_folder = "documents" | |
all_files = [file for file in os.listdir(documents_folder) if file.endswith(".txt")] | |
# Use tqdm to create a progress bar for the loop | |
for file in tqdm(all_files, desc="Processing files"): | |
file_path = os.path.join(documents_folder, file) | |
print(f"Processing: {file_path}") | |
read_and_embed_file(file_path, file.split('.')[0]) | |
while True: | |
query = input(Fore.GREEN +"Enter your question or type 'exit' to quit: ") | |
if query.lower() == "exit": | |
break | |
results = collection.query( | |
query_texts=[query], | |
n_results=5 | |
) | |
top_results = results["documents"][0] | |
pretty_print_results(query, make_openai_call( | |
top_results, query).strip(), top_results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment