Skip to content

Instantly share code, notes, and snippets.

@STHITAPRAJNAS
Last active February 25, 2025 02:36
Show Gist options
  • Select an option

  • Save STHITAPRAJNAS/5e9dfa7ff953e6bbff66a0a6e64e9dc7 to your computer and use it in GitHub Desktop.

Select an option

Save STHITAPRAJNAS/5e9dfa7ff953e6bbff66a0a6e64e9dc7 to your computer and use it in GitHub Desktop.
"""
project/
├── state.py # Defines the agent's state
├── prompts.py # Manages prompt templates
├── config.py # Handles configuration and initialization
├── tools.py # Contains utility functions and tools
├── guardrails.py # Implements safety checks
├── nodes.py # Defines workflow nodes for LangGraph
├── workflow.py # Orchestrates the main workflow using LangGraph
└── main.py # Runs the agent with example queries
"""
# 1. state.py - Agent State Definition
# This module defines the AgentState using a TypedDict to manage the agent's state throughout the workflow.
from typing import TypedDict, Annotated, List
from langchain_core.messages import HumanMessage, AIMessage
import operator
class AgentState(TypedDict):
messages: Annotated[List[HumanMessage | AIMessage], operator.add]
intent: str
confluence_context: List[str]
databricks_context: List[str]
generated_sql: str
sql_attempts: int
sql_error: str
needs_clarification: bool
final_answer: str
rewritten_query: str
# 2. prompts.py - Prompt Templates Management
# This module defines all prompt templates and a PromptManager class to handle them.
from langchain.prompts import PromptTemplate
class PromptManager:
def __init__(self):
self.intent_template = PromptTemplate(
input_variables=["question"],
template="""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous':
Question: {question}
Provide a one-word response: """
)
self.sql_generation_template = PromptTemplate(
input_variables=["query", "metadata"],
template="Generate a SQL query for '{query}' using metadata: {metadata}"
)
self.sql_refinement_template = PromptTemplate(
input_variables=["previous_sql", "error", "metadata"],
template="Refine this SQL query '{previous_sql}' that caused error '{error}' using metadata: {metadata}"
)
self.answer_template = PromptTemplate(
input_variables=["query", "confluence_ctx", "databricks_ctx"],
template="""Answer this question: '{query}'
Using Confluence context: {confluence_ctx}
And Databricks context: {databricks_ctx}
Provide a concise, accurate response."""
)
self.safety_template = PromptTemplate(
input_variables=["answer"],
template="Is this answer safe and appropriate for all audiences? Answer 'yes' or 'no': {answer}"
)
def get_intent_prompt(self, question: str) -> str:
return self.intent_template.format(question=question)
def get_sql_generation_prompt(self, query: str, metadata: str) -> str:
return self.sql_generation_template.format(query=query, metadata=metadata)
def get_sql_refinement_prompt(self, previous_sql: str, error: str, metadata: str) -> str:
return self.sql_refinement_template.format(previous_sql=previous_sql, error=error, metadata=metadata)
def get_answer_prompt(self, query: str, confluence_ctx: str, databricks_ctx: str) -> str:
return self.answer_template.format(query=query, confluence_ctx=confluence_ctx, databricks_ctx=databricks_ctx)
def get_safety_prompt(self, answer: str) -> str:
return self.safety_template.format(answer=answer)
# 3. config.py - Configuration and Initialization
# This module initializes the language model, embeddings, vector stores, and prompt manager.
import os
from langchain_aws import ChatBedrock
from langchain_community.vectorstores import PGVector
from langchain_openai import OpenAIEmbeddings
from prompts import PromptManager
# Load environment variables
os.environ["AWS_REGION"] = "us-east-1"
CONFLUENCE_CONNECTION_STRING = os.getenv("CONFLUENCE_DB_URL")
DATABRICKS_CONNECTION_STRING = os.getenv("DATABRICKS_DB_URL")
# Initialize LLM
llm = ChatBedrock(
model_id="anthropic.claude-3-5-sonnet-20240620",
region_name="us-east-1",
model_kwargs={"temperature": 0.7}
)
# Initialize embeddings
embeddings = OpenAIEmbeddings()
# Initialize vector stores
confluence_store = PGVector(
collection_name="confluence_docs",
connection_string=CONFLUENCE_CONNECTION_STRING,
embedding_function=embeddings
)
databricks_store = PGVector(
collection_name="databricks_metadata",
connection_string=DATABRICKS_CONNECTION_STRING,
embedding_function=embeddings
)
# Initialize prompt manager
prompt_manager = PromptManager()
# 4. tools.py - Utility Functions and Tools
# This module contains tools for SQL generation, query rewriting, result reranking, and simulated SQL execution.
from config import llm, prompt_manager
from langchain.tools import tool
from sentence_transformers import CrossEncoder
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
@tool
def generate_sql(user_query: str = None, previous_sql: str = None, error: str = None, metadata: List[str] = None) -> str:
"""
Generates or refines a SQL query based on inputs.
Args:
user_query (str, optional): The user's original query for initial SQL generation.
previous_sql (str, optional): The previous SQL query to refine if an error occurred.
error (str, optional): The error message from the previous SQL execution.
metadata (List[str]): Metadata to assist SQL generation.
Returns:
str: The generated or refined SQL query.
Raises:
ValueError: If inputs are invalid.
"""
if previous_sql and error:
prompt = prompt_manager.get_sql_refinement_prompt(previous_sql, error, metadata)
elif user_query:
prompt = prompt_manager.get_sql_generation_prompt(user_query, metadata)
else:
raise ValueError("Must provide either user_query or previous_sql and error")
response = llm.invoke(prompt)
return response.content
def rewrite_query(query: str) -> str:
prompt = f"Rewrite this query to make it clearer and more specific: '{query}'"
return llm.invoke(prompt).content
def rerank_results(results: List[str], query: str) -> List[str]:
pairs = [(query, doc) for doc in results]
scores = reranker.predict(pairs)
sorted_results = [doc for _, doc in sorted(zip(scores, results), reverse=True)]
return sorted_results
def execute_databricks_sql(sql: str) -> dict:
if "error" in sql.lower():
return {"status": "error", "error": "Simulated SQL error"}
return {"status": "success", "results": ["Simulated results"]}
# 5. guardrails.py - Safety Checks
# This module implements safety checks for answers and SQL queries.
from config import llm, prompt_manager
def check_answer_safety(answer: str) -> bool:
safety_prompt = prompt_manager.get_safety_prompt(answer)
safety_response = llm.invoke(safety_prompt).content.strip().lower()
if safety_response != "yes":
return False
forbidden_words = ["confidential", "password", "secret"]
return not any(word in answer.lower() for word in forbidden_words)
def is_safe_sql(sql: str) -> bool:
sql_lower = sql.lower().strip()
return (sql_lower.startswith("select") and
"insert" not in sql_lower and
"update" not in sql_lower and
"delete" not in sql_lower and
"drop" not in sql_lower)
# 6. nodes.py - Workflow Nodes
# This module defines the nodes for the LangGraph workflow.
from state import AgentState
from config import llm, prompt_manager, confluence_store, databricks_store
from tools import generate_sql, rewrite_query, rerank_results, execute_databricks_sql
from guardrails import check_answer_safety, is_safe_sql
def parse_question(state: AgentState) -> AgentState:
question = state["messages"][-1].content
prompt = prompt_manager.get_intent_prompt(question)
response = llm.invoke(prompt).content.strip()
state["intent"] = response
return state
def rewrite_query_node(state: AgentState) -> AgentState:
question = state["messages"][-1].content
state["rewritten_query"] = rewrite_query(question)
return state
def route_context(state: AgentState) -> AgentState:
query = state["rewritten_query"]
if state["intent"] in ["confluence", "both"]:
results = confluence_store.similarity_search(query, k=5)
state["confluence_context"] = rerank_results([doc.page_content for doc in results], query)
if state["intent"] in ["databricks", "both"]:
results = databricks_store.similarity_search(query, k=5)
state["databricks_context"] = rerank_results([doc.page_content for doc in results], query)
if state["intent"] == "ambiguous":
state["needs_clarification"] = True
return state
def clarify_question(state: AgentState) -> AgentState:
if state["needs_clarification"]:
state["messages"].append(AIMessage(content="Please clarify your query."))
state["needs_clarification"] = False
return state
def generate_sql_node(state: AgentState) -> AgentState:
metadata = state["databricks_context"]
state["generated_sql"] = generate_sql(user_query=state["rewritten_query"], metadata=metadata)
state["sql_attempts"] = state.get("sql_attempts", 0) + 1
return state
def execute_sql_node(state: AgentState) -> AgentState:
if not is_safe_sql(state["generated_sql"]):
state["sql_error"] = "Unsafe SQL query detected."
return state
result = execute_databricks_sql(state["generated_sql"])
if result["status"] == "error":
state["sql_error"] = result["error"]
else:
state["databricks_context"] = result["results"]
return state
def generate_answer(state: AgentState) -> AgentState:
prompt = prompt_manager.get_answer_prompt(
state["rewritten_query"],
state["confluence_context"],
state["databricks_context"]
)
answer = llm.invoke(prompt).content
if not check_answer_safety(answer):
state["final_answer"] = "Answer deemed unsafe or inappropriate."
else:
state["final_answer"] = answer
state["messages"].append(AIMessage(content=state["final_answer"]))
return state
# 7. workflow.py - Main Workflow Orchestration
# This module defines the LangGraph workflow, connecting all nodes.
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from state import AgentState
from nodes import parse_question, rewrite_query_node, route_context, clarify_question, generate_sql_node, execute_sql_node, generate_answer
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("parse_question", parse_question)
workflow.add_node("rewrite_query", rewrite_query_node)
workflow.add_node("route_context", route_context)
workflow.add_node("clarify_question", clarify_question)
workflow.add_node("generate_sql", generate_sql_node)
workflow.add_node("execute_sql", execute_sql_node)
workflow.add_node("generate_answer", generate_answer)
# Define edges
workflow.add_edge("parse_question", "rewrite_query")
workflow.add_edge("rewrite_query", "route_context")
workflow.add_edge("route_context", "clarify_question")
workflow.add_conditional_edges(
"clarify_question",
lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer",
{"generate_sql": "generate_sql", "generate_answer": "generate_answer"}
)
workflow.add_edge("generate_sql", "execute_sql")
workflow.add_edge("execute_sql", "generate_answer")
workflow.add_edge("generate_answer", END)
# Set entry point
workflow.set_entry_point("parse_question")
# Compile with memory
checkpointer = MemorySaver()
graph = workflow.compile(checkpointer=checkpointer)
# 8. main.py - Running the Agent
# This module runs the agent with example queries.
from workflow import graph
from state import AgentState
from langchain_core.messages import HumanMessage
def run_agent(question: str, thread_id: str = "thread_1"):
initial_state = AgentState(messages=[HumanMessage(content=question)])
result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}})
return result["messages"][-1].content
if __name__ == "__main__":
print(run_agent("What is the process for onboarding in Confluence?"))
print(run_agent("How many rows are in the sales table in Databricks?"))
print(run_agent("What’s the latest update on project X?"))
# This modular structure separates concerns effectively:
# State Management: state.py defines the agent's state.
# Prompt Templates: prompts.py manages all prompt templates.
# Configuration: config.py initializes models and stores.
# Tools: tools.py provides utility functions and tools.
# Safety Checks: guardrails.py ensures safe outputs.
# Workflow Nodes: nodes.py implements individual workflow steps.
# Main Workflow: workflow.py orchestrates the entire process.
# Execution: main.py runs the agent with example queries.
# Each module has clear dependencies, avoiding circular imports, and the code is easier to maintain and extend.
#pip install langchain langchain-aws langchain-community langchain-openai langgraph boto3
"""
How It Works
State Management:
The AgentState tracks conversation history, intent, context, SQL details, and more.
DynamoDBSaver persists state across invocations, supporting concurrent users.
Workflow:
Parse Question: Determines if the query relates to Confluence, Databricks, both, or is ambiguous.
Rewrite Query: Enhances query clarity.
Route Context: Retrieves relevant documents from PGVector stores.
Clarify Question: Prompts for clarification if needed.
Generate/Execute SQL: Creates and runs SQL queries for Databricks questions, with retries on errors.
Generate Answer: Combines history (last 5 messages) and context to produce a response.
Guardrails:
Ensures SQL queries are read-only.
Validates answer safety and appropriateness.
Agentic Behavior:
The LangGraph workflow enables decision-making (e.g., SQL generation) and multi-step processing.
"""
import os
import json
import boto3
from typing import TypedDict, Annotated, List
from langchain_aws import ChatBedrock
from langchain_community.vectorstores import PGVector
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.base import BaseCheckpointSaver
import operator
# --- DynamoDB State Persistence ---
class DynamoDBSaver(BaseCheckpointSaver):
"""Custom checkpointer to save and load conversation state from DynamoDB."""
def __init__(self, table_name: str = "conversation_state"):
self.dynamodb = boto3.resource('dynamodb')
self.table = self.dynamodb.Table(table_name)
def get(self, config: dict) -> dict:
thread_id = config["configurable"]["thread_id"]
response = self.table.get_item(Key={"thread_id": thread_id})
return json.loads(response["Item"]["state"]) if "Item" in response else {}
def put(self, config: dict, checkpoint: dict) -> None:
thread_id = config["configurable"]["thread_id"]
self.table.put_item(Item={"thread_id": thread_id, "state": json.dumps(checkpoint)})
# --- Agent State Definition ---
class AgentState(TypedDict):
"""Schema for the agent's state, tracking conversation and processing details."""
messages: Annotated[List[HumanMessage | AIMessage], operator.add] # Conversation history
intent: str # Classified intent: 'confluence', 'databricks', 'both', or 'ambiguous'
confluence_context: List[str] # Retrieved Confluence context
databricks_context: List[str] # Retrieved Databricks context
generated_sql: str # Generated SQL query
sql_attempts: int # Number of SQL generation attempts
sql_error: str # Error message from SQL execution
needs_clarification: bool # Flag for ambiguous queries
final_answer: str # Generated response
rewritten_query: str # Rewritten user query for clarity
# --- Configurations ---
os.environ["AWS_REGION"] = "us-east-1"
CONFLUENCE_CONNECTION_STRING = os.getenv("CONFLUENCE_DB_URL") # Set in environment
DATABRICKS_CONNECTION_STRING = os.getenv("DATABRICKS_DB_URL") # Set in environment
# Initialize Bedrock model
llm = ChatBedrock(
model_id="anthropic.claude-3-5-sonnet-20240620",
region_name="us-east-1",
model_kwargs={"temperature": 0.7}
)
# Initialize vector stores with OpenAI embeddings (for simplicity; adjust as needed)
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
confluence_store = PGVector(
collection_name="confluence_docs",
connection_string=CONFLUENCE_CONNECTION_STRING,
embedding_function=embeddings
)
databricks_store = PGVector(
collection_name="databricks_metadata",
connection_string=DATABRICKS_CONNECTION_STRING,
embedding_function=embeddings
)
# --- Prompt Templates ---
intent_template = PromptTemplate(
input_variables=["question"],
template="""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous':
Question: {question}
Provide a one-word response: """
)
sql_generation_template = PromptTemplate(
input_variables=["query", "metadata"],
template="Generate a SQL query for '{query}' using metadata: {metadata}"
)
sql_refinement_template = PromptTemplate(
input_variables=["previous_sql", "error", "metadata"],
template="Refine this SQL query '{previous_sql}' that caused error '{error}' using metadata: {metadata}"
)
answer_template = PromptTemplate(
input_variables=["history", "query", "confluence_ctx", "databricks_ctx"],
template="""Conversation history:
{history}
Current question: {query}
Using Confluence context: {confluence_ctx}
And Databricks context: {databricks_ctx}
Provide a concise, accurate response."""
)
safety_template = PromptTemplate(
input_variables=["answer"],
template="Is this answer safe and appropriate? Answer 'yes' or 'no': {answer}"
)
# --- Tools and Utilities ---
@tool
def generate_sql(user_query: str = None, previous_sql: str = None, error: str = None, metadata: List[str] = None) -> str:
"""Generate or refine a SQL query based on input."""
if previous_sql and error:
prompt = sql_refinement_template.format(previous_sql=previous_sql, error=error, metadata=metadata)
elif user_query:
prompt = sql_generation_template.format(query=user_query, metadata=metadata)
else:
raise ValueError("Must provide either user_query or previous_sql and error")
response = llm.invoke(prompt)
return response.content
def execute_databricks_sql(sql: str) -> dict:
"""Simulate SQL execution on Databricks (replace with actual implementation)."""
if "error" in sql.lower(): # Simulated error condition
return {"status": "error", "error": "Simulated SQL error"}
return {"status": "success", "results": ["Simulated results"]}
def rewrite_query(query: str) -> str:
"""Rewrite the user's query for clarity and specificity."""
prompt = f"Rewrite this query to make it clearer and more specific: '{query}'"
return llm.invoke(prompt).content
# --- Guardrails ---
def check_answer_safety(answer: str) -> bool:
"""Ensure the generated answer is safe and appropriate."""
safety_prompt = safety_template.format(answer=answer)
safety_response = llm.invoke(safety_prompt).content.strip().lower()
return safety_response == "yes"
def is_safe_sql(sql: str) -> bool:
"""Check if the SQL query is safe (read-only)."""
sql_lower = sql.lower().strip()
return sql_lower.startswith("select") and all(
keyword not in sql_lower for keyword in ["insert", "update", "delete", "drop"]
)
# --- Workflow Nodes ---
def parse_question(state: AgentState) -> AgentState:
"""Parse the user's question to determine intent."""
question = state["messages"][-1].content
prompt = intent_template.format(question=question)
response = llm.invoke(prompt).content.strip()
state["intent"] = response
return state
def rewrite_query_node(state: AgentState) -> AgentState:
"""Rewrite the query for better processing."""
question = state["messages"][-1].content
state["rewritten_query"] = rewrite_query(question)
return state
def route_context(state: AgentState) -> AgentState:
"""Retrieve context from Confluence and/or Databricks based on intent."""
query = state["rewritten_query"]
if state["intent"] in ["confluence", "both"]:
results = confluence_store.similarity_search(query, k=3)
state["confluence_context"] = [doc.page_content for doc in results]
if state["intent"] in ["databricks", "both"]:
results = databricks_store.similarity_search(query, k=3)
state["databricks_context"] = [doc.page_content for doc in results]
if state["intent"] == "ambiguous":
state["needs_clarification"] = True
return state
def clarify_question(state: AgentState) -> AgentState:
"""Ask for clarification if the query is ambiguous."""
if state["needs_clarification"]:
state["messages"].append(AIMessage(content="Please clarify your query."))
state["needs_clarification"] = False
return state
def generate_sql_node(state: AgentState) -> AgentState:
"""Generate a SQL query for Databricks-related questions."""
if state["intent"] in ["databricks", "both"] and state["databricks_context"]:
metadata = state["databricks_context"]
state["generated_sql"] = generate_sql(user_query=state["rewritten_query"], metadata=metadata)
state["sql_attempts"] = state.get("sql_attempts", 0) + 1
return state
def execute_sql_node(state: AgentState) -> AgentState:
"""Execute the SQL query with a feedback loop for retries."""
if state.get("generated_sql"):
if not is_safe_sql(state["generated_sql"]):
state["databricks_context"] = ["Unsafe SQL query detected."]
state["sql_error"] = None
else:
result = execute_databricks_sql(state["generated_sql"])
if result["status"] == "success":
state["databricks_context"] = result["results"]
state["sql_error"] = None
else:
state["sql_error"] = result["error"]
if state["sql_attempts"] < 3:
state["generated_sql"] = generate_sql(
previous_sql=state["generated_sql"],
error=state["sql_error"],
metadata=state["databricks_context"]
)
state = execute_sql_node(state) # Recursive retry
else:
state["databricks_context"] = ["Unable to retrieve data due to persistent errors."]
state["sql_error"] = None
return state
def generate_answer(state: AgentState) -> AgentState:
"""Generate the final answer using conversation history and context."""
history = "\n".join([f"{msg.type}: {msg.content}" for msg in state["messages"][-5:]])
confluence_ctx = "\n".join(state.get("confluence_context", []))
databricks_ctx = "\n".join(state.get("databricks_context", []))
prompt = answer_template.format(
history=history,
query=state["rewritten_query"],
confluence_ctx=confluence_ctx,
databricks_ctx=databricks_ctx
)
answer = llm.invoke(prompt).content
if check_answer_safety(answer):
state["final_answer"] = answer
else:
state["final_answer"] = "I'm sorry, but I can't provide that information."
state["messages"].append(AIMessage(content=state["final_answer"]))
return state
# --- Workflow Definition ---
workflow = StateGraph(AgentState)
# Add nodes to the workflow
workflow.add_node("parse_question", parse_question)
workflow.add_node("rewrite_query", rewrite_query_node)
workflow.add_node("route_context", route_context)
workflow.add_node("clarify_question", clarify_question)
workflow.add_node("generate_sql", generate_sql_node)
workflow.add_node("execute_sql", execute_sql_node)
workflow.add_node("generate_answer", generate_answer)
# Define the flow of execution
workflow.set_entry_point("parse_question")
workflow.add_edge("parse_question", "rewrite_query")
workflow.add_edge("rewrite_query", "route_context")
workflow.add_edge("route_context", "clarify_question")
workflow.add_conditional_edges(
"clarify_question",
lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer",
{"generate_sql": "generate_sql", "generate_answer": "generate_answer"}
)
workflow.add_edge("generate_sql", "execute_sql")
workflow.add_edge("execute_sql", "generate_answer")
workflow.add_edge("generate_answer", END)
# Compile the workflow with state persistence
checkpointer = DynamoDBSaver(table_name="conversation_state") # Assumes table exists
graph = workflow.compile(checkpointer=checkpointer)
# --- Main Function to Run the Agent ---
def run_agent(question: str, thread_id: str = "thread_1") -> str:
"""Run the agent with a user question and return the response."""
initial_state = {"messages": [HumanMessage(content=question)]}
result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}})
return result["messages"][-1].content
# --- Example Usage ---
if __name__ == "__main__":
response = run_agent("What is the process for onboarding in Confluence?")
print(response)
import logging
import json
from typing import TypedDict, List, Optional
from langchain_core.messages import HumanMessage
from langchain_aws import ChatBedrock
from langchain.vectorstores.pgvector import PGVector
from langchain.embeddings import BedrockEmbeddings
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from databricks import sql
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.prompts import PromptTemplate
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Bedrock Sonnet model
llm = ChatBedrock(
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
region_name="us-east-1" # Replace with your AWS region
)
# Initialize embedding function
embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
# Initialize PGVector stores (replace connection strings)
confluence_store = PGVector(
collection_name="confluence",
connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname",
embedding_function=embeddings
)
databricks_store = PGVector(
collection_name="databricks_metadata",
connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname",
embedding_function=embeddings
)
# Databricks SQL connection (replace with your credentials)
databricks_connection = sql.connect(
server_hostname="your_databricks_host",
http_path="your_http_path",
access_token="your_access_token"
)
# Define State schema
class State(TypedDict):
question: str # Original user question
rewritten_query: Optional[str] # Rewritten query for retrieval
intent: Optional[str] # Intent: "confluence", "databricks_metadata", or "sql_generation"
confluence_docs: List[str] # Retrieved Confluence documents
databricks_metadata: List[str] # Retrieved Databricks metadata
sql_query: Optional[str] # Generated SQL query
sql_result: Optional[str] # Raw SQL execution result
answer: Optional[str] # Intermediate answer
clarification_needed: bool # Flag for clarification
final_answer: Optional[str] # Final output answer
attempt_count: int # Retry counter for SQL execution
# Define Structured Output Schemas
question_analyzer_schema = [
ResponseSchema(name="intent", description="The classified intent: 'confluence', 'databricks_metadata', or 'sql_generation'"),
ResponseSchema(name="clarification_needed", description="Boolean indicating if clarification is needed", type="boolean")
]
query_rewriter_schema = [
ResponseSchema(name="rewritten_query", description="The rewritten query optimized for retrieval")
]
reranker_schema = [
ResponseSchema(name="indices", description="List of indices of the top 3 most relevant documents", type="list")
]
sql_generation_schema = [
ResponseSchema(name="sql_query", description="The generated SQL query")
]
clarification_schema = [
ResponseSchema(name="clarification_question", description="A follow-up question to clarify intent")
]
# Initialize Structured Output Parsers
question_analyzer_parser = StructuredOutputParser.from_response_schemas(question_analyzer_schema)
query_rewriter_parser = StructuredOutputParser.from_response_schemas(query_rewriter_schema)
reranker_parser = StructuredOutputParser.from_response_schemas(reranker_schema)
sql_generation_parser = StructuredOutputParser.from_response_schemas(sql_generation_schema)
clarification_parser = StructuredOutputParser.from_response_schemas(clarification_schema)
# Define Nodes with Claude-Optimized Prompts
def question_analyzer(state: State) -> State:
logger.info("Entering question_analyzer stage")
prompt_template = PromptTemplate(
input_variables=["question"],
template="""
You are an expert intent classifier. Given this question: "{question}", determine its intent. Choose one:
- "confluence" for general knowledge or documentation queries
- "databricks_metadata" for Databricks table schema or metadata questions
- "sql_generation" for questions needing an SQL query
If the question is unclear, set clarification_needed to true; otherwise, false.
Return your response in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": question_analyzer_parser.get_format_instructions()}
)
prompt = prompt_template.format(question=state['question'])
response = llm.invoke([HumanMessage(content=prompt)])
result = question_analyzer_parser.parse(response.content)
state['intent'] = result['intent']
state['clarification_needed'] = result['clarification_needed']
state['confluence_docs'] = []
state['databricks_metadata'] = []
state['attempt_count'] = 0
logger.debug(f"Intent: {state['intent']}, Clarification needed: {state['clarification_needed']}")
return state
def query_rewriter(state: State) -> State:
logger.info("Entering query_rewriter stage")
prompt_template = PromptTemplate(
input_variables=["question"],
template="""
You are a query optimization expert. Rewrite this question to make it more precise for document retrieval: "{question}".
Return the rewritten query in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": query_rewriter_parser.get_format_instructions()}
)
prompt = prompt_template.format(question=state['question'])
response = llm.invoke([HumanMessage(content=prompt)])
result = query_rewriter_parser.parse(response.content)
state['rewritten_query'] = result['rewritten_query']
logger.debug(f"Rewritten query: {state['rewritten_query']}")
return state
def confluence_retrieval(state: State) -> State:
logger.info("Entering confluence_retrieval stage")
query = state.get('rewritten_query', state['question'])
docs = confluence_store.similarity_search(query, k=5)
state['confluence_docs'] = [doc.page_content for doc in docs]
logger.debug(f"Retrieved {len(state['confluence_docs'])} Confluence docs")
return state
def databricks_retrieval(state: State) -> State:
logger.info("Entering databricks_retrieval stage")
query = state.get('rewritten_query', state['question'])
docs = databricks_store.similarity_search(query, k=5)
state['databricks_metadata'] = [doc.page_content for doc in docs]
logger.debug(f"Retrieved {len(state['databricks_metadata'])} Databricks metadata entries")
return state
def result_reranker(state: State) -> State:
logger.info("Entering result_reranker stage")
if state['confluence_docs']:
docs_str = "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(state['confluence_docs'])])
prompt_template = PromptTemplate(
input_variables=["question", "docs"],
template="""
You are a document relevance analyst. Given the question "{question}" and these documents:
{docs}
Identify the top 3 most relevant documents by their indices (0-based).
Return the indices in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": reranker_parser.get_format_instructions()}
)
prompt = prompt_template.format(question=state['question'], docs=docs_str)
response = llm.invoke([HumanMessage(content=prompt)])
result = reranker_parser.parse(response.content)
indices = result['indices']
state['confluence_docs'] = [state['confluence_docs'][i] for i in indices if i < len(state['confluence_docs'])]
logger.debug(f"Reranked Confluence docs: {state['confluence_docs']}")
if state['databricks_metadata']:
docs_str = "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(state['databricks_metadata'])])
prompt_template = PromptTemplate(
input_variables=["question", "docs"],
template="""
You are a metadata relevance analyst. Given the question "{question}" and these metadata entries:
{docs}
Identify the top 3 most relevant entries by their indices (0-based).
Return the indices in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": reranker_parser.get_format_instructions()}
)
prompt = prompt_template.format(question=state['question'], docs=docs_str)
response = llm.invoke([HumanMessage(content=prompt)])
result = reranker_parser.parse(response.content)
indices = result['indices']
state['databricks_metadata'] = [state['databricks_metadata'][i] for i in indices if i < len(state['databricks_metadata'])]
logger.debug(f"Reranked Databricks metadata: {state['databricks_metadata']}")
return state
def sql_generation(state: State) -> State:
logger.info("Entering sql_generation stage")
if not state['databricks_metadata']:
state = databricks_retrieval(state)
state = result_reranker(state)
metadata = "\n".join(state['databricks_metadata'])
prompt_template = PromptTemplate(
input_variables=["metadata", "question"],
template="""
You are an expert SQL generator for Databricks. Using this metadata:
{metadata}
Write an SQL query to answer the question: "{question}".
Return the query in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": sql_generation_parser.get_format_instructions()}
)
prompt = prompt_template.format(metadata=metadata, question=state['question'])
response = llm.invoke([HumanMessage(content=prompt)])
result = sql_generation_parser.parse(response.content)
state['sql_query'] = result['sql_query']
logger.debug(f"Generated SQL query: {state['sql_query']}")
return state
def sql_execution(state: State) -> State:
logger.info("Entering sql_execution stage")
sql_query = state['sql_query']
max_attempts = 3
attempt = state['attempt_count']
if attempt >= max_attempts:
state['final_answer'] = "Failed to execute SQL query after 3 attempts."
logger.info("SQL execution abandoned after max retries")
return state
try:
with databricks_connection.cursor() as cursor:
logger.info(f"Executing SQL (Attempt {attempt + 1}): {sql_query}")
cursor.execute(sql_query)
result = cursor.fetchall()
state['sql_result'] = str(result)
logger.debug(f"SQL result: {state['sql_result']}")
return state
except Exception as e:
logger.error(f"SQL execution failed: {e}")
state['attempt_count'] += 1
logger.info(f"Retrying SQL generation (attempt {state['attempt_count']})")
return sql_generation(state)
def answer_generation(state: State) -> State:
logger.info("Entering answer_generation stage")
context = ""
if state['confluence_docs']:
context += "Confluence Documents:\n" + "\n".join(state['confluence_docs']) + "\n"
if state['databricks_metadata']:
context += "Databricks Metadata:\n" + "\n".join(state['databricks_metadata']) + "\n"
if state.get('sql_query'):
context += "Generated SQL Query:\n" + state['sql_query'] + "\n"
if state.get('sql_result'):
context += "SQL Result:\n" + state['sql_result'] + "\n"
if not context.strip():
state['answer'] = "I'm sorry, I couldn't find sufficient information to answer your question."
else:
prompt = f"""
You are an expert answer generator. Using this context:
{context}
Provide a clear and concise answer to the question: "{state['question']}".
If an SQL result is included, explain it in text and, if suitable, format it as a table.
"""
response = llm.invoke([HumanMessage(content=prompt)])
state['answer'] = response.content.strip()
state['final_answer'] = state['answer']
logger.debug(f"Generated answer: {state['answer']}")
return state
def guardrails(state: State) -> State:
logger.info("Entering guardrails stage")
answer = state['answer'].lower()
prohibited_words = ["harmful", "offensive", "confidential"]
if any(word in answer for word in prohibited_words):
state['final_answer'] = "I'm sorry, I cannot provide that information due to content restrictions."
elif len(state['answer']) > 1000:
state['final_answer'] = "The response is too long to display fully. Please refine your question."
logger.debug(f"Final answer after guardrails: {state['final_answer']}")
return state
def clarification_node(state: State) -> State:
logger.info("Entering clarification_node stage")
prompt_template = PromptTemplate(
input_variables=["question"],
template="""
You are a clarification expert. The question "{question}" is unclear.
Ask a concise follow-up question to clarify the user's intent.
Return the follow-up question in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": clarification_parser.get_format_instructions()}
)
prompt = prompt_template.format(question=state['question'])
response = llm.invoke([HumanMessage(content=prompt)])
result = clarification_parser.parse(response.content)
state['final_answer'] = result['clarification_question']
logger.debug(f"Clarification question: {state['final_answer']}")
return state
# Define Graph
graph = StateGraph(State)
memory = MemorySaver()
# Add nodes
graph.add_node("question_analyzer", question_analyzer)
graph.add_node("query_rewriter", query_rewriter)
graph.add_node("confluence_retrieval", confluence_retrieval)
graph.add_node("databricks_retrieval", databricks_retrieval)
graph.add_node("result_reranker", result_reranker)
graph.add_node("sql_generation", sql_generation)
graph.add_node("sql_execution", sql_execution)
graph.add_node("answer_generation", answer_generation)
graph.add_node("guardrails", guardrails)
graph.add_node("clarification_node", clarification_node)
# Define routing functions
def route_after_analyzer(state: State) -> str:
return "clarification_node" if state['clarification_needed'] else "query_rewriter"
def route_after_rewriter(state: State) -> str:
intent = state['intent']
if intent == "confluence":
return "confluence_retrieval"
elif intent == "databricks_metadata":
return "databricks_retrieval"
elif intent == "sql_generation":
return "sql_generation"
return "clarification_node"
def route_after_sql_execution(state: State) -> str:
if state['attempt_count'] < 3 and not state.get('sql_result'):
return "sql_generation"
return "answer_generation"
# Define edges with START and END
graph.add_edge(START, "question_analyzer")
graph.add_conditional_edges("question_analyzer", route_after_analyzer, {"clarification_node": "clarification_node", "query_rewriter": "query_rewriter"})
graph.add_conditional_edges("query_rewriter", route_after_rewriter, {
"confluence_retrieval": "confluence_retrieval",
"databricks_retrieval": "databricks_retrieval",
"sql_generation": "sql_generation",
"clarification_node": "clarification_node"
})
graph.add_edge("confluence_retrieval", "result_reranker")
graph.add_edge("databricks_retrieval", "result_reranker")
graph.add_edge("sql_generation", "sql_execution")
graph.add_conditional_edges("sql_execution", route_after_sql_execution, {"sql_generation": "sql_generation", "answer_generation": "answer_generation"})
graph.add_edge("result_reranker", "answer_generation")
graph.add_edge("answer_generation", "guardrails")
graph.add_edge("guardrails", END)
graph.add_edge("clarification_node", END)
# Compile the graph with memory
app = graph.compile(checkpointer=memory)
# Interface for user questions with thread_id
def ask_question(question: str, thread_id: str) -> str:
logger.info(f"Processing question '{question}' with thread_id '{thread_id}'")
config = {"configurable": {"thread_id": thread_id}}
initial_state = {
"question": question,
"rewritten_query": None,
"intent": None,
"confluence_docs": [],
"databricks_metadata": [],
"sql_query": None,
"sql_result": None,
"answer": None,
"clarification_needed": False,
"final_answer": None,
"attempt_count": 0
}
result = app.invoke(initial_state, config=config)
logger.info(f"Returning answer: {result['final_answer']}")
return result['final_answer']
# Example usage
if __name__ == "__main__":
# New conversation
answer1 = ask_question("What is the total sales in the last quarter?", "thread_1")
print("Answer 1:", answer1)
# Follow-up in same thread
answer2 = ask_question("Break it down by region", "thread_1")
print("Answer 2:", answer2)
# New conversation with different thread_id
answer3 = ask_question("How do I configure Databricks?", "thread_2")
print("Answer 3:", answer3)
import os
import json
import boto3
from typing import TypedDict, Annotated, List
from langchain_aws import ChatBedrock
from langchain_community.vectorstores import PGVector
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.base import BaseCheckpointSaver
import operator
import logging
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
import sqlparse
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# --- DynamoDB State Persistence ---
class DynamoDBSaver(BaseCheckpointSaver):
def __init__(self, table_name: str = "conversation_state"):
self.dynamodb = boto3.resource('dynamodb')
self.table = self.dynamodb.Table(table_name)
def get(self, config: dict) -> dict:
thread_id = config["configurable"]["thread_id"]
response = self.table.get_item(Key={"thread_id": thread_id})
return json.loads(response["Item"]["state"]) if "Item" in response else {}
def put(self, config: dict, checkpoint: dict) -> None:
thread_id = config["configurable"]["thread_id"]
self.table.put_item(Item={"thread_id": thread_id, "state": json.dumps(checkpoint)})
# --- Pydantic Models for Output Parsing ---
class IntentOutput(BaseModel):
intent: str = Field(description="The classified intent: 'confluence', 'databricks', 'both', or 'ambiguous'")
class SQLOutput(BaseModel):
sql: str = Field(description="The generated or refined SQL query")
class RewrittenQueryOutput(BaseModel):
rewritten_query: str = Field(description="The rewritten query for clarity and specificity")
class SafetyOutput(BaseModel):
is_safe: str = Field(description="Whether the answer is safe and appropriate: 'yes' or 'no'")
class AnswerOutput(BaseModel):
answer: str = Field(description="The final answer to the user's question")
class ValidationOutput(BaseModel):
is_valid: str = Field(description="Whether the SQL results make sense for the question: 'yes' or 'no'")
# --- Agent State Definition ---
class AgentState(TypedDict):
messages: Annotated[List[HumanMessage | AIMessage], operator.add]
intent: str
confluence_context: List[str]
databricks_context: List[str]
generated_sql: str
sql_attempts: int
sql_error: str
needs_clarification: bool
final_answer: str
rewritten_query: str
# --- Configurations ---
os.environ["AWS_REGION"] = "us-east-1"
CONFLUENCE_CONNECTION_STRING = os.getenv("CONFLUENCE_DB_URL")
DATABRICKS_CONNECTION_STRING = os.getenv("DATABRICKS_DB_URL")
# Initialize Bedrock model
llm = ChatBedrock(
model_id="anthropic.claude-3-5-sonnet-20240620",
region_name="us-east-1",
model_kwargs={"temperature": 0.7}
)
# Initialize vector stores
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
confluence_store = PGVector(
collection_name="confluence_docs",
connection_string=CONFLUENCE_CONNECTION_STRING,
embedding_function=embeddings
)
databricks_store = PGVector(
collection_name="databricks_metadata",
connection_string=DATABRICKS_CONNECTION_STRING,
embedding_function=embeddings
)
# --- Tools and Utilities ---
@tool
def generate_sql(user_query: str = None, previous_sql: str = None, error: str = None, metadata: List[str] = None) -> str:
logger.debug("Generating SQL with user_query=%s, previous_sql=%s, error=%s, metadata=%s", user_query, previous_sql, error, metadata)
parser = PydanticOutputParser(pydantic_object=SQLOutput)
format_instructions = parser.get_format_instructions()
if previous_sql and error:
prompt = PromptTemplate(
input_variables=["previous_sql", "error", "metadata"],
partial_variables={"format_instructions": format_instructions},
template="Refine this SQL query '{previous_sql}' that caused error '{error}' using metadata: {metadata}\n\n{format_instructions}"
)
chain = prompt | llm | parser
response = chain.invoke({"previous_sql": previous_sql, "error": error, "metadata": metadata})
elif user_query:
prompt = PromptTemplate(
input_variables=["query", "metadata"],
partial_variables={"format_instructions": format_instructions},
template="Generate a SQL query for '{query}' using metadata: {metadata}\n\n{format_instructions}"
)
chain = prompt | llm | parser
response = chain.invoke({"query": user_query, "metadata": metadata})
else:
raise ValueError("Must provide either user_query or previous_sql and error")
sql = response.sql
logger.debug("Generated SQL: %s", sql)
return sql
def execute_databricks_sql(sql: str) -> dict:
logger.debug("Executing SQL: %s", sql)
if "error" in sql.lower(): # Simulated error condition
return {"status": "error", "error": "Simulated SQL error"}
return {"status": "success", "results": ["Simulated results"]}
def rewrite_query(query: str) -> str:
logger.debug("Rewriting query: %s", query)
parser = PydanticOutputParser(pydantic_object=RewrittenQueryOutput)
format_instructions = parser.get_format_instructions()
prompt = PromptTemplate(
input_variables=["query"],
partial_variables={"format_instructions": format_instructions},
template="Rewrite this query to make it clearer and more specific: '{query}'\n\n{format_instructions}"
)
chain = prompt | llm | parser
response = chain.invoke({"query": query})
rewritten = response.rewritten_query
logger.debug("Rewritten query: %s", rewritten)
return rewritten
# --- Guardrails ---
def check_answer_safety(answer: str) -> bool:
logger.debug("Checking answer safety: %s", answer)
parser = PydanticOutputParser(pydantic_object=SafetyOutput)
format_instructions = parser.get_format_instructions()
prompt = PromptTemplate(
input_variables=["answer"],
partial_variables={"format_instructions": format_instructions},
template="Is this answer safe and appropriate? Answer 'yes' or 'no': {answer}\n\n{format_instructions}"
)
chain = prompt | llm | parser
response = chain.invoke({"answer": answer})
is_safe = response.is_safe.lower() == "yes"
logger.debug("Answer safety: %s", is_safe)
return is_safe
def is_safe_sql(sql: str) -> bool:
logger.debug("Checking SQL safety: %s", sql)
sql_lower = sql.lower().strip()
is_safe = sql_lower.startswith("select") and all(
keyword not in sql_lower for keyword in ["insert", "update", "delete", "drop"]
)
logger.debug("SQL is safe: %s", is_safe)
return is_safe
def is_valid_sql(sql: str) -> bool:
try:
parsed = sqlparse.parse(sql)
return len(parsed) > 0 and all(stmt.get_type() != 'UNKNOWN' for stmt in parsed)
except Exception:
return False
def validate_results(question: str, results: List[str]) -> bool:
logger.debug("Validating results for question: %s, results: %s", question, results)
results_str = "\n".join(results)
parser = PydanticOutputParser(pydantic_object=ValidationOutput)
format_instructions = parser.get_format_instructions()
prompt = PromptTemplate(
input_variables=["question", "results"],
partial_variables={"format_instructions": format_instructions},
template="""Does the following SQL result answer the question?
Question: {question}
SQL Result:
{results}
Respond with 'yes' or 'no'.\n\n{format_instructions}"""
)
chain = prompt | llm | parser
response = chain.invoke({"question": question, "results": results_str})
is_valid = response.is_valid.lower() == "yes"
logger.debug("Results validation: %s", is_valid)
return is_valid
# --- Workflow Nodes ---
def parse_question(state: AgentState) -> AgentState:
logger.info("Entering parse_question")
question = state["messages"][-1].content
logger.debug("Question: %s", question)
parser = PydanticOutputParser(pydantic_object=IntentOutput)
format_instructions = parser.get_format_instructions()
prompt = PromptTemplate(
input_variables=["question"],
partial_variables={"format_instructions": format_instructions},
template="""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous'.
Question: {question}
{format_instructions}"""
)
chain = prompt | llm | parser
response = chain.invoke({"question": question})
state["intent"] = response.intent
logger.debug("Intent: %s", state["intent"])
logger.info("Exiting parse_question")
return state
def rewrite_query_node(state: AgentState) -> AgentState:
logger.info("Entering rewrite_query_node")
question = state["messages"][-1].content
state["rewritten_query"] = rewrite_query(question)
logger.debug("Rewritten query: %s", state["rewritten_query"])
logger.info("Exiting rewrite_query_node")
return state
def route_context(state: AgentState) -> AgentState:
logger.info("Entering route_context")
query = state["rewritten_query"]
if state["intent"] in ["confluence", "both"]:
results = confluence_store.similarity_search(query, k=3)
state["confluence_context"] = [doc.page_content for doc in results]
logger.debug("Confluence context: %s", state["confluence_context"])
if state["intent"] in ["databricks", "both"]:
results = databricks_store.similarity_search(query, k=3)
state["databricks_context"] = [doc.page_content for doc in results]
logger.debug("Databricks context: %s", state["databricks_context"])
if state["intent"] == "ambiguous":
state["needs_clarification"] = True
logger.info("Exiting route_context")
return state
def clarify_question(state: AgentState) -> AgentState:
logger.info("Entering clarify_question")
if state["needs_clarification"]:
state["messages"].append(AIMessage(content="Please clarify your query."))
state["needs_clarification"] = False
logger.info("Exiting clarify_question")
return state
def generate_sql_node(state: AgentState) -> AgentState:
logger.info("Entering generate_sql_node")
if state["intent"] in ["databricks", "both"] and state["databricks_context"]:
metadata = state["databricks_context"]
state["generated_sql"] = generate_sql(user_query=state["rewritten_query"], metadata=metadata)
state["sql_attempts"] = 1
logger.info("Exiting generate_sql_node")
return state
def execute_sql_node(state: AgentState) -> AgentState:
logger.info("Entering execute_sql_node")
if state.get("generated_sql"):
if not is_safe_sql(state["generated_sql"]):
state["databricks_context"] = ["Unsafe SQL query detected."]
state["sql_error"] = None
elif not is_valid_sql(state["generated_sql"]):
state["sql_error"] = "Invalid SQL syntax"
else:
result = execute_databricks_sql(state["generated_sql"])
if result["status"] == "success":
if validate_results(state["rewritten_query"], result["results"]):
state["databricks_context"] = result["results"]
state["sql_error"] = None
else:
state["sql_error"] = "Results do not make sense for the question"
else:
state["sql_error"] = result["error"]
if state["sql_error"] and state.get("sql_attempts", 0) < 3:
state["generated_sql"] = generate_sql(
previous_sql=state["generated_sql"],
error=state["sql_error"],
metadata=state["databricks_context"]
)
state["sql_attempts"] = state.get("sql_attempts", 0) + 1
state = execute_sql_node(state) # Recursive retry
elif state["sql_error"]:
state["databricks_context"] = ["Unable to retrieve data due to persistent errors."]
state["sql_error"] = None
logger.info("Exiting execute_sql_node")
return state
def generate_answer(state: AgentState) -> AgentState:
logger.info("Entering generate_answer")
history = "\n".join([f"{msg.type}: {msg.content}" for msg in state["messages"][-5:]])
confluence_ctx = "\n".join(state.get("confluence_context", []))
databricks_ctx = "\n".join(state.get("databricks_context", []))
parser = PydanticOutputParser(pydantic_object=AnswerOutput)
format_instructions = parser.get_format_instructions()
prompt = PromptTemplate(
input_variables=["history", "query", "confluence_ctx", "databricks_ctx"],
partial_variables={"format_instructions": format_instructions},
template="""Conversation history:
{history}
Current question: {query}
Using Confluence context: {confluence_ctx}
And Databricks context: {databricks_ctx}
Provide a concise, accurate response.\n\n{format_instructions}"""
)
chain = prompt | llm | parser
response = chain.invoke({
"history": history,
"query": state["rewritten_query"],
"confluence_ctx": confluence_ctx,
"databricks_ctx": databricks_ctx
})
answer = response.answer
if check_answer_safety(answer):
state["final_answer"] = answer
else:
state["final_answer"] = "I'm sorry, but I can't provide that information."
state["messages"].append(AIMessage(content=state["final_answer"]))
logger.debug("Final answer: %s", state["final_answer"])
logger.info("Exiting generate_answer")
return state
# --- Workflow Definition ---
workflow = StateGraph(AgentState)
workflow.add_node("parse_question", parse_question)
workflow.add_node("rewrite_query", rewrite_query_node)
workflow.add_node("route_context", route_context)
workflow.add_node("clarify_question", clarify_question)
workflow.add_node("generate_sql", generate_sql_node)
workflow.add_node("execute_sql", execute_sql_node)
workflow.add_node("generate_answer", generate_answer)
workflow.set_entry_point("parse_question")
workflow.add_edge("parse_question", "rewrite_query")
workflow.add_edge("rewrite_query", "route_context")
workflow.add_edge("route_context", "clarify_question")
workflow.add_conditional_edges(
"clarify_question",
lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer",
{"generate_sql": "generate_sql", "generate_answer": "generate_answer"}
)
workflow.add_edge("generate_sql", "execute_sql")
workflow.add_edge("execute_sql", "generate_answer")
workflow.add_edge("generate_answer", END)
# Compile the workflow with state persistence
checkpointer = DynamoDBSaver(table_name="conversation_state")
graph = workflow.compile(checkpointer=checkpointer)
# --- Main Function to Run the Agent ---
def run_agent(question: str, thread_id: str = "thread_1") -> str:
logger.info("Running agent for question: %s, thread_id: %s", question, thread_id)
initial_state = {"messages": [HumanMessage(content=question)]}
result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}})
response = result["messages"][-1].content
logger.info("Agent response: %s", response)
return response
# --- Example Usage ---
if __name__ == "__main__":
response = run_agent("What is the process for onboarding in Confluence?")
print(response)
import logging
import json
from typing import TypedDict, List, Optional, Annotated
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_aws import ChatBedrock
from langchain.vectorstores.pgvector import PGVector
from langchain.embeddings import BedrockEmbeddings
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from databricks import sql
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from pydantic import BaseModel, Field
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Bedrock Sonnet model
llm = ChatBedrock(
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
region_name="us-east-1" # Replace with your AWS region
)
# Initialize embedding function
embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
# Initialize PGVector stores (replace connection strings)
confluence_store = PGVector(
collection_name="confluence",
connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname",
embedding_function=embeddings
)
databricks_store = PGVector(
collection_name="databricks_system_metadata",
connection_string="postgresql+psycopg2://user:password@localhost:5432/dbname",
embedding_function=embeddings
)
# Databricks SQL connection (replace with your credentials)
databricks_connection = sql.connect(
server_hostname="your_databricks_host",
http_path="your_http_path",
access_token="your_access_token"
)
# Define State schema with defaults
def _default_messages() -> List[BaseMessage]:
return []
def _append_message(current: List[BaseMessage], new: BaseMessage) -> List[BaseMessage]:
return current + [new]
class State(TypedDict):
messages: Annotated[List[BaseMessage], _append_message, _default_messages]
question: str
rewritten_query: Optional[str]
intent: Optional[str]
confluence_docs: Annotated[List[str], lambda x, y: x + [y], lambda: []]
databricks_metadata: Annotated[List[str], lambda x, y: x + [y], lambda: []]
sql_query: Optional[str]
sql_result: Optional[str]
answer: Optional[str]
clarification_needed: Annotated[bool, lambda x, y: y, lambda: False]
final_answer: Optional[str]
attempt_count: Annotated[int, lambda x, y: y, lambda: 0]
sql_valid: Annotated[bool, lambda x, y: y, lambda: False]
history_answer: Optional[str]
# Define Pydantic Models for Structured Output
class IntentResponse(BaseModel):
intent: str = Field(description="The classified intent: 'confluence' or 'databricks'")
clarification_needed: bool = Field(description="Boolean indicating if clarification is needed")
class ClarificationResponse(BaseModel):
clarification_question: str = Field(description="A follow-up question to clarify intent")
# Initialize Pydantic Output Parsers
question_analyzer_parser = PydanticOutputParser(pydantic_object=IntentResponse)
clarification_parser = PydanticOutputParser(pydantic_object=ClarificationResponse)
# Define Nodes
def start_conversation(state: State) -> State:
logger.info("Entering start_conversation stage")
human_message = HumanMessage(content=state['question'])
state['messages'] = state['messages'] + [human_message] # Append to default empty list
logger.debug(f"Appended human message: {human_message.content}")
return state
def check_history(state: State) -> State:
logger.info("Entering check_history stage")
if len(state['messages']) > 1: # Check prior history
history = "\n".join([f"{msg.type}: {msg.content}" for msg in state['messages'][:-1]])
prompt = f"""
You are a history analysis expert. Given this conversation history:
{history}
Can the question "{state['question']}" be answered based on this history? If yes, provide the answer as plain text. If no, return "no".
"""
response = llm.invoke([HumanMessage(content=prompt)])
answer = response.content.strip()
if answer != "no":
state['history_answer'] = answer
state['final_answer'] = answer
logger.debug(f"Answer found in history: {state['history_answer']}")
else:
state['history_answer'] = None
logger.debug("No answer found in history")
return state
def question_analyzer(state: State) -> State:
logger.info("Entering question_analyzer stage")
if state.get('history_answer'):
return state
prompt_template = PromptTemplate(
input_variables=["question"],
template="""
You are an expert intent classifier. Given this question: "{question}", determine its intent. Choose one:
- "confluence" for general knowledge or documentation queries
- "databricks" for questions requiring Databricks system table data
If the question is unclear, set clarification_needed to true; otherwise, false.
Return your response in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": question_analyzer_parser.get_format_instructions()}
)
prompt = prompt_template.format(question=state['question'])
response = llm.invoke([HumanMessage(content=prompt)])
result = question_analyzer_parser.parse(response.content)
state['intent'] = result.intent
state['clarification_needed'] = result.clarification_needed
logger.debug(f"Intent: {state['intent']}, Clarification needed: {state['clarification_needed']}")
return state
def query_rewriter(state: State) -> State:
logger.info("Entering query_rewriter stage")
if state.get('history_answer'):
return state
prompt = f"""
You are a query optimization expert. Rewrite this question to make it more precise for retrieval: "{state['question']}".
Provide the rewritten query as plain text.
"""
response = llm.invoke([HumanMessage(content=prompt)])
state['rewritten_query'] = response.content.strip()
logger.debug(f"Rewritten query: {state['rewritten_query']}")
return state
def confluence_retrieval(state: State) -> State:
logger.info("Entering confluence_retrieval stage")
query = state.get('rewritten_query', state['question'])
docs = confluence_store.similarity_search(query, k=5)
state['confluence_docs'] = docs # Already a list from similarity_search
logger.debug(f"Retrieved {len(state['confluence_docs'])} Confluence docs")
return state
def databricks_retrieval(state: State) -> State:
logger.info("Entering databricks_retrieval stage")
query = state.get('rewritten_query', state['question'])
docs = databricks_store.similarity_search(query, k=5)
state['databricks_metadata'] = [doc.page_content for doc in docs]
logger.debug(f"Retrieved {len(state['databricks_metadata'])} Databricks metadata entries")
return state
def result_reranker(state: State) -> State:
logger.info("Entering result_reranker stage")
if state['intent'] == "confluence" and state['confluence_docs']:
docs_str = "\n".join([f"Doc {i}: {doc.page_content}" for i, doc in enumerate(state['confluence_docs'])])
prompt = f"""
You are a document relevance analyst. Given the question "{state['question']}" and these documents:
{docs_str}
Return the indices (0-based) of the top 3 most relevant documents as a JSON list, e.g., [0, 1, 2].
"""
response = llm.invoke([HumanMessage(content=prompt)])
indices = json.loads(response.content.strip())
state['confluence_docs'] = [state['confluence_docs'][i] for i in indices if i < len(state['confluence_docs'])]
logger.debug(f"Reranked Confluence docs: {[doc.page_content for doc in state['confluence_docs']]}")
elif state['intent'] == "databricks" and state['databricks_metadata']:
docs_str = "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(state['databricks_metadata'])])
prompt = f"""
You are a metadata relevance analyst. Given the question "{state['question']}" and these Databricks system table metadata entries:
{docs_str}
Return the indices (0-based) of the top 3 most relevant entries as a JSON list, e.g., [0, 1, 2].
"""
response = llm.invoke([HumanMessage(content=prompt)])
indices = json.loads(response.content.strip())
state['databricks_metadata'] = [state['databricks_metadata'][i] for i in indices if i < len(state['databricks_metadata'])]
logger.debug(f"Reranked Databricks metadata: {state['databricks_metadata']}")
return state
def sql_generation(state: State) -> State:
logger.info("Entering sql_generation stage")
if not state['databricks_metadata']:
state = databricks_retrieval(state)
state = result_reranker(state)
metadata = "\n".join(state['databricks_metadata'])
prompt = f"""
You are an expert SQL generator for Databricks. Using this system table metadata:
{metadata}
Write an SQL query to fetch data from Databricks that answers the question: "{state['question']}".
Provide the SQL query as plain text.
"""
response = llm.invoke([HumanMessage(content=prompt)])
state['sql_query'] = response.content.strip()
logger.debug(f"Generated SQL query: {state['sql_query']}")
return state
def sql_validation(state: State) -> State:
logger.info("Entering sql_validation stage")
metadata = "\n".join(state['databricks_metadata'])
prompt = f"""
You are an SQL validation expert. Given this Databricks system table metadata:
{metadata}
Check if this SQL query is valid and likely to answer the question "{state['question']}":
{state['sql_query']}
Return "valid" if the query is correct, or "invalid" if it’s not.
"""
response = llm.invoke([HumanMessage(content=prompt)])
state['sql_valid'] = response.content.strip() == "valid"
logger.debug(f"SQL validation result: {state['sql_valid']}")
return state
def sql_execution(state: State) -> State:
logger.info("Entering sql_execution stage")
sql_query = state['sql_query']
max_attempts = 3
attempt = state['attempt_count']
if attempt >= max_attempts:
state['final_answer'] = "Failed to generate a valid SQL query or retrieve meaningful results from Databricks after 3 attempts."
logger.info("SQL execution abandoned after max retries")
return state
if not state['sql_valid']:
logger.info(f"SQL invalid, retrying generation (attempt {attempt + 1})")
state['attempt_count'] += 1
return sql_generation(state)
try:
with databricks_connection.cursor() as cursor:
logger.info(f"Executing SQL in Databricks (Attempt {attempt + 1}): {sql_query}")
cursor.execute(sql_query)
result = cursor.fetchall()
state['sql_result'] = str(result)
logger.debug(f"SQL result: {state['sql_result']}")
return state
except Exception as e:
logger.error(f"SQL execution failed: {e}")
state['attempt_count'] += 1
logger.info(f"Retrying SQL generation due to execution error (attempt {attempt + 1})")
return sql_generation(state)
def sql_refinement(state: State) -> State:
logger.info("Entering sql_refinement stage")
metadata = "\n".join(state['databricks_metadata'])
prompt = f"""
You are an SQL refinement expert. The previous SQL query:
{state['sql_query']}
Produced this result:
{state['sql_result']}
For the question "{state['question']}", the result didn’t fully satisfy the query intent. Using this Databricks system table metadata:
{metadata}
Refine the SQL query to better answer the question. Provide the improved SQL query as plain text.
"""
response = llm.invoke([HumanMessage(content=prompt)])
state['sql_query'] = response.content.strip()
state['sql_valid'] = False # Reset for new query
logger.debug(f"Refined SQL query: {state['sql_query']}")
return state
def answer_generation(state: State) -> State:
logger.info("Entering answer_generation stage")
context = ""
if state['intent'] == "confluence" and state['confluence_docs']:
context += "Confluence Documents:\n" + "\n".join([doc.page_content for doc in state['confluence_docs']]) + "\n"
elif state['intent'] == "databricks":
if state.get('databricks_metadata'):
context += "Databricks System Table Metadata:\n" + "\n".join(state['databricks_metadata']) + "\n"
if state.get('sql_query'):
context += "Generated SQL Query:\n" + state['sql_query'] + "\n"
if state.get('sql_result'):
context += "SQL Result:\n" + state['sql_result'] + "\n"
history = "\n".join([f"{msg.type}: {msg.content}" for msg in state['messages']]) if state['messages'] else "No prior history"
if not context.strip():
state['answer'] = "I'm sorry, I couldn't find sufficient information to answer your question."
else:
prompt = f"""
You are an expert answer generator. Using this context:
{context}
And this conversation history:
{history}
Provide a clear and concise answer to the question: "{state['question']}".
If an SQL result is included, explain it in text and, if suitable, format it as a table. Ensure the answer is consistent with or builds on the history where relevant.
"""
response = llm.invoke([HumanMessage(content=prompt)])
state['answer'] = response.content.strip()
if state['intent'] == "databricks" and state.get('sql_result'):
prompt = f"""
You are a result validation expert. Given the question "{state['question']}" and this answer:
{state['answer']}
Does the answer fully satisfy the question based on the SQL result? Return "yes" or "no".
"""
response = llm.invoke([HumanMessage(content=prompt)])
if response.content.strip() == "no" and state['attempt_count'] < 3:
logger.info(f"Answer doesn’t satisfy question, refining SQL (attempt {state['attempt_count'] + 1})")
state['attempt_count'] += 1
return sql_refinement(state)
state['final_answer'] = state['answer']
logger.debug(f"Generated answer: {state['answer']}")
return state
def guardrails(state: State) -> State:
logger.info("Entering guardrails stage")
answer = state['answer'].lower()
prohibited_words = ["harmful", "offensive", "confidential"]
if any(word in answer for word in prohibited_words):
state['final_answer'] = "I'm sorry, I cannot provide that information due to content restrictions."
elif len(state['answer']) > 1000:
state['final_answer'] = "The response is too long to display fully. Please refine your question."
logger.debug(f"Final answer after guardrails: {state['final_answer']}")
return state
def clarification_node(state: State) -> State:
logger.info("Entering clarification_node stage")
prompt_template = PromptTemplate(
input_variables=["question"],
template="""
You are a clarification expert. The question "{question}" is unclear.
Ask a concise follow-up question to clarify the user's intent.
Return the follow-up question in this JSON format:
{format_instructions}
""",
partial_variables={"format_instructions": clarification_parser.get_format_instructions()}
)
prompt = prompt_template.format(question=state['question'])
response = llm.invoke([HumanMessage(content=prompt)])
result = clarification_parser.parse(response.content)
state['final_answer'] = result.clarification_question
logger.debug(f"Clarification question: {state['final_answer']}")
return state
def end_conversation(state: State) -> State:
logger.info("Entering end_conversation stage")
ai_message = AIMessage(content=state['final_answer'])
state['messages'] = state['messages'] + [ai_message]
logger.debug(f"Appended AI message: {ai_message.content}")
return state
# Define Graph
graph = StateGraph(State)
memory = MemorySaver()
# Add nodes
graph.add_node("start_conversation", start_conversation)
graph.add_node("check_history", check_history)
graph.add_node("question_analyzer", question_analyzer)
graph.add_node("query_rewriter", query_rewriter)
graph.add_node("confluence_retrieval", confluence_retrieval)
graph.add_node("databricks_retrieval", databricks_retrieval)
graph.add_node("result_reranker", result_reranker)
graph.add_node("sql_generation", sql_generation)
graph.add_node("sql_validation", sql_validation)
graph.add_node("sql_execution", sql_execution)
graph.add_node("sql_refinement", sql_refinement)
graph.add_node("answer_generation", answer_generation)
graph.add_node("guardrails", guardrails)
graph.add_node("clarification_node", clarification_node)
graph.add_node("end_conversation", end_conversation)
# Define routing functions
def route_after_history(state: State) -> str:
return "end_conversation" if state.get('history_answer') else "question_analyzer"
def route_after_analyzer(state: State) -> str:
return "clarification_node" if state['clarification_needed'] else "query_rewriter"
def route_after_rewriter(state: State) -> str:
intent = state['intent']
if intent == "confluence":
return "confluence_retrieval"
elif intent == "databricks":
return "databricks_retrieval"
return "clarification_node"
# Define edges with START and END
graph.add_edge(START, "start_conversation")
graph.add_edge("start_conversation", "check_history")
graph.add_conditional_edges("check_history", route_after_history, {
"end_conversation": "end_conversation",
"question_analyzer": "question_analyzer"
})
graph.add_conditional_edges("question_analyzer", route_after_analyzer, {
"clarification_node": "clarification_node",
"query_rewriter": "query_rewriter"
})
graph.add_conditional_edges("query_rewriter", route_after_rewriter, {
"confluence_retrieval": "confluence_retrieval",
"databricks_retrieval": "databricks_retrieval",
"clarification_node": "clarification_node"
})
graph.add_edge("confluence_retrieval", "result_reranker")
graph.add_edge("databricks_retrieval", "result_reranker")
graph.add_edge("result_reranker", "sql_generation")
graph.add_edge("sql_generation", "sql_validation")
graph.add_edge("sql_validation", "sql_execution")
graph.add_edge("sql_execution", "answer_generation")
graph.add_edge("result_reranker", "answer_generation")
graph.add_edge("answer_generation", "sql_refinement")
graph.add_edge("sql_refinement", "sql_validation")
graph.add_edge("answer_generation", "guardrails")
graph.add_edge("guardrails", "end_conversation")
graph.add_edge("clarification_node", "end_conversation")
graph.add_edge("end_conversation", END)
# Compile the graph with memory
app = graph.compile(checkpointer=memory)
# Interface for user questions with thread_id
def ask_question(question: str, thread_id: str) -> str:
logger.info(f"Processing question '{question}' with thread_id '{thread_id}'")
config = {"configurable": {"thread_id": thread_id}}
initial_state = {"question": question} # Only question required, others default
result = app.invoke(initial_state, config=config)
logger.info(f"Returning answer: {result['final_answer']}")
logger.debug(f"Conversation history: {[msg.content for msg in result['messages']]}")
return result['final_answer']
# Example usage
if __name__ == "__main__":
# New conversation (Databricks example)
answer1 = ask_question("What is the total sales in the last quarter?", "thread_1")
print("Answer 1:", answer1)
# Follow-up in same thread
answer2 = ask_question("Break it down by region", "thread_1")
print("Answer 2:", answer2)
# New conversation (Confluence example)
answer3 = ask_question("How do I configure Databricks?", "thread_2")
print("Answer 3:", answer3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment