Last active
August 9, 2024 15:55
-
-
Save dpflucas/079b0644990eaf13a0600a5964d14f4a to your computer and use it in GitHub Desktop.
AI agent with RAG to provide information about the 2024 Summer Olympics.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import wikipedia | |
import pandas as pd | |
import pickle | |
import argparse | |
import warnings | |
import io | |
from dotenv import load_dotenv | |
from urllib3.exceptions import NotOpenSSLWarning | |
from llama_index.core import VectorStoreIndex, Settings, Document, StorageContext, load_index_from_storage | |
from llama_index.core.node_parser import SimpleNodeParser | |
from llama_index.llms.azure_openai import AzureOpenAI | |
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding | |
def load_environment(): | |
load_dotenv() | |
os.environ["OPENAI_API_TYPE"] = "azure" | |
os.environ["OPENAI_API_VERSION"] = "2023-05-15" # Update this if you're using a different API version | |
os.environ["OPENAI_API_BASE"] = os.getenv("AZURE_OPENAI_ENDPOINT") | |
os.environ["OPENAI_API_KEY"] = os.getenv("AZURE_OPENAI_API_KEY") | |
global AZURE_DEPLOYMENT_NAME, AZURE_EMBEDDING_DEPLOYMENT_NAME | |
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME") | |
AZURE_EMBEDDING_DEPLOYMENT_NAME = os.getenv("AZURE_EMBEDDING_DEPLOYMENT_NAME") | |
load_environment() | |
def fetch_wikipedia_content(title): | |
page = wikipedia.page(title) | |
text_content = page.content | |
tables = pd.read_html(io.StringIO(page.html())) | |
return text_content, tables | |
def table_to_markdown(df): | |
return df.to_markdown(index=False) | |
def load_or_fetch_wikipedia_data(wiki_titles, cache_file='wikipedia_cache.pkl', force_fetch=False): | |
if os.path.exists(cache_file) and not force_fetch: | |
print("Loading cached Wikipedia data...") | |
with open(cache_file, 'rb') as f: | |
return pickle.load(f) | |
print("Fetching Wikipedia data...") | |
data = {} | |
for title in wiki_titles: | |
text_content, tables = fetch_wikipedia_content(title) | |
data[title] = (text_content, tables) | |
with open(cache_file, 'wb') as f: | |
pickle.dump(data, f) | |
return data | |
def setup_llm_and_embed_model(): | |
llm = AzureOpenAI( | |
engine=AZURE_DEPLOYMENT_NAME, | |
temperature=0, | |
) | |
embed_model = AzureOpenAIEmbedding( | |
model=AZURE_EMBEDDING_DEPLOYMENT_NAME, | |
deployment_name=AZURE_EMBEDDING_DEPLOYMENT_NAME, | |
) | |
Settings.llm = llm | |
Settings.embed_model = embed_model | |
def create_index(force_fetch=False): | |
wiki_titles = [ | |
"List of 2024 Summer Olympics medal winners", | |
"2024 Summer Olympics medal table" | |
] | |
wikipedia_data = load_or_fetch_wikipedia_data(wiki_titles, force_fetch=force_fetch) | |
documents = [] | |
for title, (text_content, tables) in wikipedia_data.items(): | |
documents.append(Document(text=text_content, metadata={"source": title, "type": "text"})) | |
for i, table in enumerate(tables): | |
table_md = table_to_markdown(table) | |
documents.append(Document(text=f"Table {i+1} from {title}:\n{table_md}", metadata={"source": f"{title} - Table {i+1}", "type": "table"})) | |
parser = SimpleNodeParser.from_defaults(chunk_size=500, chunk_overlap=50) | |
nodes = parser.get_nodes_from_documents(documents) | |
setup_llm_and_embed_model() | |
index = VectorStoreIndex(nodes) | |
index.storage_context.persist("./storage") | |
return index | |
def load_or_create_index(force_fetch=False): | |
setup_llm_and_embed_model() | |
if os.path.exists("./storage") and not force_fetch: | |
print("Loading existing index...") | |
storage_context = StorageContext.from_defaults(persist_dir="./storage") | |
return load_index_from_storage(storage_context) | |
else: | |
print("Creating new index...") | |
return create_index(force_fetch) | |
def query_index(index, query): | |
query_engine = index.as_query_engine(similarity_top_k=10) | |
response = query_engine.query(query) | |
return response | |
def main(force_fetch): | |
warnings.filterwarnings("ignore", category=NotOpenSSLWarning) | |
index = load_or_create_index(force_fetch) | |
print("AI agent is ready. Type 'exit' to quit.") | |
while True: | |
user_query = input("Enter your question about the 2024 Summer Olympics: ") | |
if user_query.lower() == 'exit': | |
break | |
try: | |
response = query_index(index, user_query) | |
print(f"AI: {response.response}\n") | |
print("Sources:") | |
for node in response.source_nodes: | |
print(f"- {node.metadata['source']} ({node.metadata['type']})") | |
print() | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
print("I apologize, but I'm unable to provide a reliable answer to this question at the moment.") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="AI agent for 2024 Summer Olympics information") | |
parser.add_argument("--force-fetch", action="store_true", help="Force fetch new data from Wikipedia") | |
args = parser.parse_args() | |
main(args.force_fetch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment