Skip to content

Instantly share code, notes, and snippets.

@bitbutter
Created April 24, 2023 14:46
Show Gist options
  • Save bitbutter/981b18b48195163bb94c25a38b3d3d8d to your computer and use it in GitHub Desktop.
Save bitbutter/981b18b48195163bb94c25a38b3d3d8d to your computer and use it in GitHub Desktop.
A python console program that uses Chroma and OpenAI to query text files
import chromadb
import openai
import os
from tqdm import tqdm
import tiktoken
from chromadb.errors import NotEnoughElementsException
import re
from colorama import Fore, Style
# Instructions (assumes Windows OS)
# In the console/terminal use this command to install the necessary python libraries on your machine: pip install chromadb openai tqdm tiktoken colorama
# Place this script (knowledge_extractor.py) next to a directory named 'documents'. Put text files you want to use as sources of information inside this folder.
# Edit the line below to reflect your openAI api key.
# In the console, at the location of this script enter this command to run it: python ./knowledge_extractor.py
# openai.api_key = YOUR_OPENAI_API_KEY_HERE
openai.api_key = os.getenv("OPENAI_API_KEY") # delete this line if you're using the line above
chroma_client = chromadb.Client()
collection = chroma_client.create_collection(name="my_collection")
def detect_hard_line_breaks(data):
lines = data.strip().split('\n')
num_lines = len(lines)
lines_with_soft_breaks = sum(1 for line in lines if re.search(r'[.!?;:]\s+[a-z]', line))
proportion_soft_breaks = lines_with_soft_breaks / num_lines
return proportion_soft_breaks < 0.1 # Adjust the threshold based on your observations
def chunk_text(text, chunk_size=300):
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i:i + chunk_size])
chunks.append(chunk)
return chunks
def read_and_embed_file(file_name, title):
with open(file_name, 'r', encoding="utf-8") as file:
data = file.read()
if detect_hard_line_breaks(data):
# Split the text into chunks of 300 words each
paragraphs = chunk_text(data, 300)
else:
# Split into paragraphs using a regular expression
paragraphs = re.split(r'\n{1,2}', data.strip())
metadata_list = []
ids_list = []
for idx, paragraph in enumerate(paragraphs):
metadata_list.append({"source": f"{title}"})
ids_list.append(f"id{idx + 1}")
collection.add(
documents=paragraphs,
metadatas=metadata_list,
ids=ids_list
)
def generate_prompt(sources, question):
return f"""
Answer the question below usine the sources to get relevant information. Cite sources using in-text citations with square brackets.
For example: [1] refers to source 1 and [2] refers to source 2. Cite once per sentence.
If the context doesn't answer the question, output "I don't know".
Sources: {sources}
Question: {question}
Result:"""
def make_openai_call(context, question):
model_name = "text-davinci-003"
max_tokens=4097
max_tokens = max_tokens - 200 # Reserve 200 tokens for the completion
sources = ''
total_tokens = 0
# To get the tokeniser corresponding to a specific model in the OpenAI API:
encoding = tiktoken.encoding_for_model("text-davinci-003")
question_tokens=len(encoding.encode(question))
for idx, paragraph in enumerate(context):
paragraph_tokens=len(encoding.encode(paragraph))
if total_tokens + paragraph_tokens + question_tokens <= max_tokens:
sources += f"Source {idx + 1}: {paragraph}\n"
total_tokens += paragraph_tokens
else:
break
prompt = generate_prompt(sources, question)
response = openai.Completion.create(
model=model_name,
prompt=prompt,
temperature=0.6,
max_tokens=200,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=["\n"]
)
return response['choices'][0]['text']
def pretty_print_results(query, summary, sources):
print("--------------------")
print(Fore.YELLOW + summary)
print(Fore.BLUE + "*********")
print("Sources:")
for idx, source in enumerate(sources):
print(f"{idx + 1}: {source}\n")
print("*********"+ Style.RESET_ALL)
if __name__ == "__main__":
documents_folder = "documents"
all_files = [file for file in os.listdir(documents_folder) if file.endswith(".txt")]
# Use tqdm to create a progress bar for the loop
for file in tqdm(all_files, desc="Processing files"):
file_path = os.path.join(documents_folder, file)
print(f"Processing: {file_path}")
read_and_embed_file(file_path, file.split('.')[0])
while True:
query = input(Fore.GREEN +"Enter your question or type 'exit' to quit: ")
if query.lower() == "exit":
break
results = collection.query(
query_texts=[query],
n_results=5
)
top_results = results["documents"][0]
pretty_print_results(query, make_openai_call(
top_results, query).strip(), top_results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment