|
import os |
|
import argparse |
|
import operator |
|
import shutil |
|
from typing import Annotated, List, TypedDict, Union, Sequence |
|
|
|
from dotenv import load_dotenv |
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_openai import ChatOpenAI |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langgraph.graph import StateGraph, END |
|
from langgraph.graph.graph import CompiledGraph |
|
|
|
# --- Configuration Loading --- |
|
load_dotenv() |
|
|
|
# Load essential configs, others can be passed via CLI |
|
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") |
|
if not OPENROUTER_API_KEY: |
|
raise ValueError("OPENROUTER_API_KEY not found in .env file") |
|
|
|
# Default values from .env, overridden by CLI args |
|
DEFAULT_DOCS_DIR = os.getenv("DEFAULT_DOCS_DIR", "./text_docs") |
|
DEFAULT_INDEX_DIR = os.getenv("DEFAULT_INDEX_DIR", "./indexes/default_index") |
|
DEFAULT_EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2") |
|
DEFAULT_CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", 1000)) |
|
DEFAULT_CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", 100)) |
|
DEFAULT_LLM_MODEL = os.getenv("OPENROUTER_MODEL_NAME", "mistralai/mistral-7b-instruct:free") |
|
DEFAULT_COLLECTION_NAME = "rag_docs_collection" # Chroma collection name |
|
|
|
# --- Helper Functions --- |
|
|
|
def _get_embedding_function(model_name: str): |
|
"""Initializes and returns the embedding function.""" |
|
print(f"Initializing embedding model: {model_name}") |
|
# Use Cache folder for Sentence Transformers within the project |
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
cache_folder_path = os.path.join(script_dir, ".sentence_transformers_cache") |
|
os.makedirs(cache_folder_path, exist_ok=True) |
|
print(f"Sentence Transformers cache directory: {cache_folder_path}") |
|
return HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
cache_folder=cache_folder_path |
|
) |
|
|
|
def _build_index(docs_dir: str, index_dir: str, embedding_model: str, chunk_size: int, chunk_overlap: int, collection_name: str, extension: str = "txt"): |
|
"""Loads docs, splits, embeds, and persists a ChromaDB index.""" |
|
print(f"\n--- Building Index ---") |
|
print(f"Source Documents: {docs_dir}") |
|
print(f"Index Directory: {index_dir}") |
|
print(f"Embedding Model: {embedding_model}") |
|
|
|
if not os.path.exists(docs_dir): |
|
raise FileNotFoundError(f"Document directory not found: {docs_dir}") |
|
|
|
# Clean existing index directory if it exists |
|
if os.path.exists(index_dir): |
|
print(f"Removing existing index directory: {index_dir}") |
|
shutil.rmtree(index_dir) |
|
os.makedirs(index_dir, exist_ok=True) |
|
|
|
# Load documents |
|
print("Loading documents...") |
|
loader = DirectoryLoader( |
|
docs_dir, |
|
glob=f"**/*.{extension}", |
|
loader_cls=TextLoader if extension == "txt" else UnstructuredHTMLLoader, |
|
show_progress=True, |
|
use_multithreading=True, |
|
silent_errors=True |
|
) |
|
documents = loader.load() |
|
if not documents: |
|
print(f"Warning: No .{extension} documents found in {docs_dir}. Index will be empty.") |
|
# Create Chroma dir structure anyway so loading doesn't fail |
|
# Chroma().persist(persist_directory=index_dir) # Requires a collection - tricky |
|
# Let's just warn and proceed; loading an empty dir might be handled by Chroma load |
|
return # Exit if no docs |
|
|
|
print(f"Loaded {len(documents)} documents.") |
|
|
|
# Split documents |
|
print("Splitting documents...") |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap |
|
) |
|
splits = text_splitter.split_documents(documents) |
|
print(f"Split into {len(splits)} chunks.") |
|
if not splits: |
|
print("Warning: No text chunks generated after splitting. Index will be empty.") |
|
return # Exit if no splits |
|
|
|
# Embed and store |
|
print("Initializing embeddings and vector store...") |
|
embeddings = _get_embedding_function(embedding_model) |
|
vectorstore = Chroma.from_documents( |
|
documents=splits, |
|
embedding=embeddings, |
|
collection_name=collection_name, |
|
persist_directory=index_dir # Persist to the specified directory |
|
) |
|
# vectorstore.persist() # Persist is often implicitly called by `from_documents` with persist_directory |
|
print(f"Index built successfully at {index_dir} with collection '{collection_name}'") |
|
|
|
|
|
# --- LangGraph Setup --- |
|
|
|
class AgentState(TypedDict): |
|
"""Represents the state of our RAG agent.""" |
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
# We can add retrieved docs to state if needed for more complex routing |
|
# retrieved_docs: List[Document] |
|
|
|
def _get_retriever(index_dir: str, embedding_model: str, collection_name: str): |
|
"""Loads the retriever from a persisted ChromaDB index.""" |
|
print(f"Loading index from: {index_dir}") |
|
if not os.path.exists(index_dir): |
|
raise FileNotFoundError(f"Index directory not found: {index_dir}. Please build the index first using the 'build' command.") |
|
|
|
embeddings = _get_embedding_function(embedding_model) |
|
vectorstore = Chroma( |
|
persist_directory=index_dir, |
|
embedding_function=embeddings, |
|
collection_name=collection_name |
|
) |
|
print(f"Loaded collection '{vectorstore._collection.name}' with {vectorstore._collection.count()} documents.") |
|
if vectorstore._collection.count() == 0: |
|
print("Warning: Loaded index is empty.") |
|
return vectorstore.as_retriever(search_kwargs={"k": int(os.getenv("RAG_RETRIEVE_NUM"))}) # Retrieve top chunks |
|
|
|
def _create_graph(llm, retriever) -> CompiledGraph: |
|
"""Creates and compiles the LangGraph RAG agent.""" |
|
|
|
# --- Nodes --- |
|
def retrieve(state: AgentState): |
|
"""Retrieves documents based on the last human message.""" |
|
print("---NODE: Retrieving documents---") |
|
last_message = state["messages"][-1].content |
|
print(f"Retrieving based on: '{last_message[:100]}...'") # Log query snippet |
|
docs = retriever.invoke(last_message) |
|
print(f"Retrieved {len(docs)} documents.") |
|
# Format context for the LLM |
|
context = "\n\n".join([doc.page_content for doc in docs]) |
|
formatted_message = f"Context:\n{context}\n\nQuestion: {last_message}" |
|
# Create a new HumanMessage with context prepended |
|
# Or, modify state to store docs and use them in the prompt later |
|
# For simplicity here, we'll modify the last message content IF it's Human |
|
# More robust: add a System message with context, or add 'retrieved_docs' to state |
|
if isinstance(state["messages"][-1], HumanMessage): |
|
state["messages"][-1] = HumanMessage(content=formatted_message) |
|
# If the design stored docs in state: |
|
# return {"retrieved_docs": docs} |
|
return {"messages": state["messages"]} # Continue normally |
|
|
|
def generate(state: AgentState): |
|
"""Generates response using the LLM.""" |
|
print("---NODE: Generating response---") |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", os.getenv("PROMPT")), |
|
MessagesPlaceholder(variable_name="messages"), |
|
]) |
|
chain = prompt | llm |
|
# The context is now part of the last HumanMessage after the 'retrieve' node |
|
response_message = chain.invoke({"messages": state["messages"]}) |
|
return {"messages": [response_message]} # Add AI response to messages |
|
|
|
# --- Graph Definition --- |
|
graph = StateGraph(AgentState) |
|
graph.add_node("retrieve", retrieve) |
|
graph.add_node("generate", generate) |
|
|
|
graph.set_entry_point("retrieve") |
|
graph.add_edge("retrieve", "generate") |
|
graph.add_edge("generate", END) # End after generation |
|
|
|
print("Compiling graph...") |
|
compiled_graph = graph.compile() |
|
print("Graph compiled.") |
|
return compiled_graph |
|
|
|
def _run_agent(index_dir: str, embedding_model: str, llm_model: str, collection_name: str): |
|
"""Loads the index, compiles the graph, and runs the interactive RAG agent.""" |
|
print(f"\n--- Running RAG Agent ---") |
|
print(f"Index Directory: {index_dir}") |
|
print(f"Embedding Model: {embedding_model}") |
|
print(f"LLM: {llm_model}") |
|
|
|
# Load retriever |
|
retriever = _get_retriever(index_dir, embedding_model, collection_name) |
|
|
|
# Setup LLM |
|
llm = ChatOpenAI( |
|
model=llm_model, |
|
openai_api_key=OPENROUTER_API_KEY, |
|
openai_api_base="https://openrouter.ai/api/v1", |
|
temperature=0.7, |
|
) |
|
print("LLM initialized.") |
|
|
|
# Create and compile graph |
|
app = _create_graph(llm, retriever) |
|
|
|
# Interaction loop |
|
print("\n--- Chat with your documents (type 'exit' to quit) ---") |
|
while True: |
|
try: |
|
user_input = input("You: ") |
|
if user_input.lower() == 'exit': |
|
print("Exiting.") |
|
break |
|
if not user_input: |
|
continue |
|
|
|
inputs = {"messages": [HumanMessage(content=user_input)]} |
|
# Stream or invoke - invoke waits for the final result |
|
# for event in app.stream(inputs): |
|
# print(event, flush=True) # Print graph events |
|
|
|
final_state = app.invoke(inputs) |
|
ai_response = final_state['messages'][-1].content |
|
print(f"AI: {ai_response}") |
|
|
|
except KeyboardInterrupt: |
|
print("\nExiting.") |
|
break |
|
except Exception as e: |
|
print(f"\nAn error occurred: {e}") |
|
# Optionally break or continue |
|
# break |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Build a document index or run a RAG agent.") |
|
subparsers = parser.add_subparsers(dest="command", required=True, help="Available commands") |
|
|
|
# --- Build Command --- |
|
parser_build = subparsers.add_parser("build", help="Build a vector index from documents.") |
|
parser_build.add_argument("--docs-dir", type=str, default=DEFAULT_DOCS_DIR, help="Directory containing .txt source documents.") |
|
parser_build.add_argument("--index-dir", type=str, default=DEFAULT_INDEX_DIR, help="Directory to save the persistent index.") |
|
parser_build.add_argument("--embed-model", type=str, default=DEFAULT_EMBEDDING_MODEL, help="Name of the Sentence Transformer embedding model.") |
|
parser_build.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE, help="Chunk size for splitting documents.") |
|
parser_build.add_argument("--chunk-overlap", type=int, default=DEFAULT_CHUNK_OVERLAP, help="Chunk overlap for splitting documents.") |
|
parser_build.add_argument("--collection-name", type=str, default=DEFAULT_COLLECTION_NAME, help="Name for the ChromaDB collection.") |
|
# FIX: Make lambda accept 'args' |
|
parser_build.set_defaults(func=lambda args: _build_index(args.docs_dir, args.index_dir, args.embed_model, args.chunk_size, args.chunk_overlap, args.collection_name)) |
|
|
|
# --- Run Command --- |
|
parser_run = subparsers.add_parser("run", help="Run the RAG agent using an existing index.") |
|
parser_run.add_argument("--index-dir", type=str, default=DEFAULT_INDEX_DIR, help="Directory containing the persistent index.") |
|
parser_run.add_argument("--embed-model", type=str, default=DEFAULT_EMBEDDING_MODEL, help="Name of the Sentence Transformer embedding model (must match the one used for building).") |
|
parser_run.add_argument("--llm-model", type=str, default=DEFAULT_LLM_MODEL, help="Name of the OpenRouter model to use (e.g., 'mistralai/mistral-7b-instruct:free').") |
|
parser_run.add_argument("--collection-name", type=str, default=DEFAULT_COLLECTION_NAME, help="Name of the ChromaDB collection within the index.") |
|
# FIX: Make lambda accept 'args' |
|
parser_run.set_defaults(func=lambda args: _run_agent(args.index_dir, args.embed_model, args.llm_model, args.collection_name)) |
|
|
|
# --- Parse Args and Execute --- |
|
args = parser.parse_args() |
|
args.func(args) # Call the function associated with the chosen command |