Skip to content

Instantly share code, notes, and snippets.

@dpflucas
Last active August 9, 2024 15:55
Show Gist options
  • Save dpflucas/079b0644990eaf13a0600a5964d14f4a to your computer and use it in GitHub Desktop.
Save dpflucas/079b0644990eaf13a0600a5964d14f4a to your computer and use it in GitHub Desktop.
AI agent with RAG to provide information about the 2024 Summer Olympics.
import os
import wikipedia
import pandas as pd
import pickle
import argparse
import warnings
import io
from dotenv import load_dotenv
from urllib3.exceptions import NotOpenSSLWarning
from llama_index.core import VectorStoreIndex, Settings, Document, StorageContext, load_index_from_storage
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
def load_environment():
load_dotenv()
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_VERSION"] = "2023-05-15" # Update this if you're using a different API version
os.environ["OPENAI_API_BASE"] = os.getenv("AZURE_OPENAI_ENDPOINT")
os.environ["OPENAI_API_KEY"] = os.getenv("AZURE_OPENAI_API_KEY")
global AZURE_DEPLOYMENT_NAME, AZURE_EMBEDDING_DEPLOYMENT_NAME
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")
AZURE_EMBEDDING_DEPLOYMENT_NAME = os.getenv("AZURE_EMBEDDING_DEPLOYMENT_NAME")
load_environment()
def fetch_wikipedia_content(title):
page = wikipedia.page(title)
text_content = page.content
tables = pd.read_html(io.StringIO(page.html()))
return text_content, tables
def table_to_markdown(df):
return df.to_markdown(index=False)
def load_or_fetch_wikipedia_data(wiki_titles, cache_file='wikipedia_cache.pkl', force_fetch=False):
if os.path.exists(cache_file) and not force_fetch:
print("Loading cached Wikipedia data...")
with open(cache_file, 'rb') as f:
return pickle.load(f)
print("Fetching Wikipedia data...")
data = {}
for title in wiki_titles:
text_content, tables = fetch_wikipedia_content(title)
data[title] = (text_content, tables)
with open(cache_file, 'wb') as f:
pickle.dump(data, f)
return data
def setup_llm_and_embed_model():
llm = AzureOpenAI(
engine=AZURE_DEPLOYMENT_NAME,
temperature=0,
)
embed_model = AzureOpenAIEmbedding(
model=AZURE_EMBEDDING_DEPLOYMENT_NAME,
deployment_name=AZURE_EMBEDDING_DEPLOYMENT_NAME,
)
Settings.llm = llm
Settings.embed_model = embed_model
def create_index(force_fetch=False):
wiki_titles = [
"List of 2024 Summer Olympics medal winners",
"2024 Summer Olympics medal table"
]
wikipedia_data = load_or_fetch_wikipedia_data(wiki_titles, force_fetch=force_fetch)
documents = []
for title, (text_content, tables) in wikipedia_data.items():
documents.append(Document(text=text_content, metadata={"source": title, "type": "text"}))
for i, table in enumerate(tables):
table_md = table_to_markdown(table)
documents.append(Document(text=f"Table {i+1} from {title}:\n{table_md}", metadata={"source": f"{title} - Table {i+1}", "type": "table"}))
parser = SimpleNodeParser.from_defaults(chunk_size=500, chunk_overlap=50)
nodes = parser.get_nodes_from_documents(documents)
setup_llm_and_embed_model()
index = VectorStoreIndex(nodes)
index.storage_context.persist("./storage")
return index
def load_or_create_index(force_fetch=False):
setup_llm_and_embed_model()
if os.path.exists("./storage") and not force_fetch:
print("Loading existing index...")
storage_context = StorageContext.from_defaults(persist_dir="./storage")
return load_index_from_storage(storage_context)
else:
print("Creating new index...")
return create_index(force_fetch)
def query_index(index, query):
query_engine = index.as_query_engine(similarity_top_k=10)
response = query_engine.query(query)
return response
def main(force_fetch):
warnings.filterwarnings("ignore", category=NotOpenSSLWarning)
index = load_or_create_index(force_fetch)
print("AI agent is ready. Type 'exit' to quit.")
while True:
user_query = input("Enter your question about the 2024 Summer Olympics: ")
if user_query.lower() == 'exit':
break
try:
response = query_index(index, user_query)
print(f"AI: {response.response}\n")
print("Sources:")
for node in response.source_nodes:
print(f"- {node.metadata['source']} ({node.metadata['type']})")
print()
except Exception as e:
print(f"An error occurred: {str(e)}")
print("I apologize, but I'm unable to provide a reliable answer to this question at the moment.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AI agent for 2024 Summer Olympics information")
parser.add_argument("--force-fetch", action="store_true", help="Force fetch new data from Wikipedia")
args = parser.parse_args()
main(args.force_fetch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment