Last active
February 25, 2025 02:36
-
-
Save STHITAPRAJNAS/5e9dfa7ff953e6bbff66a0a6e64e9dc7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| project/ | |
| │ | |
| ├── state.py # Defines the agent's state | |
| ├── prompts.py # Manages prompt templates | |
| ├── config.py # Handles configuration and initialization | |
| ├── tools.py # Contains utility functions and tools | |
| ├── guardrails.py # Implements safety checks | |
| ├── nodes.py # Defines workflow nodes for LangGraph | |
| ├── workflow.py # Orchestrates the main workflow using LangGraph | |
| └── main.py # Runs the agent with example queries | |
| """ | |
| # 1. state.py - Agent State Definition | |
| # This module defines the AgentState using a TypedDict to manage the agent's state throughout the workflow. | |
| from typing import TypedDict, Annotated, List | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| import operator | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
| intent: str | |
| confluence_context: List[str] | |
| databricks_context: List[str] | |
| generated_sql: str | |
| sql_attempts: int | |
| sql_error: str | |
| needs_clarification: bool | |
| final_answer: str | |
| rewritten_query: str | |
| # 2. prompts.py - Prompt Templates Management | |
| # This module defines all prompt templates and a PromptManager class to handle them. | |
| from langchain.prompts import PromptTemplate | |
| class PromptManager: | |
| def __init__(self): | |
| self.intent_template = PromptTemplate( | |
| input_variables=["question"], | |
| template="""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous': | |
| Question: {question} | |
| Provide a one-word response: """ | |
| ) | |
| self.sql_generation_template = PromptTemplate( | |
| input_variables=["query", "metadata"], | |
| template="Generate a SQL query for '{query}' using metadata: {metadata}" | |
| ) | |
| self.sql_refinement_template = PromptTemplate( | |
| input_variables=["previous_sql", "error", "metadata"], | |
| template="Refine this SQL query '{previous_sql}' that caused error '{error}' using metadata: {metadata}" | |
| ) | |
| self.answer_template = PromptTemplate( | |
| input_variables=["query", "confluence_ctx", "databricks_ctx"], | |
| template="""Answer this question: '{query}' | |
| Using Confluence context: {confluence_ctx} | |
| And Databricks context: {databricks_ctx} | |
| Provide a concise, accurate response.""" | |
| ) | |
| self.safety_template = PromptTemplate( | |
| input_variables=["answer"], | |
| template="Is this answer safe and appropriate for all audiences? Answer 'yes' or 'no': {answer}" | |
| ) | |
| def get_intent_prompt(self, question: str) -> str: | |
| return self.intent_template.format(question=question) | |
| def get_sql_generation_prompt(self, query: str, metadata: str) -> str: | |
| return self.sql_generation_template.format(query=query, metadata=metadata) | |
| def get_sql_refinement_prompt(self, previous_sql: str, error: str, metadata: str) -> str: | |
| return self.sql_refinement_template.format(previous_sql=previous_sql, error=error, metadata=metadata) | |
| def get_answer_prompt(self, query: str, confluence_ctx: str, databricks_ctx: str) -> str: | |
| return self.answer_template.format(query=query, confluence_ctx=confluence_ctx, databricks_ctx=databricks_ctx) | |
| def get_safety_prompt(self, answer: str) -> str: | |
| return self.safety_template.format(answer=answer) | |
| # 3. config.py - Configuration and Initialization | |
| # This module initializes the language model, embeddings, vector stores, and prompt manager. | |
| import os | |
| from langchain_aws import ChatBedrock | |
| from langchain_community.vectorstores import PGVector | |
| from langchain_openai import OpenAIEmbeddings | |
| from prompts import PromptManager | |
| # Load environment variables | |
| os.environ["AWS_REGION"] = "us-east-1" | |
| CONFLUENCE_CONNECTION_STRING = os.getenv("CONFLUENCE_DB_URL") | |
| DATABRICKS_CONNECTION_STRING = os.getenv("DATABRICKS_DB_URL") | |
| # Initialize LLM | |
| llm = ChatBedrock( | |
| model_id="anthropic.claude-3-5-sonnet-20240620", | |
| region_name="us-east-1", | |
| model_kwargs={"temperature": 0.7} | |
| ) | |
| # Initialize embeddings | |
| embeddings = OpenAIEmbeddings() | |
| # Initialize vector stores | |
| confluence_store = PGVector( | |
| collection_name="confluence_docs", | |
| connection_string=CONFLUENCE_CONNECTION_STRING, | |
| embedding_function=embeddings | |
| ) | |
| databricks_store = PGVector( | |
| collection_name="databricks_metadata", | |
| connection_string=DATABRICKS_CONNECTION_STRING, | |
| embedding_function=embeddings | |
| ) | |
| # Initialize prompt manager | |
| prompt_manager = PromptManager() | |
| # 4. tools.py - Utility Functions and Tools | |
| # This module contains tools for SQL generation, query rewriting, result reranking, and simulated SQL execution. | |
| from config import llm, prompt_manager | |
| from langchain.tools import tool | |
| from sentence_transformers import CrossEncoder | |
| reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| @tool | |
| def generate_sql(user_query: str = None, previous_sql: str = None, error: str = None, metadata: List[str] = None) -> str: | |
| """ | |
| Generates or refines a SQL query based on inputs. | |
| Args: | |
| user_query (str, optional): The user's original query for initial SQL generation. | |
| previous_sql (str, optional): The previous SQL query to refine if an error occurred. | |
| error (str, optional): The error message from the previous SQL execution. | |
| metadata (List[str]): Metadata to assist SQL generation. | |
| Returns: | |
| str: The generated or refined SQL query. | |
| Raises: | |
| ValueError: If inputs are invalid. | |
| """ | |
| if previous_sql and error: | |
| prompt = prompt_manager.get_sql_refinement_prompt(previous_sql, error, metadata) | |
| elif user_query: | |
| prompt = prompt_manager.get_sql_generation_prompt(user_query, metadata) | |
| else: | |
| raise ValueError("Must provide either user_query or previous_sql and error") | |
| response = llm.invoke(prompt) | |
| return response.content | |
| def rewrite_query(query: str) -> str: | |
| prompt = f"Rewrite this query to make it clearer and more specific: '{query}'" | |
| return llm.invoke(prompt).content | |
| def rerank_results(results: List[str], query: str) -> List[str]: | |
| pairs = [(query, doc) for doc in results] | |
| scores = reranker.predict(pairs) | |
| sorted_results = [doc for _, doc in sorted(zip(scores, results), reverse=True)] | |
| return sorted_results | |
| def execute_databricks_sql(sql: str) -> dict: | |
| if "error" in sql.lower(): | |
| return {"status": "error", "error": "Simulated SQL error"} | |
| return {"status": "success", "results": ["Simulated results"]} | |
| # 5. guardrails.py - Safety Checks | |
| # This module implements safety checks for answers and SQL queries. | |
| from config import llm, prompt_manager | |
| def check_answer_safety(answer: str) -> bool: | |
| safety_prompt = prompt_manager.get_safety_prompt(answer) | |
| safety_response = llm.invoke(safety_prompt).content.strip().lower() | |
| if safety_response != "yes": | |
| return False | |
| forbidden_words = ["confidential", "password", "secret"] | |
| return not any(word in answer.lower() for word in forbidden_words) | |
| def is_safe_sql(sql: str) -> bool: | |
| sql_lower = sql.lower().strip() | |
| return (sql_lower.startswith("select") and | |
| "insert" not in sql_lower and | |
| "update" not in sql_lower and | |
| "delete" not in sql_lower and | |
| "drop" not in sql_lower) | |
| # 6. nodes.py - Workflow Nodes | |
| # This module defines the nodes for the LangGraph workflow. | |
| from state import AgentState | |
| from config import llm, prompt_manager, confluence_store, databricks_store | |
| from tools import generate_sql, rewrite_query, rerank_results, execute_databricks_sql | |
| from guardrails import check_answer_safety, is_safe_sql | |
| def parse_question(state: AgentState) -> AgentState: | |
| question = state["messages"][-1].content | |
| prompt = prompt_manager.get_intent_prompt(question) | |
| response = llm.invoke(prompt).content.strip() | |
| state["intent"] = response | |
| return state | |
| def rewrite_query_node(state: AgentState) -> AgentState: | |
| question = state["messages"][-1].content | |
| state["rewritten_query"] = rewrite_query(question) | |
| return state | |
| def route_context(state: AgentState) -> AgentState: | |
| query = state["rewritten_query"] | |
| if state["intent"] in ["confluence", "both"]: | |
| results = confluence_store.similarity_search(query, k=5) | |
| state["confluence_context"] = rerank_results([doc.page_content for doc in results], query) | |
| if state["intent"] in ["databricks", "both"]: | |
| results = databricks_store.similarity_search(query, k=5) | |
| state["databricks_context"] = rerank_results([doc.page_content for doc in results], query) | |
| if state["intent"] == "ambiguous": | |
| state["needs_clarification"] = True | |
| return state | |
| def clarify_question(state: AgentState) -> AgentState: | |
| if state["needs_clarification"]: | |
| state["messages"].append(AIMessage(content="Please clarify your query.")) | |
| state["needs_clarification"] = False | |
| return state | |
| def generate_sql_node(state: AgentState) -> AgentState: | |
| metadata = state["databricks_context"] | |
| state["generated_sql"] = generate_sql(user_query=state["rewritten_query"], metadata=metadata) | |
| state["sql_attempts"] = state.get("sql_attempts", 0) + 1 | |
| return state | |
| def execute_sql_node(state: AgentState) -> AgentState: | |
| if not is_safe_sql(state["generated_sql"]): | |
| state["sql_error"] = "Unsafe SQL query detected." | |
| return state | |
| result = execute_databricks_sql(state["generated_sql"]) | |
| if result["status"] == "error": | |
| state["sql_error"] = result["error"] | |
| else: | |
| state["databricks_context"] = result["results"] | |
| return state | |
| def generate_answer(state: AgentState) -> AgentState: | |
| prompt = prompt_manager.get_answer_prompt( | |
| state["rewritten_query"], | |
| state["confluence_context"], | |
| state["databricks_context"] | |
| ) | |
| answer = llm.invoke(prompt).content | |
| if not check_answer_safety(answer): | |
| state["final_answer"] = "Answer deemed unsafe or inappropriate." | |
| else: | |
| state["final_answer"] = answer | |
| state["messages"].append(AIMessage(content=state["final_answer"])) | |
| return state | |
| # 7. workflow.py - Main Workflow Orchestration | |
| # This module defines the LangGraph workflow, connecting all nodes. | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from state import AgentState | |
| from nodes import parse_question, rewrite_query_node, route_context, clarify_question, generate_sql_node, execute_sql_node, generate_answer | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("parse_question", parse_question) | |
| workflow.add_node("rewrite_query", rewrite_query_node) | |
| workflow.add_node("route_context", route_context) | |
| workflow.add_node("clarify_question", clarify_question) | |
| workflow.add_node("generate_sql", generate_sql_node) | |
| workflow.add_node("execute_sql", execute_sql_node) | |
| workflow.add_node("generate_answer", generate_answer) | |
| # Define edges | |
| workflow.add_edge("parse_question", "rewrite_query") | |
| workflow.add_edge("rewrite_query", "route_context") | |
| workflow.add_edge("route_context", "clarify_question") | |
| workflow.add_conditional_edges( | |
| "clarify_question", | |
| lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer", | |
| {"generate_sql": "generate_sql", "generate_answer": "generate_answer"} | |
| ) | |
| workflow.add_edge("generate_sql", "execute_sql") | |
| workflow.add_edge("execute_sql", "generate_answer") | |
| workflow.add_edge("generate_answer", END) | |
| # Set entry point | |
| workflow.set_entry_point("parse_question") | |
| # Compile with memory | |
| checkpointer = MemorySaver() | |
| graph = workflow.compile(checkpointer=checkpointer) | |
| # 8. main.py - Running the Agent | |
| # This module runs the agent with example queries. | |
| from workflow import graph | |
| from state import AgentState | |
| from langchain_core.messages import HumanMessage | |
| def run_agent(question: str, thread_id: str = "thread_1"): | |
| initial_state = AgentState(messages=[HumanMessage(content=question)]) | |
| result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}}) | |
| return result["messages"][-1].content | |
| if __name__ == "__main__": | |
| print(run_agent("What is the process for onboarding in Confluence?")) | |
| print(run_agent("How many rows are in the sales table in Databricks?")) | |
| print(run_agent("What’s the latest update on project X?")) | |
| # This modular structure separates concerns effectively: | |
| # State Management: state.py defines the agent's state. | |
| # Prompt Templates: prompts.py manages all prompt templates. | |
| # Configuration: config.py initializes models and stores. | |
| # Tools: tools.py provides utility functions and tools. | |
| # Safety Checks: guardrails.py ensures safe outputs. | |
| # Workflow Nodes: nodes.py implements individual workflow steps. | |
| # Main Workflow: workflow.py orchestrates the entire process. | |
| # Execution: main.py runs the agent with example queries. | |
| # Each module has clear dependencies, avoiding circular imports, and the code is easier to maintain and extend. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #pip install langchain langchain-aws langchain-community langchain-openai langgraph boto3 | |
| """ | |
| How It Works | |
| State Management: | |
| The AgentState tracks conversation history, intent, context, SQL details, and more. | |
| DynamoDBSaver persists state across invocations, supporting concurrent users. | |
| Workflow: | |
| Parse Question: Determines if the query relates to Confluence, Databricks, both, or is ambiguous. | |
| Rewrite Query: Enhances query clarity. | |
| Route Context: Retrieves relevant documents from PGVector stores. | |
| Clarify Question: Prompts for clarification if needed. | |
| Generate/Execute SQL: Creates and runs SQL queries for Databricks questions, with retries on errors. | |
| Generate Answer: Combines history (last 5 messages) and context to produce a response. | |
| Guardrails: | |
| Ensures SQL queries are read-only. | |
| Validates answer safety and appropriateness. | |
| Agentic Behavior: | |
| The LangGraph workflow enables decision-making (e.g., SQL generation) and multi-step processing. | |
| """ | |
| import os | |
| import json | |
| import boto3 | |
| from typing import TypedDict, Annotated, List | |
| from langchain_aws import ChatBedrock | |
| from langchain_community.vectorstores import PGVector | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.tools import tool | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.base import BaseCheckpointSaver | |
| import operator | |
| # --- DynamoDB State Persistence --- | |
| class DynamoDBSaver(BaseCheckpointSaver): | |
| """Custom checkpointer to save and load conversation state from DynamoDB.""" | |
| def __init__(self, table_name: str = "conversation_state"): | |
| self.dynamodb = boto3.resource('dynamodb') | |
| self.table = self.dynamodb.Table(table_name) | |
| def get(self, config: dict) -> dict: | |
| thread_id = config["configurable"]["thread_id"] | |
| response = self.table.get_item(Key={"thread_id": thread_id}) | |
| return json.loads(response["Item"]["state"]) if "Item" in response else {} | |
| def put(self, config: dict, checkpoint: dict) -> None: | |
| thread_id = config["configurable"]["thread_id"] | |
| self.table.put_item(Item={"thread_id": thread_id, "state": json.dumps(checkpoint)}) | |
| # --- Agent State Definition --- | |
| class AgentState(TypedDict): | |
| """Schema for the agent's state, tracking conversation and processing details.""" | |
| messages: Annotated[List[HumanMessage | AIMessage], operator.add] # Conversation history | |
| intent: str # Classified intent: 'confluence', 'databricks', 'both', or 'ambiguous' | |
| confluence_context: List[str] # Retrieved Confluence context | |
| databricks_context: List[str] # Retrieved Databricks context | |
| generated_sql: str # Generated SQL query | |
| sql_attempts: int # Number of SQL generation attempts | |
| sql_error: str # Error message from SQL execution | |
| needs_clarification: bool # Flag for ambiguous queries | |
| final_answer: str # Generated response | |
| rewritten_query: str # Rewritten user query for clarity | |
| # --- Configurations --- | |
| os.environ["AWS_REGION"] = "us-east-1" | |
| CONFLUENCE_CONNECTION_STRING = os.getenv("CONFLUENCE_DB_URL") # Set in environment | |
| DATABRICKS_CONNECTION_STRING = os.getenv("DATABRICKS_DB_URL") # Set in environment | |
| # Initialize Bedrock model | |
| llm = ChatBedrock( | |
| model_id="anthropic.claude-3-5-sonnet-20240620", | |
| region_name="us-east-1", | |
| model_kwargs={"temperature": 0.7} | |
| ) | |
| # Initialize vector stores with OpenAI embeddings (for simplicity; adjust as needed) | |
| from langchain_openai import OpenAIEmbeddings | |
| embeddings = OpenAIEmbeddings() | |
| confluence_store = PGVector( | |
| collection_name="confluence_docs", | |
| connection_string=CONFLUENCE_CONNECTION_STRING, | |
| embedding_function=embeddings | |
| ) | |
| databricks_store = PGVector( | |
| collection_name="databricks_metadata", | |
| connection_string=DATABRICKS_CONNECTION_STRING, | |
| embedding_function=embeddings | |
| ) | |
| # --- Prompt Templates --- | |
| intent_template = PromptTemplate( | |
| input_variables=["question"], | |
| template="""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous': | |
| Question: {question} | |
| Provide a one-word response: """ | |
| ) | |
| sql_generation_template = PromptTemplate( | |
| input_variables=["query", "metadata"], | |
| template="Generate a SQL query for '{query}' using metadata: {metadata}" | |
| ) | |
| sql_refinement_template = PromptTemplate( | |
| input_variables=["previous_sql", "error", "metadata"], | |
| template="Refine this SQL query '{previous_sql}' that caused error '{error}' using metadata: {metadata}" | |
| ) | |
| answer_template = PromptTemplate( | |
| input_variables=["history", "query", "confluence_ctx", "databricks_ctx"], | |
| template="""Conversation history: | |
| {history} | |
| Current question: {query} | |
| Using Confluence context: {confluence_ctx} | |
| And Databricks context: {databricks_ctx} | |
| Provide a concise, accurate response.""" | |
| ) | |
| safety_template = PromptTemplate( | |
| input_variables=["answer"], | |
| template="Is this answer safe and appropriate? Answer 'yes' or 'no': {answer}" | |
| ) | |
| # --- Tools and Utilities --- | |
| @tool | |
| def generate_sql(user_query: str = None, previous_sql: str = None, error: str = None, metadata: List[str] = None) -> str: | |
| """Generate or refine a SQL query based on input.""" | |
| if previous_sql and error: | |
| prompt = sql_refinement_template.format(previous_sql=previous_sql, error=error, metadata=metadata) | |
| elif user_query: | |
| prompt = sql_generation_template.format(query=user_query, metadata=metadata) | |
| else: | |
| raise ValueError("Must provide either user_query or previous_sql and error") | |
| response = llm.invoke(prompt) | |
| return response.content | |
| def execute_databricks_sql(sql: str) -> dict: | |
| """Simulate SQL execution on Databricks (replace with actual implementation).""" | |
| if "error" in sql.lower(): # Simulated error condition | |
| return {"status": "error", "error": "Simulated SQL error"} | |
| return {"status": "success", "results": ["Simulated results"]} | |
| def rewrite_query(query: str) -> str: | |
| """Rewrite the user's query for clarity and specificity.""" | |
| prompt = f"Rewrite this query to make it clearer and more specific: '{query}'" | |
| return llm.invoke(prompt).content | |
| # --- Guardrails --- | |
| def check_answer_safety(answer: str) -> bool: | |
| """Ensure the generated answer is safe and appropriate.""" | |
| safety_prompt = safety_template.format(answer=answer) | |
| safety_response = llm.invoke(safety_prompt).content.strip().lower() | |
| return safety_response == "yes" | |
| def is_safe_sql(sql: str) -> bool: | |
| """Check if the SQL query is safe (read-only).""" | |
| sql_lower = sql.lower().strip() | |
| return sql_lower.startswith("select") and all( | |
| keyword not in sql_lower for keyword in ["insert", "update", "delete", "drop"] | |
| ) | |
| # --- Workflow Nodes --- | |
| def parse_question(state: AgentState) -> AgentState: | |
| """Parse the user's question to determine intent.""" | |
| question = state["messages"][-1].content | |
| prompt = intent_template.format(question=question) | |
| response = llm.invoke(prompt).content.strip() | |
| state["intent"] = response | |
| return state | |
| def rewrite_query_node(state: AgentState) -> AgentState: | |
| """Rewrite the query for better processing.""" | |
| question = state["messages"][-1].content | |
| state["rewritten_query"] = rewrite_query(question) | |
| return state | |
| def route_context(state: AgentState) -> AgentState: | |
| """Retrieve context from Confluence and/or Databricks based on intent.""" | |
| query = state["rewritten_query"] | |
| if state["intent"] in ["confluence", "both"]: | |
| results = confluence_store.similarity_search(query, k=3) | |
| state["confluence_context"] = [doc.page_content for doc in results] | |
| if state["intent"] in ["databricks", "both"]: | |
| results = databricks_store.similarity_search(query, k=3) | |
| state["databricks_context"] = [doc.page_content for doc in results] | |
| if state["intent"] == "ambiguous": | |
| state["needs_clarification"] = True | |
| return state | |
| def clarify_question(state: AgentState) -> AgentState: | |
| """Ask for clarification if the query is ambiguous.""" | |
| if state["needs_clarification"]: | |
| state["messages"].append(AIMessage(content="Please clarify your query.")) | |
| state["needs_clarification"] = False | |
| return state | |
| def generate_sql_node(state: AgentState) -> AgentState: | |
| """Generate a SQL query for Databricks-related questions.""" | |
| if state["intent"] in ["databricks", "both"] and state["databricks_context"]: | |
| metadata = state["databricks_context"] | |
| state["generated_sql"] = generate_sql(user_query=state["rewritten_query"], metadata=metadata) | |
| state["sql_attempts"] = state.get("sql_attempts", 0) + 1 | |
| return state | |
| def execute_sql_node(state: AgentState) -> AgentState: | |
| """Execute the SQL query with a feedback loop for retries.""" | |
| if state.get("generated_sql"): | |
| if not is_safe_sql(state["generated_sql"]): | |
| state["databricks_context"] = ["Unsafe SQL query detected."] | |
| state["sql_error"] = None | |
| else: | |
| result = execute_databricks_sql(state["generated_sql"]) | |
| if result["status"] == "success": | |
| state["databricks_context"] = result["results"] | |
| state["sql_error"] = None | |
| else: | |
| state["sql_error"] = result["error"] | |
| if state["sql_attempts"] < 3: | |
| state["generated_sql"] = generate_sql( | |
| previous_sql=state["generated_sql"], | |
| error=state["sql_error"], | |
| metadata=state["databricks_context"] | |
| ) | |
| state = execute_sql_node(state) # Recursive retry | |
| else: | |
| state["databricks_context"] = ["Unable to retrieve data due to persistent errors."] | |
| state["sql_error"] = None | |
| return state | |
| def generate_answer(state: AgentState) -> AgentState: | |
| """Generate the final answer using conversation history and context.""" | |
| history = "\n".join([f"{msg.type}: {msg.content}" for msg in state["messages"][-5:]]) | |
| confluence_ctx = "\n".join(state.get("confluence_context", [])) | |
| databricks_ctx = "\n".join(state.get("databricks_context", [])) | |
| prompt = answer_template.format( | |
| history=history, | |
| query=state["rewritten_query"], | |
| confluence_ctx=confluence_ctx, | |
| databricks_ctx=databricks_ctx | |
| ) | |
| answer = llm.invoke(prompt).content | |
| if check_answer_safety(answer): | |
| state["final_answer"] = answer | |
| else: | |
| state["final_answer"] = "I'm sorry, but I can't provide that information." | |
| state["messages"].append(AIMessage(content=state["final_answer"])) | |
| return state | |
| # --- Workflow Definition --- | |
| workflow = StateGraph(AgentState) | |
| # Add nodes to the workflow | |
| workflow.add_node("parse_question", parse_question) | |
| workflow.add_node("rewrite_query", rewrite_query_node) | |
| workflow.add_node("route_context", route_context) | |
| workflow.add_node("clarify_question", clarify_question) | |
| workflow.add_node("generate_sql", generate_sql_node) | |
| workflow.add_node("execute_sql", execute_sql_node) | |
| workflow.add_node("generate_answer", generate_answer) | |
| # Define the flow of execution | |
| workflow.set_entry_point("parse_question") | |
| workflow.add_edge("parse_question", "rewrite_query") | |
| workflow.add_edge("rewrite_query", "route_context") | |
| workflow.add_edge("route_context", "clarify_question") | |
| workflow.add_conditional_edges( | |
| "clarify_question", | |
| lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer", | |
| {"generate_sql": "generate_sql", "generate_answer": "generate_answer"} | |
| ) | |
| workflow.add_edge("generate_sql", "execute_sql") | |
| workflow.add_edge("execute_sql", "generate_answer") | |
| workflow.add_edge("generate_answer", END) | |
| # Compile the workflow with state persistence | |
| checkpointer = DynamoDBSaver(table_name="conversation_state") # Assumes table exists | |
| graph = workflow.compile(checkpointer=checkpointer) | |
| # --- Main Function to Run the Agent --- | |
| def run_agent(question: str, thread_id: str = "thread_1") -> str: | |
| """Run the agent with a user question and return the response.""" | |
| initial_state = {"messages": [HumanMessage(content=question)]} | |
| result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}}) | |
| return result["messages"][-1].content | |
| # --- Example Usage --- | |
| if __name__ == "__main__": | |
| response = run_agent("What is the process for onboarding in Confluence?") | |
| print(response) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import json | |
| from typing import TypedDict, List, Optional | |
| from langchain_core.messages import HumanMessage | |
| from langchain_aws import ChatBedrock | |
| from langchain.vectorstores.pgvector import PGVector | |
| from langchain.embeddings import BedrockEmbeddings | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from databricks import sql | |
| from langchain.output_parsers import StructuredOutputParser, ResponseSchema | |
| from langchain.prompts import PromptTemplate | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Bedrock Sonnet model | |
| llm = ChatBedrock( | |
| model_id="anthropic.claude-3-sonnet-20240229-v1:0", | |
| region_name="us-east-1" # Replace with your AWS region | |
| ) | |
| # Initialize embedding function | |
| embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1") | |
| # Initialize PGVector stores (replace connection strings) | |
| confluence_store = PGVector( | |
| collection_name="confluence", | |
| connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname", | |
| embedding_function=embeddings | |
| ) | |
| databricks_store = PGVector( | |
| collection_name="databricks_metadata", | |
| connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname", | |
| embedding_function=embeddings | |
| ) | |
| # Databricks SQL connection (replace with your credentials) | |
| databricks_connection = sql.connect( | |
| server_hostname="your_databricks_host", | |
| http_path="your_http_path", | |
| access_token="your_access_token" | |
| ) | |
| # Define State schema | |
| class State(TypedDict): | |
| question: str # Original user question | |
| rewritten_query: Optional[str] # Rewritten query for retrieval | |
| intent: Optional[str] # Intent: "confluence", "databricks_metadata", or "sql_generation" | |
| confluence_docs: List[str] # Retrieved Confluence documents | |
| databricks_metadata: List[str] # Retrieved Databricks metadata | |
| sql_query: Optional[str] # Generated SQL query | |
| sql_result: Optional[str] # Raw SQL execution result | |
| answer: Optional[str] # Intermediate answer | |
| clarification_needed: bool # Flag for clarification | |
| final_answer: Optional[str] # Final output answer | |
| attempt_count: int # Retry counter for SQL execution | |
| # Define Structured Output Schemas | |
| question_analyzer_schema = [ | |
| ResponseSchema(name="intent", description="The classified intent: 'confluence', 'databricks_metadata', or 'sql_generation'"), | |
| ResponseSchema(name="clarification_needed", description="Boolean indicating if clarification is needed", type="boolean") | |
| ] | |
| query_rewriter_schema = [ | |
| ResponseSchema(name="rewritten_query", description="The rewritten query optimized for retrieval") | |
| ] | |
| reranker_schema = [ | |
| ResponseSchema(name="indices", description="List of indices of the top 3 most relevant documents", type="list") | |
| ] | |
| sql_generation_schema = [ | |
| ResponseSchema(name="sql_query", description="The generated SQL query") | |
| ] | |
| clarification_schema = [ | |
| ResponseSchema(name="clarification_question", description="A follow-up question to clarify intent") | |
| ] | |
| # Initialize Structured Output Parsers | |
| question_analyzer_parser = StructuredOutputParser.from_response_schemas(question_analyzer_schema) | |
| query_rewriter_parser = StructuredOutputParser.from_response_schemas(query_rewriter_schema) | |
| reranker_parser = StructuredOutputParser.from_response_schemas(reranker_schema) | |
| sql_generation_parser = StructuredOutputParser.from_response_schemas(sql_generation_schema) | |
| clarification_parser = StructuredOutputParser.from_response_schemas(clarification_schema) | |
| # Define Nodes with Claude-Optimized Prompts | |
| def question_analyzer(state: State) -> State: | |
| logger.info("Entering question_analyzer stage") | |
| prompt_template = PromptTemplate( | |
| input_variables=["question"], | |
| template=""" | |
| You are an expert intent classifier. Given this question: "{question}", determine its intent. Choose one: | |
| - "confluence" for general knowledge or documentation queries | |
| - "databricks_metadata" for Databricks table schema or metadata questions | |
| - "sql_generation" for questions needing an SQL query | |
| If the question is unclear, set clarification_needed to true; otherwise, false. | |
| Return your response in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": question_analyzer_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(question=state['question']) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = question_analyzer_parser.parse(response.content) | |
| state['intent'] = result['intent'] | |
| state['clarification_needed'] = result['clarification_needed'] | |
| state['confluence_docs'] = [] | |
| state['databricks_metadata'] = [] | |
| state['attempt_count'] = 0 | |
| logger.debug(f"Intent: {state['intent']}, Clarification needed: {state['clarification_needed']}") | |
| return state | |
| def query_rewriter(state: State) -> State: | |
| logger.info("Entering query_rewriter stage") | |
| prompt_template = PromptTemplate( | |
| input_variables=["question"], | |
| template=""" | |
| You are a query optimization expert. Rewrite this question to make it more precise for document retrieval: "{question}". | |
| Return the rewritten query in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": query_rewriter_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(question=state['question']) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = query_rewriter_parser.parse(response.content) | |
| state['rewritten_query'] = result['rewritten_query'] | |
| logger.debug(f"Rewritten query: {state['rewritten_query']}") | |
| return state | |
| def confluence_retrieval(state: State) -> State: | |
| logger.info("Entering confluence_retrieval stage") | |
| query = state.get('rewritten_query', state['question']) | |
| docs = confluence_store.similarity_search(query, k=5) | |
| state['confluence_docs'] = [doc.page_content for doc in docs] | |
| logger.debug(f"Retrieved {len(state['confluence_docs'])} Confluence docs") | |
| return state | |
| def databricks_retrieval(state: State) -> State: | |
| logger.info("Entering databricks_retrieval stage") | |
| query = state.get('rewritten_query', state['question']) | |
| docs = databricks_store.similarity_search(query, k=5) | |
| state['databricks_metadata'] = [doc.page_content for doc in docs] | |
| logger.debug(f"Retrieved {len(state['databricks_metadata'])} Databricks metadata entries") | |
| return state | |
| def result_reranker(state: State) -> State: | |
| logger.info("Entering result_reranker stage") | |
| if state['confluence_docs']: | |
| docs_str = "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(state['confluence_docs'])]) | |
| prompt_template = PromptTemplate( | |
| input_variables=["question", "docs"], | |
| template=""" | |
| You are a document relevance analyst. Given the question "{question}" and these documents: | |
| {docs} | |
| Identify the top 3 most relevant documents by their indices (0-based). | |
| Return the indices in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": reranker_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(question=state['question'], docs=docs_str) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = reranker_parser.parse(response.content) | |
| indices = result['indices'] | |
| state['confluence_docs'] = [state['confluence_docs'][i] for i in indices if i < len(state['confluence_docs'])] | |
| logger.debug(f"Reranked Confluence docs: {state['confluence_docs']}") | |
| if state['databricks_metadata']: | |
| docs_str = "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(state['databricks_metadata'])]) | |
| prompt_template = PromptTemplate( | |
| input_variables=["question", "docs"], | |
| template=""" | |
| You are a metadata relevance analyst. Given the question "{question}" and these metadata entries: | |
| {docs} | |
| Identify the top 3 most relevant entries by their indices (0-based). | |
| Return the indices in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": reranker_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(question=state['question'], docs=docs_str) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = reranker_parser.parse(response.content) | |
| indices = result['indices'] | |
| state['databricks_metadata'] = [state['databricks_metadata'][i] for i in indices if i < len(state['databricks_metadata'])] | |
| logger.debug(f"Reranked Databricks metadata: {state['databricks_metadata']}") | |
| return state | |
| def sql_generation(state: State) -> State: | |
| logger.info("Entering sql_generation stage") | |
| if not state['databricks_metadata']: | |
| state = databricks_retrieval(state) | |
| state = result_reranker(state) | |
| metadata = "\n".join(state['databricks_metadata']) | |
| prompt_template = PromptTemplate( | |
| input_variables=["metadata", "question"], | |
| template=""" | |
| You are an expert SQL generator for Databricks. Using this metadata: | |
| {metadata} | |
| Write an SQL query to answer the question: "{question}". | |
| Return the query in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": sql_generation_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(metadata=metadata, question=state['question']) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = sql_generation_parser.parse(response.content) | |
| state['sql_query'] = result['sql_query'] | |
| logger.debug(f"Generated SQL query: {state['sql_query']}") | |
| return state | |
| def sql_execution(state: State) -> State: | |
| logger.info("Entering sql_execution stage") | |
| sql_query = state['sql_query'] | |
| max_attempts = 3 | |
| attempt = state['attempt_count'] | |
| if attempt >= max_attempts: | |
| state['final_answer'] = "Failed to execute SQL query after 3 attempts." | |
| logger.info("SQL execution abandoned after max retries") | |
| return state | |
| try: | |
| with databricks_connection.cursor() as cursor: | |
| logger.info(f"Executing SQL (Attempt {attempt + 1}): {sql_query}") | |
| cursor.execute(sql_query) | |
| result = cursor.fetchall() | |
| state['sql_result'] = str(result) | |
| logger.debug(f"SQL result: {state['sql_result']}") | |
| return state | |
| except Exception as e: | |
| logger.error(f"SQL execution failed: {e}") | |
| state['attempt_count'] += 1 | |
| logger.info(f"Retrying SQL generation (attempt {state['attempt_count']})") | |
| return sql_generation(state) | |
| def answer_generation(state: State) -> State: | |
| logger.info("Entering answer_generation stage") | |
| context = "" | |
| if state['confluence_docs']: | |
| context += "Confluence Documents:\n" + "\n".join(state['confluence_docs']) + "\n" | |
| if state['databricks_metadata']: | |
| context += "Databricks Metadata:\n" + "\n".join(state['databricks_metadata']) + "\n" | |
| if state.get('sql_query'): | |
| context += "Generated SQL Query:\n" + state['sql_query'] + "\n" | |
| if state.get('sql_result'): | |
| context += "SQL Result:\n" + state['sql_result'] + "\n" | |
| if not context.strip(): | |
| state['answer'] = "I'm sorry, I couldn't find sufficient information to answer your question." | |
| else: | |
| prompt = f""" | |
| You are an expert answer generator. Using this context: | |
| {context} | |
| Provide a clear and concise answer to the question: "{state['question']}". | |
| If an SQL result is included, explain it in text and, if suitable, format it as a table. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| state['answer'] = response.content.strip() | |
| state['final_answer'] = state['answer'] | |
| logger.debug(f"Generated answer: {state['answer']}") | |
| return state | |
| def guardrails(state: State) -> State: | |
| logger.info("Entering guardrails stage") | |
| answer = state['answer'].lower() | |
| prohibited_words = ["harmful", "offensive", "confidential"] | |
| if any(word in answer for word in prohibited_words): | |
| state['final_answer'] = "I'm sorry, I cannot provide that information due to content restrictions." | |
| elif len(state['answer']) > 1000: | |
| state['final_answer'] = "The response is too long to display fully. Please refine your question." | |
| logger.debug(f"Final answer after guardrails: {state['final_answer']}") | |
| return state | |
| def clarification_node(state: State) -> State: | |
| logger.info("Entering clarification_node stage") | |
| prompt_template = PromptTemplate( | |
| input_variables=["question"], | |
| template=""" | |
| You are a clarification expert. The question "{question}" is unclear. | |
| Ask a concise follow-up question to clarify the user's intent. | |
| Return the follow-up question in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": clarification_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(question=state['question']) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = clarification_parser.parse(response.content) | |
| state['final_answer'] = result['clarification_question'] | |
| logger.debug(f"Clarification question: {state['final_answer']}") | |
| return state | |
| # Define Graph | |
| graph = StateGraph(State) | |
| memory = MemorySaver() | |
| # Add nodes | |
| graph.add_node("question_analyzer", question_analyzer) | |
| graph.add_node("query_rewriter", query_rewriter) | |
| graph.add_node("confluence_retrieval", confluence_retrieval) | |
| graph.add_node("databricks_retrieval", databricks_retrieval) | |
| graph.add_node("result_reranker", result_reranker) | |
| graph.add_node("sql_generation", sql_generation) | |
| graph.add_node("sql_execution", sql_execution) | |
| graph.add_node("answer_generation", answer_generation) | |
| graph.add_node("guardrails", guardrails) | |
| graph.add_node("clarification_node", clarification_node) | |
| # Define routing functions | |
| def route_after_analyzer(state: State) -> str: | |
| return "clarification_node" if state['clarification_needed'] else "query_rewriter" | |
| def route_after_rewriter(state: State) -> str: | |
| intent = state['intent'] | |
| if intent == "confluence": | |
| return "confluence_retrieval" | |
| elif intent == "databricks_metadata": | |
| return "databricks_retrieval" | |
| elif intent == "sql_generation": | |
| return "sql_generation" | |
| return "clarification_node" | |
| def route_after_sql_execution(state: State) -> str: | |
| if state['attempt_count'] < 3 and not state.get('sql_result'): | |
| return "sql_generation" | |
| return "answer_generation" | |
| # Define edges with START and END | |
| graph.add_edge(START, "question_analyzer") | |
| graph.add_conditional_edges("question_analyzer", route_after_analyzer, {"clarification_node": "clarification_node", "query_rewriter": "query_rewriter"}) | |
| graph.add_conditional_edges("query_rewriter", route_after_rewriter, { | |
| "confluence_retrieval": "confluence_retrieval", | |
| "databricks_retrieval": "databricks_retrieval", | |
| "sql_generation": "sql_generation", | |
| "clarification_node": "clarification_node" | |
| }) | |
| graph.add_edge("confluence_retrieval", "result_reranker") | |
| graph.add_edge("databricks_retrieval", "result_reranker") | |
| graph.add_edge("sql_generation", "sql_execution") | |
| graph.add_conditional_edges("sql_execution", route_after_sql_execution, {"sql_generation": "sql_generation", "answer_generation": "answer_generation"}) | |
| graph.add_edge("result_reranker", "answer_generation") | |
| graph.add_edge("answer_generation", "guardrails") | |
| graph.add_edge("guardrails", END) | |
| graph.add_edge("clarification_node", END) | |
| # Compile the graph with memory | |
| app = graph.compile(checkpointer=memory) | |
| # Interface for user questions with thread_id | |
| def ask_question(question: str, thread_id: str) -> str: | |
| logger.info(f"Processing question '{question}' with thread_id '{thread_id}'") | |
| config = {"configurable": {"thread_id": thread_id}} | |
| initial_state = { | |
| "question": question, | |
| "rewritten_query": None, | |
| "intent": None, | |
| "confluence_docs": [], | |
| "databricks_metadata": [], | |
| "sql_query": None, | |
| "sql_result": None, | |
| "answer": None, | |
| "clarification_needed": False, | |
| "final_answer": None, | |
| "attempt_count": 0 | |
| } | |
| result = app.invoke(initial_state, config=config) | |
| logger.info(f"Returning answer: {result['final_answer']}") | |
| return result['final_answer'] | |
| # Example usage | |
| if __name__ == "__main__": | |
| # New conversation | |
| answer1 = ask_question("What is the total sales in the last quarter?", "thread_1") | |
| print("Answer 1:", answer1) | |
| # Follow-up in same thread | |
| answer2 = ask_question("Break it down by region", "thread_1") | |
| print("Answer 2:", answer2) | |
| # New conversation with different thread_id | |
| answer3 = ask_question("How do I configure Databricks?", "thread_2") | |
| print("Answer 3:", answer3) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import json | |
| import boto3 | |
| from typing import TypedDict, Annotated, List | |
| from langchain_aws import ChatBedrock | |
| from langchain_community.vectorstores import PGVector | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.tools import tool | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.base import BaseCheckpointSaver | |
| import operator | |
| import logging | |
| from pydantic import BaseModel, Field | |
| from langchain.output_parsers import PydanticOutputParser | |
| import sqlparse | |
| # Set up logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| # --- DynamoDB State Persistence --- | |
| class DynamoDBSaver(BaseCheckpointSaver): | |
| def __init__(self, table_name: str = "conversation_state"): | |
| self.dynamodb = boto3.resource('dynamodb') | |
| self.table = self.dynamodb.Table(table_name) | |
| def get(self, config: dict) -> dict: | |
| thread_id = config["configurable"]["thread_id"] | |
| response = self.table.get_item(Key={"thread_id": thread_id}) | |
| return json.loads(response["Item"]["state"]) if "Item" in response else {} | |
| def put(self, config: dict, checkpoint: dict) -> None: | |
| thread_id = config["configurable"]["thread_id"] | |
| self.table.put_item(Item={"thread_id": thread_id, "state": json.dumps(checkpoint)}) | |
| # --- Pydantic Models for Output Parsing --- | |
| class IntentOutput(BaseModel): | |
| intent: str = Field(description="The classified intent: 'confluence', 'databricks', 'both', or 'ambiguous'") | |
| class SQLOutput(BaseModel): | |
| sql: str = Field(description="The generated or refined SQL query") | |
| class RewrittenQueryOutput(BaseModel): | |
| rewritten_query: str = Field(description="The rewritten query for clarity and specificity") | |
| class SafetyOutput(BaseModel): | |
| is_safe: str = Field(description="Whether the answer is safe and appropriate: 'yes' or 'no'") | |
| class AnswerOutput(BaseModel): | |
| answer: str = Field(description="The final answer to the user's question") | |
| class ValidationOutput(BaseModel): | |
| is_valid: str = Field(description="Whether the SQL results make sense for the question: 'yes' or 'no'") | |
| # --- Agent State Definition --- | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
| intent: str | |
| confluence_context: List[str] | |
| databricks_context: List[str] | |
| generated_sql: str | |
| sql_attempts: int | |
| sql_error: str | |
| needs_clarification: bool | |
| final_answer: str | |
| rewritten_query: str | |
| # --- Configurations --- | |
| os.environ["AWS_REGION"] = "us-east-1" | |
| CONFLUENCE_CONNECTION_STRING = os.getenv("CONFLUENCE_DB_URL") | |
| DATABRICKS_CONNECTION_STRING = os.getenv("DATABRICKS_DB_URL") | |
| # Initialize Bedrock model | |
| llm = ChatBedrock( | |
| model_id="anthropic.claude-3-5-sonnet-20240620", | |
| region_name="us-east-1", | |
| model_kwargs={"temperature": 0.7} | |
| ) | |
| # Initialize vector stores | |
| from langchain_openai import OpenAIEmbeddings | |
| embeddings = OpenAIEmbeddings() | |
| confluence_store = PGVector( | |
| collection_name="confluence_docs", | |
| connection_string=CONFLUENCE_CONNECTION_STRING, | |
| embedding_function=embeddings | |
| ) | |
| databricks_store = PGVector( | |
| collection_name="databricks_metadata", | |
| connection_string=DATABRICKS_CONNECTION_STRING, | |
| embedding_function=embeddings | |
| ) | |
| # --- Tools and Utilities --- | |
| @tool | |
| def generate_sql(user_query: str = None, previous_sql: str = None, error: str = None, metadata: List[str] = None) -> str: | |
| logger.debug("Generating SQL with user_query=%s, previous_sql=%s, error=%s, metadata=%s", user_query, previous_sql, error, metadata) | |
| parser = PydanticOutputParser(pydantic_object=SQLOutput) | |
| format_instructions = parser.get_format_instructions() | |
| if previous_sql and error: | |
| prompt = PromptTemplate( | |
| input_variables=["previous_sql", "error", "metadata"], | |
| partial_variables={"format_instructions": format_instructions}, | |
| template="Refine this SQL query '{previous_sql}' that caused error '{error}' using metadata: {metadata}\n\n{format_instructions}" | |
| ) | |
| chain = prompt | llm | parser | |
| response = chain.invoke({"previous_sql": previous_sql, "error": error, "metadata": metadata}) | |
| elif user_query: | |
| prompt = PromptTemplate( | |
| input_variables=["query", "metadata"], | |
| partial_variables={"format_instructions": format_instructions}, | |
| template="Generate a SQL query for '{query}' using metadata: {metadata}\n\n{format_instructions}" | |
| ) | |
| chain = prompt | llm | parser | |
| response = chain.invoke({"query": user_query, "metadata": metadata}) | |
| else: | |
| raise ValueError("Must provide either user_query or previous_sql and error") | |
| sql = response.sql | |
| logger.debug("Generated SQL: %s", sql) | |
| return sql | |
| def execute_databricks_sql(sql: str) -> dict: | |
| logger.debug("Executing SQL: %s", sql) | |
| if "error" in sql.lower(): # Simulated error condition | |
| return {"status": "error", "error": "Simulated SQL error"} | |
| return {"status": "success", "results": ["Simulated results"]} | |
| def rewrite_query(query: str) -> str: | |
| logger.debug("Rewriting query: %s", query) | |
| parser = PydanticOutputParser(pydantic_object=RewrittenQueryOutput) | |
| format_instructions = parser.get_format_instructions() | |
| prompt = PromptTemplate( | |
| input_variables=["query"], | |
| partial_variables={"format_instructions": format_instructions}, | |
| template="Rewrite this query to make it clearer and more specific: '{query}'\n\n{format_instructions}" | |
| ) | |
| chain = prompt | llm | parser | |
| response = chain.invoke({"query": query}) | |
| rewritten = response.rewritten_query | |
| logger.debug("Rewritten query: %s", rewritten) | |
| return rewritten | |
| # --- Guardrails --- | |
| def check_answer_safety(answer: str) -> bool: | |
| logger.debug("Checking answer safety: %s", answer) | |
| parser = PydanticOutputParser(pydantic_object=SafetyOutput) | |
| format_instructions = parser.get_format_instructions() | |
| prompt = PromptTemplate( | |
| input_variables=["answer"], | |
| partial_variables={"format_instructions": format_instructions}, | |
| template="Is this answer safe and appropriate? Answer 'yes' or 'no': {answer}\n\n{format_instructions}" | |
| ) | |
| chain = prompt | llm | parser | |
| response = chain.invoke({"answer": answer}) | |
| is_safe = response.is_safe.lower() == "yes" | |
| logger.debug("Answer safety: %s", is_safe) | |
| return is_safe | |
| def is_safe_sql(sql: str) -> bool: | |
| logger.debug("Checking SQL safety: %s", sql) | |
| sql_lower = sql.lower().strip() | |
| is_safe = sql_lower.startswith("select") and all( | |
| keyword not in sql_lower for keyword in ["insert", "update", "delete", "drop"] | |
| ) | |
| logger.debug("SQL is safe: %s", is_safe) | |
| return is_safe | |
| def is_valid_sql(sql: str) -> bool: | |
| try: | |
| parsed = sqlparse.parse(sql) | |
| return len(parsed) > 0 and all(stmt.get_type() != 'UNKNOWN' for stmt in parsed) | |
| except Exception: | |
| return False | |
| def validate_results(question: str, results: List[str]) -> bool: | |
| logger.debug("Validating results for question: %s, results: %s", question, results) | |
| results_str = "\n".join(results) | |
| parser = PydanticOutputParser(pydantic_object=ValidationOutput) | |
| format_instructions = parser.get_format_instructions() | |
| prompt = PromptTemplate( | |
| input_variables=["question", "results"], | |
| partial_variables={"format_instructions": format_instructions}, | |
| template="""Does the following SQL result answer the question? | |
| Question: {question} | |
| SQL Result: | |
| {results} | |
| Respond with 'yes' or 'no'.\n\n{format_instructions}""" | |
| ) | |
| chain = prompt | llm | parser | |
| response = chain.invoke({"question": question, "results": results_str}) | |
| is_valid = response.is_valid.lower() == "yes" | |
| logger.debug("Results validation: %s", is_valid) | |
| return is_valid | |
| # --- Workflow Nodes --- | |
| def parse_question(state: AgentState) -> AgentState: | |
| logger.info("Entering parse_question") | |
| question = state["messages"][-1].content | |
| logger.debug("Question: %s", question) | |
| parser = PydanticOutputParser(pydantic_object=IntentOutput) | |
| format_instructions = parser.get_format_instructions() | |
| prompt = PromptTemplate( | |
| input_variables=["question"], | |
| partial_variables={"format_instructions": format_instructions}, | |
| template="""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous'. | |
| Question: {question} | |
| {format_instructions}""" | |
| ) | |
| chain = prompt | llm | parser | |
| response = chain.invoke({"question": question}) | |
| state["intent"] = response.intent | |
| logger.debug("Intent: %s", state["intent"]) | |
| logger.info("Exiting parse_question") | |
| return state | |
| def rewrite_query_node(state: AgentState) -> AgentState: | |
| logger.info("Entering rewrite_query_node") | |
| question = state["messages"][-1].content | |
| state["rewritten_query"] = rewrite_query(question) | |
| logger.debug("Rewritten query: %s", state["rewritten_query"]) | |
| logger.info("Exiting rewrite_query_node") | |
| return state | |
| def route_context(state: AgentState) -> AgentState: | |
| logger.info("Entering route_context") | |
| query = state["rewritten_query"] | |
| if state["intent"] in ["confluence", "both"]: | |
| results = confluence_store.similarity_search(query, k=3) | |
| state["confluence_context"] = [doc.page_content for doc in results] | |
| logger.debug("Confluence context: %s", state["confluence_context"]) | |
| if state["intent"] in ["databricks", "both"]: | |
| results = databricks_store.similarity_search(query, k=3) | |
| state["databricks_context"] = [doc.page_content for doc in results] | |
| logger.debug("Databricks context: %s", state["databricks_context"]) | |
| if state["intent"] == "ambiguous": | |
| state["needs_clarification"] = True | |
| logger.info("Exiting route_context") | |
| return state | |
| def clarify_question(state: AgentState) -> AgentState: | |
| logger.info("Entering clarify_question") | |
| if state["needs_clarification"]: | |
| state["messages"].append(AIMessage(content="Please clarify your query.")) | |
| state["needs_clarification"] = False | |
| logger.info("Exiting clarify_question") | |
| return state | |
| def generate_sql_node(state: AgentState) -> AgentState: | |
| logger.info("Entering generate_sql_node") | |
| if state["intent"] in ["databricks", "both"] and state["databricks_context"]: | |
| metadata = state["databricks_context"] | |
| state["generated_sql"] = generate_sql(user_query=state["rewritten_query"], metadata=metadata) | |
| state["sql_attempts"] = 1 | |
| logger.info("Exiting generate_sql_node") | |
| return state | |
| def execute_sql_node(state: AgentState) -> AgentState: | |
| logger.info("Entering execute_sql_node") | |
| if state.get("generated_sql"): | |
| if not is_safe_sql(state["generated_sql"]): | |
| state["databricks_context"] = ["Unsafe SQL query detected."] | |
| state["sql_error"] = None | |
| elif not is_valid_sql(state["generated_sql"]): | |
| state["sql_error"] = "Invalid SQL syntax" | |
| else: | |
| result = execute_databricks_sql(state["generated_sql"]) | |
| if result["status"] == "success": | |
| if validate_results(state["rewritten_query"], result["results"]): | |
| state["databricks_context"] = result["results"] | |
| state["sql_error"] = None | |
| else: | |
| state["sql_error"] = "Results do not make sense for the question" | |
| else: | |
| state["sql_error"] = result["error"] | |
| if state["sql_error"] and state.get("sql_attempts", 0) < 3: | |
| state["generated_sql"] = generate_sql( | |
| previous_sql=state["generated_sql"], | |
| error=state["sql_error"], | |
| metadata=state["databricks_context"] | |
| ) | |
| state["sql_attempts"] = state.get("sql_attempts", 0) + 1 | |
| state = execute_sql_node(state) # Recursive retry | |
| elif state["sql_error"]: | |
| state["databricks_context"] = ["Unable to retrieve data due to persistent errors."] | |
| state["sql_error"] = None | |
| logger.info("Exiting execute_sql_node") | |
| return state | |
| def generate_answer(state: AgentState) -> AgentState: | |
| logger.info("Entering generate_answer") | |
| history = "\n".join([f"{msg.type}: {msg.content}" for msg in state["messages"][-5:]]) | |
| confluence_ctx = "\n".join(state.get("confluence_context", [])) | |
| databricks_ctx = "\n".join(state.get("databricks_context", [])) | |
| parser = PydanticOutputParser(pydantic_object=AnswerOutput) | |
| format_instructions = parser.get_format_instructions() | |
| prompt = PromptTemplate( | |
| input_variables=["history", "query", "confluence_ctx", "databricks_ctx"], | |
| partial_variables={"format_instructions": format_instructions}, | |
| template="""Conversation history: | |
| {history} | |
| Current question: {query} | |
| Using Confluence context: {confluence_ctx} | |
| And Databricks context: {databricks_ctx} | |
| Provide a concise, accurate response.\n\n{format_instructions}""" | |
| ) | |
| chain = prompt | llm | parser | |
| response = chain.invoke({ | |
| "history": history, | |
| "query": state["rewritten_query"], | |
| "confluence_ctx": confluence_ctx, | |
| "databricks_ctx": databricks_ctx | |
| }) | |
| answer = response.answer | |
| if check_answer_safety(answer): | |
| state["final_answer"] = answer | |
| else: | |
| state["final_answer"] = "I'm sorry, but I can't provide that information." | |
| state["messages"].append(AIMessage(content=state["final_answer"])) | |
| logger.debug("Final answer: %s", state["final_answer"]) | |
| logger.info("Exiting generate_answer") | |
| return state | |
| # --- Workflow Definition --- | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("parse_question", parse_question) | |
| workflow.add_node("rewrite_query", rewrite_query_node) | |
| workflow.add_node("route_context", route_context) | |
| workflow.add_node("clarify_question", clarify_question) | |
| workflow.add_node("generate_sql", generate_sql_node) | |
| workflow.add_node("execute_sql", execute_sql_node) | |
| workflow.add_node("generate_answer", generate_answer) | |
| workflow.set_entry_point("parse_question") | |
| workflow.add_edge("parse_question", "rewrite_query") | |
| workflow.add_edge("rewrite_query", "route_context") | |
| workflow.add_edge("route_context", "clarify_question") | |
| workflow.add_conditional_edges( | |
| "clarify_question", | |
| lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer", | |
| {"generate_sql": "generate_sql", "generate_answer": "generate_answer"} | |
| ) | |
| workflow.add_edge("generate_sql", "execute_sql") | |
| workflow.add_edge("execute_sql", "generate_answer") | |
| workflow.add_edge("generate_answer", END) | |
| # Compile the workflow with state persistence | |
| checkpointer = DynamoDBSaver(table_name="conversation_state") | |
| graph = workflow.compile(checkpointer=checkpointer) | |
| # --- Main Function to Run the Agent --- | |
| def run_agent(question: str, thread_id: str = "thread_1") -> str: | |
| logger.info("Running agent for question: %s, thread_id: %s", question, thread_id) | |
| initial_state = {"messages": [HumanMessage(content=question)]} | |
| result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}}) | |
| response = result["messages"][-1].content | |
| logger.info("Agent response: %s", response) | |
| return response | |
| # --- Example Usage --- | |
| if __name__ == "__main__": | |
| response = run_agent("What is the process for onboarding in Confluence?") | |
| print(response) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import json | |
| from typing import TypedDict, List, Optional, Annotated | |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
| from langchain_aws import ChatBedrock | |
| from langchain.vectorstores.pgvector import PGVector | |
| from langchain.embeddings import BedrockEmbeddings | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from databricks import sql | |
| from langchain.output_parsers import PydanticOutputParser | |
| from langchain.prompts import PromptTemplate | |
| from pydantic import BaseModel, Field | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Bedrock Sonnet model | |
| llm = ChatBedrock( | |
| model_id="anthropic.claude-3-sonnet-20240229-v1:0", | |
| region_name="us-east-1" # Replace with your AWS region | |
| ) | |
| # Initialize embedding function | |
| embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1") | |
| # Initialize PGVector stores (replace connection strings) | |
| confluence_store = PGVector( | |
| collection_name="confluence", | |
| connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname", | |
| embedding_function=embeddings | |
| ) | |
| databricks_store = PGVector( | |
| collection_name="databricks_system_metadata", | |
| connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname", | |
| embedding_function=embeddings | |
| ) | |
| # Databricks SQL connection (replace with your credentials) | |
| databricks_connection = sql.connect( | |
| server_hostname="your_databricks_host", | |
| http_path="your_http_path", | |
| access_token="your_access_token" | |
| ) | |
| # Define State schema with defaults | |
| def _default_messages() -> List[BaseMessage]: | |
| return [] | |
| def _append_message(current: List[BaseMessage], new: BaseMessage) -> List[BaseMessage]: | |
| return current + [new] | |
| class State(TypedDict): | |
| messages: Annotated[List[BaseMessage], _append_message, _default_messages] | |
| question: str | |
| rewritten_query: Optional[str] | |
| intent: Optional[str] | |
| confluence_docs: Annotated[List[str], lambda x, y: x + [y], lambda: []] | |
| databricks_metadata: Annotated[List[str], lambda x, y: x + [y], lambda: []] | |
| sql_query: Optional[str] | |
| sql_result: Optional[str] | |
| answer: Optional[str] | |
| clarification_needed: Annotated[bool, lambda x, y: y, lambda: False] | |
| final_answer: Optional[str] | |
| attempt_count: Annotated[int, lambda x, y: y, lambda: 0] | |
| sql_valid: Annotated[bool, lambda x, y: y, lambda: False] | |
| history_answer: Optional[str] | |
| # Define Pydantic Models for Structured Output | |
| class IntentResponse(BaseModel): | |
| intent: str = Field(description="The classified intent: 'confluence' or 'databricks'") | |
| clarification_needed: bool = Field(description="Boolean indicating if clarification is needed") | |
| class ClarificationResponse(BaseModel): | |
| clarification_question: str = Field(description="A follow-up question to clarify intent") | |
| # Initialize Pydantic Output Parsers | |
| question_analyzer_parser = PydanticOutputParser(pydantic_object=IntentResponse) | |
| clarification_parser = PydanticOutputParser(pydantic_object=ClarificationResponse) | |
| # Define Nodes | |
| def start_conversation(state: State) -> State: | |
| logger.info("Entering start_conversation stage") | |
| human_message = HumanMessage(content=state['question']) | |
| state['messages'] = state['messages'] + [human_message] # Append to default empty list | |
| logger.debug(f"Appended human message: {human_message.content}") | |
| return state | |
| def check_history(state: State) -> State: | |
| logger.info("Entering check_history stage") | |
| if len(state['messages']) > 1: # Check prior history | |
| history = "\n".join([f"{msg.type}: {msg.content}" for msg in state['messages'][:-1]]) | |
| prompt = f""" | |
| You are a history analysis expert. Given this conversation history: | |
| {history} | |
| Can the question "{state['question']}" be answered based on this history? If yes, provide the answer as plain text. If no, return "no". | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| answer = response.content.strip() | |
| if answer != "no": | |
| state['history_answer'] = answer | |
| state['final_answer'] = answer | |
| logger.debug(f"Answer found in history: {state['history_answer']}") | |
| else: | |
| state['history_answer'] = None | |
| logger.debug("No answer found in history") | |
| return state | |
| def question_analyzer(state: State) -> State: | |
| logger.info("Entering question_analyzer stage") | |
| if state.get('history_answer'): | |
| return state | |
| prompt_template = PromptTemplate( | |
| input_variables=["question"], | |
| template=""" | |
| You are an expert intent classifier. Given this question: "{question}", determine its intent. Choose one: | |
| - "confluence" for general knowledge or documentation queries | |
| - "databricks" for questions requiring Databricks system table data | |
| If the question is unclear, set clarification_needed to true; otherwise, false. | |
| Return your response in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": question_analyzer_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(question=state['question']) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = question_analyzer_parser.parse(response.content) | |
| state['intent'] = result.intent | |
| state['clarification_needed'] = result.clarification_needed | |
| logger.debug(f"Intent: {state['intent']}, Clarification needed: {state['clarification_needed']}") | |
| return state | |
| def query_rewriter(state: State) -> State: | |
| logger.info("Entering query_rewriter stage") | |
| if state.get('history_answer'): | |
| return state | |
| prompt = f""" | |
| You are a query optimization expert. Rewrite this question to make it more precise for retrieval: "{state['question']}". | |
| Provide the rewritten query as plain text. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| state['rewritten_query'] = response.content.strip() | |
| logger.debug(f"Rewritten query: {state['rewritten_query']}") | |
| return state | |
| def confluence_retrieval(state: State) -> State: | |
| logger.info("Entering confluence_retrieval stage") | |
| query = state.get('rewritten_query', state['question']) | |
| docs = confluence_store.similarity_search(query, k=5) | |
| state['confluence_docs'] = docs # Already a list from similarity_search | |
| logger.debug(f"Retrieved {len(state['confluence_docs'])} Confluence docs") | |
| return state | |
| def databricks_retrieval(state: State) -> State: | |
| logger.info("Entering databricks_retrieval stage") | |
| query = state.get('rewritten_query', state['question']) | |
| docs = databricks_store.similarity_search(query, k=5) | |
| state['databricks_metadata'] = [doc.page_content for doc in docs] | |
| logger.debug(f"Retrieved {len(state['databricks_metadata'])} Databricks metadata entries") | |
| return state | |
| def result_reranker(state: State) -> State: | |
| logger.info("Entering result_reranker stage") | |
| if state['intent'] == "confluence" and state['confluence_docs']: | |
| docs_str = "\n".join([f"Doc {i}: {doc.page_content}" for i, doc in enumerate(state['confluence_docs'])]) | |
| prompt = f""" | |
| You are a document relevance analyst. Given the question "{state['question']}" and these documents: | |
| {docs_str} | |
| Return the indices (0-based) of the top 3 most relevant documents as a JSON list, e.g., [0, 1, 2]. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| indices = json.loads(response.content.strip()) | |
| state['confluence_docs'] = [state['confluence_docs'][i] for i in indices if i < len(state['confluence_docs'])] | |
| logger.debug(f"Reranked Confluence docs: {[doc.page_content for doc in state['confluence_docs']]}") | |
| elif state['intent'] == "databricks" and state['databricks_metadata']: | |
| docs_str = "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(state['databricks_metadata'])]) | |
| prompt = f""" | |
| You are a metadata relevance analyst. Given the question "{state['question']}" and these Databricks system table metadata entries: | |
| {docs_str} | |
| Return the indices (0-based) of the top 3 most relevant entries as a JSON list, e.g., [0, 1, 2]. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| indices = json.loads(response.content.strip()) | |
| state['databricks_metadata'] = [state['databricks_metadata'][i] for i in indices if i < len(state['databricks_metadata'])] | |
| logger.debug(f"Reranked Databricks metadata: {state['databricks_metadata']}") | |
| return state | |
| def sql_generation(state: State) -> State: | |
| logger.info("Entering sql_generation stage") | |
| if not state['databricks_metadata']: | |
| state = databricks_retrieval(state) | |
| state = result_reranker(state) | |
| metadata = "\n".join(state['databricks_metadata']) | |
| prompt = f""" | |
| You are an expert SQL generator for Databricks. Using this system table metadata: | |
| {metadata} | |
| Write an SQL query to fetch data from Databricks that answers the question: "{state['question']}". | |
| Provide the SQL query as plain text. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| state['sql_query'] = response.content.strip() | |
| logger.debug(f"Generated SQL query: {state['sql_query']}") | |
| return state | |
| def sql_validation(state: State) -> State: | |
| logger.info("Entering sql_validation stage") | |
| metadata = "\n".join(state['databricks_metadata']) | |
| prompt = f""" | |
| You are an SQL validation expert. Given this Databricks system table metadata: | |
| {metadata} | |
| Check if this SQL query is valid and likely to answer the question "{state['question']}": | |
| {state['sql_query']} | |
| Return "valid" if the query is correct, or "invalid" if it’s not. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| state['sql_valid'] = response.content.strip() == "valid" | |
| logger.debug(f"SQL validation result: {state['sql_valid']}") | |
| return state | |
| def sql_execution(state: State) -> State: | |
| logger.info("Entering sql_execution stage") | |
| sql_query = state['sql_query'] | |
| max_attempts = 3 | |
| attempt = state['attempt_count'] | |
| if attempt >= max_attempts: | |
| state['final_answer'] = "Failed to generate a valid SQL query or retrieve meaningful results from Databricks after 3 attempts." | |
| logger.info("SQL execution abandoned after max retries") | |
| return state | |
| if not state['sql_valid']: | |
| logger.info(f"SQL invalid, retrying generation (attempt {attempt + 1})") | |
| state['attempt_count'] += 1 | |
| return sql_generation(state) | |
| try: | |
| with databricks_connection.cursor() as cursor: | |
| logger.info(f"Executing SQL in Databricks (Attempt {attempt + 1}): {sql_query}") | |
| cursor.execute(sql_query) | |
| result = cursor.fetchall() | |
| state['sql_result'] = str(result) | |
| logger.debug(f"SQL result: {state['sql_result']}") | |
| return state | |
| except Exception as e: | |
| logger.error(f"SQL execution failed: {e}") | |
| state['attempt_count'] += 1 | |
| logger.info(f"Retrying SQL generation due to execution error (attempt {attempt + 1})") | |
| return sql_generation(state) | |
| def sql_refinement(state: State) -> State: | |
| logger.info("Entering sql_refinement stage") | |
| metadata = "\n".join(state['databricks_metadata']) | |
| prompt = f""" | |
| You are an SQL refinement expert. The previous SQL query: | |
| {state['sql_query']} | |
| Produced this result: | |
| {state['sql_result']} | |
| For the question "{state['question']}", the result didn’t fully satisfy the query intent. Using this Databricks system table metadata: | |
| {metadata} | |
| Refine the SQL query to better answer the question. Provide the improved SQL query as plain text. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| state['sql_query'] = response.content.strip() | |
| state['sql_valid'] = False # Reset for new query | |
| logger.debug(f"Refined SQL query: {state['sql_query']}") | |
| return state | |
| def answer_generation(state: State) -> State: | |
| logger.info("Entering answer_generation stage") | |
| context = "" | |
| if state['intent'] == "confluence" and state['confluence_docs']: | |
| context += "Confluence Documents:\n" + "\n".join([doc.page_content for doc in state['confluence_docs']]) + "\n" | |
| elif state['intent'] == "databricks": | |
| if state.get('databricks_metadata'): | |
| context += "Databricks System Table Metadata:\n" + "\n".join(state['databricks_metadata']) + "\n" | |
| if state.get('sql_query'): | |
| context += "Generated SQL Query:\n" + state['sql_query'] + "\n" | |
| if state.get('sql_result'): | |
| context += "SQL Result:\n" + state['sql_result'] + "\n" | |
| history = "\n".join([f"{msg.type}: {msg.content}" for msg in state['messages']]) if state['messages'] else "No prior history" | |
| if not context.strip(): | |
| state['answer'] = "I'm sorry, I couldn't find sufficient information to answer your question." | |
| else: | |
| prompt = f""" | |
| You are an expert answer generator. Using this context: | |
| {context} | |
| And this conversation history: | |
| {history} | |
| Provide a clear and concise answer to the question: "{state['question']}". | |
| If an SQL result is included, explain it in text and, if suitable, format it as a table. Ensure the answer is consistent with or builds on the history where relevant. | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| state['answer'] = response.content.strip() | |
| if state['intent'] == "databricks" and state.get('sql_result'): | |
| prompt = f""" | |
| You are a result validation expert. Given the question "{state['question']}" and this answer: | |
| {state['answer']} | |
| Does the answer fully satisfy the question based on the SQL result? Return "yes" or "no". | |
| """ | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| if response.content.strip() == "no" and state['attempt_count'] < 3: | |
| logger.info(f"Answer doesn’t satisfy question, refining SQL (attempt {state['attempt_count'] + 1})") | |
| state['attempt_count'] += 1 | |
| return sql_refinement(state) | |
| state['final_answer'] = state['answer'] | |
| logger.debug(f"Generated answer: {state['answer']}") | |
| return state | |
| def guardrails(state: State) -> State: | |
| logger.info("Entering guardrails stage") | |
| answer = state['answer'].lower() | |
| prohibited_words = ["harmful", "offensive", "confidential"] | |
| if any(word in answer for word in prohibited_words): | |
| state['final_answer'] = "I'm sorry, I cannot provide that information due to content restrictions." | |
| elif len(state['answer']) > 1000: | |
| state['final_answer'] = "The response is too long to display fully. Please refine your question." | |
| logger.debug(f"Final answer after guardrails: {state['final_answer']}") | |
| return state | |
| def clarification_node(state: State) -> State: | |
| logger.info("Entering clarification_node stage") | |
| prompt_template = PromptTemplate( | |
| input_variables=["question"], | |
| template=""" | |
| You are a clarification expert. The question "{question}" is unclear. | |
| Ask a concise follow-up question to clarify the user's intent. | |
| Return the follow-up question in this JSON format: | |
| {format_instructions} | |
| """, | |
| partial_variables={"format_instructions": clarification_parser.get_format_instructions()} | |
| ) | |
| prompt = prompt_template.format(question=state['question']) | |
| response = llm.invoke([HumanMessage(content=prompt)]) | |
| result = clarification_parser.parse(response.content) | |
| state['final_answer'] = result.clarification_question | |
| logger.debug(f"Clarification question: {state['final_answer']}") | |
| return state | |
| def end_conversation(state: State) -> State: | |
| logger.info("Entering end_conversation stage") | |
| ai_message = AIMessage(content=state['final_answer']) | |
| state['messages'] = state['messages'] + [ai_message] | |
| logger.debug(f"Appended AI message: {ai_message.content}") | |
| return state | |
| # Define Graph | |
| graph = StateGraph(State) | |
| memory = MemorySaver() | |
| # Add nodes | |
| graph.add_node("start_conversation", start_conversation) | |
| graph.add_node("check_history", check_history) | |
| graph.add_node("question_analyzer", question_analyzer) | |
| graph.add_node("query_rewriter", query_rewriter) | |
| graph.add_node("confluence_retrieval", confluence_retrieval) | |
| graph.add_node("databricks_retrieval", databricks_retrieval) | |
| graph.add_node("result_reranker", result_reranker) | |
| graph.add_node("sql_generation", sql_generation) | |
| graph.add_node("sql_validation", sql_validation) | |
| graph.add_node("sql_execution", sql_execution) | |
| graph.add_node("sql_refinement", sql_refinement) | |
| graph.add_node("answer_generation", answer_generation) | |
| graph.add_node("guardrails", guardrails) | |
| graph.add_node("clarification_node", clarification_node) | |
| graph.add_node("end_conversation", end_conversation) | |
| # Define routing functions | |
| def route_after_history(state: State) -> str: | |
| return "end_conversation" if state.get('history_answer') else "question_analyzer" | |
| def route_after_analyzer(state: State) -> str: | |
| return "clarification_node" if state['clarification_needed'] else "query_rewriter" | |
| def route_after_rewriter(state: State) -> str: | |
| intent = state['intent'] | |
| if intent == "confluence": | |
| return "confluence_retrieval" | |
| elif intent == "databricks": | |
| return "databricks_retrieval" | |
| return "clarification_node" | |
| # Define edges with START and END | |
| graph.add_edge(START, "start_conversation") | |
| graph.add_edge("start_conversation", "check_history") | |
| graph.add_conditional_edges("check_history", route_after_history, { | |
| "end_conversation": "end_conversation", | |
| "question_analyzer": "question_analyzer" | |
| }) | |
| graph.add_conditional_edges("question_analyzer", route_after_analyzer, { | |
| "clarification_node": "clarification_node", | |
| "query_rewriter": "query_rewriter" | |
| }) | |
| graph.add_conditional_edges("query_rewriter", route_after_rewriter, { | |
| "confluence_retrieval": "confluence_retrieval", | |
| "databricks_retrieval": "databricks_retrieval", | |
| "clarification_node": "clarification_node" | |
| }) | |
| graph.add_edge("confluence_retrieval", "result_reranker") | |
| graph.add_edge("databricks_retrieval", "result_reranker") | |
| graph.add_edge("result_reranker", "sql_generation") | |
| graph.add_edge("sql_generation", "sql_validation") | |
| graph.add_edge("sql_validation", "sql_execution") | |
| graph.add_edge("sql_execution", "answer_generation") | |
| graph.add_edge("result_reranker", "answer_generation") | |
| graph.add_edge("answer_generation", "sql_refinement") | |
| graph.add_edge("sql_refinement", "sql_validation") | |
| graph.add_edge("answer_generation", "guardrails") | |
| graph.add_edge("guardrails", "end_conversation") | |
| graph.add_edge("clarification_node", "end_conversation") | |
| graph.add_edge("end_conversation", END) | |
| # Compile the graph with memory | |
| app = graph.compile(checkpointer=memory) | |
| # Interface for user questions with thread_id | |
| def ask_question(question: str, thread_id: str) -> str: | |
| logger.info(f"Processing question '{question}' with thread_id '{thread_id}'") | |
| config = {"configurable": {"thread_id": thread_id}} | |
| initial_state = {"question": question} # Only question required, others default | |
| result = app.invoke(initial_state, config=config) | |
| logger.info(f"Returning answer: {result['final_answer']}") | |
| logger.debug(f"Conversation history: {[msg.content for msg in result['messages']]}") | |
| return result['final_answer'] | |
| # Example usage | |
| if __name__ == "__main__": | |
| # New conversation (Databricks example) | |
| answer1 = ask_question("What is the total sales in the last quarter?", "thread_1") | |
| print("Answer 1:", answer1) | |
| # Follow-up in same thread | |
| answer2 = ask_question("Break it down by region", "thread_1") | |
| print("Answer 2:", answer2) | |
| # New conversation (Confluence example) | |
| answer3 = ask_question("How do I configure Databricks?", "thread_2") | |
| print("Answer 3:", answer3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment