Created
September 20, 2024 11:16
-
-
Save janakiramm/30cdc8dda557379d76e85078f1fb48ef to your computer and use it in GitHub Desktop.
NIM-LangChain-RAG-Agent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from datetime import datetime, timedelta | |
import pytz | |
import requests | |
import json | |
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.tools import tool | |
from langchain_core.messages import HumanMessage | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
# Environment setup | |
# Replace these with your actual API keys | |
os.environ["NVIDIA_API_KEY"] = "your_nvidia_api_key" | |
AEROAPI_BASE_URL = "https://aeroapi.flightaware.com/aeroapi" | |
AEROAPI_KEY = "your_aero_api_key" | |
# Initialize LLM | |
llm = ChatNVIDIA(model="meta/llama-3.1-405b-instruct") | |
# Flight status tool | |
@tool | |
def get_flight_status(flight_id: str): | |
""" | |
Returns flight information for a given flight ID. | |
""" | |
def get_api_session(): | |
session = requests.Session() | |
session.headers.update({"x-apikey": AEROAPI_KEY}) | |
return session | |
def fetch_flight_data(flight_id, session): | |
# Extract flight_id if it contains 'flight_id=' | |
if "flight_id=" in flight_id: | |
flight_id = flight_id.split("flight_id=")[1] | |
# Define the time range for the API query | |
start_date = datetime.now().date().strftime('%Y-%m-%d') | |
end_date = (datetime.now().date() + timedelta(days=1)).strftime('%Y-%m-%d') | |
api_resource = f"/flights/{flight_id}?start={start_date}&end={end_date}" | |
# Make the API request | |
response = session.get(f"{AEROAPI_BASE_URL}{api_resource}") | |
response.raise_for_status() | |
flights = response.json().get('flights', []) | |
if not flights: | |
raise ValueError(f"No flight data found for flight ID {flight_id}.") | |
return flights[0] | |
def utc_to_local(utc_date_str, local_timezone_str): | |
utc_datetime = datetime.strptime(utc_date_str, '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=pytz.utc) | |
local_timezone = pytz.timezone(local_timezone_str) | |
local_datetime = utc_datetime.astimezone(local_timezone) | |
return local_datetime.strftime('%Y-%m-%d %H:%M:%S') | |
# Get session and fetch flight data | |
session = get_api_session() | |
flight_data = fetch_flight_data(flight_id, session) | |
# Determine departure and arrival keys | |
dep_key = 'estimated_out' if flight_data.get('estimated_out') else 'scheduled_out' | |
arr_key = 'estimated_in' if flight_data.get('estimated_in') else 'scheduled_in' | |
# Build flight details | |
flight_details = { | |
'source': flight_data['origin']['city'], | |
'destination': flight_data['destination']['city'], | |
'depart_time': utc_to_local(flight_data[dep_key], flight_data['origin']['timezone']), | |
'arrival_time': utc_to_local(flight_data[arr_key], flight_data['destination']['timezone']), | |
'status': flight_data['status'] | |
} | |
return ( | |
f"The current status of flight {flight_id} from {flight_details['source']} to {flight_details['destination']} " | |
f"is {flight_details['status']} with departure time at {flight_details['depart_time']} and arrival time at " | |
f"{flight_details['arrival_time']}." | |
) | |
# LLM with tools | |
llm_with_tools = llm.bind_tools([get_flight_status], tool_choice="required") | |
# Document loading and processing | |
def load_and_process_documents(url): | |
""" | |
Loads documents from a URL and splits them into chunks for processing. | |
""" | |
loader = WebBaseLoader(url) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
return text_splitter.split_documents(docs) | |
# Vector store setup | |
def setup_vector_store(documents): | |
""" | |
Sets up a vector store for document retrieval using embeddings. | |
""" | |
embeddings = NVIDIAEmbeddings() | |
vector_store = FAISS.from_documents(documents, embeddings) | |
return vector_store.as_retriever() | |
# Load and process documents | |
documents = load_and_process_documents( | |
"https://www.emirates.com/in/english/before-you-fly/baggage/cabin-baggage-rules/" | |
) | |
# Setup vector store and retriever | |
retriever = setup_vector_store(documents) | |
# Retrieval function | |
def retrieve(input_dict): | |
""" | |
Retrieves an answer based on the question and the context from documents. | |
""" | |
question = input_dict["question"] | |
docs = retriever.invoke(question) | |
context = " ".join(doc.page_content for doc in docs) | |
evaluation_prompt = ( | |
f"Based on the following context, can you answer the question '{question}'? " | |
"If yes, provide the answer. If no, respond with 'Unable to answer based on the given context.'\n\n" | |
f"Context: {context}" | |
) | |
evaluation_messages = [HumanMessage(content=evaluation_prompt)] | |
evaluation_result = llm.invoke(evaluation_messages) | |
if "Unable to answer based on the given context" in evaluation_result.content: | |
final_answer = use_flight_status_tool(question) | |
else: | |
final_answer = evaluation_result.content.strip() | |
return { | |
"context": context, | |
"question": question, | |
"answer": final_answer | |
} | |
def use_flight_status_tool(question): | |
""" | |
Uses the flight status tool to answer flight status related questions. | |
""" | |
tool_messages = [HumanMessage(content=question)] | |
ai_msg = llm_with_tools.invoke(tool_messages) | |
if hasattr(ai_msg, 'tool_calls') and ai_msg.tool_calls: | |
tool_call = ai_msg.tool_calls[0] | |
try: | |
tool_name = tool_call['name'].lower() | |
tool_args = tool_call['args'] | |
# Select and invoke the appropriate tool | |
selected_tool = {"get_flight_status": get_flight_status}[tool_name] | |
return selected_tool.invoke(tool_args['flight_id']) | |
except Exception as e: | |
return f"Error retrieving flight status: {str(e)}" | |
else: | |
return "Unable to retrieve flight status information." | |
# RAG chain setup | |
rag_chain = ( | |
RunnablePassthrough() | |
| RunnableLambda(retrieve) | |
| (lambda x: x["answer"]) | |
) | |
def process_question(question): | |
""" | |
Processes a question and returns an answer. | |
""" | |
return rag_chain.invoke({"question": question}) | |
# Main execution | |
if __name__ == "__main__": | |
# Example usage | |
questions = [ | |
"What is flight status of EK524?", | |
"What is the cabin baggage size?" | |
] | |
for question in questions: | |
result = process_question(question) | |
print(f"Question: {question}") | |
print(f"Answer: {result}\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment