Skip to content

Instantly share code, notes, and snippets.

@tbensonwest
Forked from rvndbalaji/langgraph_builder.py
Created September 29, 2024 15:18
Show Gist options
  • Save tbensonwest/e57459156cf126d9bba53225398258c1 to your computer and use it in GitHub Desktop.
Save tbensonwest/e57459156cf126d9bba53225398258c1 to your computer and use it in GitHub Desktop.
Utility class to build dynamic langgrah agents
import base64
import functools
import operator
from typing import *
from typing import Annotated, Any, Dict, List, Sequence
from dto.chat import *
from dto.graph import *
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools import BaseTool
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from models import get_text_model
from tools._index import get_tools
from typing_extensions import TypedDict
from utils.common import *
class GraphState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
sender: str
class Agent:
def __init__(self, name: str, llm, prompt: str, tools: List[BaseTool]):
self.name = name
self.prompt = prompt
prompt_template = ChatPromptTemplate.from_messages([
("system", prompt),
MessagesPlaceholder(variable_name="messages"),
])
self.agent = prompt_template | llm.bind_tools(tools)
self.tools = tools
class GraphBuilder:
def __init__(self, graph_config: GraphConfig):
self.graph_config = graph_config
self.agents: List[Agent] = []
self.graph: CompiledStateGraph | None = None
@staticmethod
def _router(state) -> Literal["call_tool", "__end__", "continue"]:
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "call_tool"
if "FINISHED" in last_message.content:
return "__end__"
return "continue"
@staticmethod
def _agent_node(state, agent, name):
result = agent.invoke(state)
result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
return {
"messages": [result],
"sender": name,
}
def _create_agent_node(self, agent: Agent):
return functools.partial(self._agent_node, agent=agent.agent, name=agent.name)
def _add_nodes(self, state_graph):
for agent in self.agents:
state_graph.add_node(agent.name, self._create_agent_node(agent))
all_tools = [tool for agent in self.agents for tool in agent.tools]
state_graph.add_node("call_tool", ToolNode(all_tools))
def _add_edges(self, state_graph):
graph_structure = {edge.source: edge.target for edge in self.graph_config.edges}
# Add start edge
start_node = graph_structure['start']
state_graph.add_edge(START, start_node)
for agent in self.agents:
next_node = graph_structure.get(agent.name, 'end')
next_node = END if next_node.lower() == 'end' else next_node
state_graph.add_conditional_edges(
agent.name,
self._router,
{
"call_tool": "call_tool",
"continue": next_node,
"__end__": END
}
)
# Add edges from call_tool back to agents
state_graph.add_conditional_edges(
"call_tool",
lambda x: x["sender"],
{agent.name: agent.name for agent in self.agents}
)
def get_stream(self, initial_state: Dict[str, Any], user : User):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
return self.graph.stream(initial_state,config={"recursion_limit": 150, **get_llm_config(user)})
def get_image(self, xray=True, mermaid=True):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
img_data = self.graph.get_graph(xray=xray)
img_data = img_data.draw_mermaid_png() if mermaid else img_data.draw_png()
img_base64 = base64.b64encode(img_data).decode('utf-8')
return img_base64
def get_compiled_graph(self):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
return self.graph.get_graph()
def build(self):
# Create agents
for agent_params in self.graph_config.agents:
if(agent_params.name not in ['start','end']):
tools : List[BaseTool] = get_tools(tool_names=agent_params.tool_names)
tools.extend(agent_params.tools)
self.agents.append(Agent(
name=agent_params.name,
llm=get_text_model(agent_params.llm),
prompt=agent_params.prompt,
tools=tools,
))
#Build & Compile the Graph
state_graph = StateGraph(GraphState)
self._add_nodes(state_graph)
self._add_edges(state_graph)
self.graph = state_graph.compile()
return self.graph
def run(self, user: User, messages : List[Chat]):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
messages_dict = to_dict_list(messages)
events = self.get_stream({"messages" : messages_dict}, user)
for event in events:
for key, value in event.items():
if 'messages' in value:
for message in value['messages']:
try:
chat_item = Chat(
role=map_to_chat_role(message.type),
content=message.content,
name=getattr(message, 'name', None),
tool_calls=[ToolCall(**tool_call) for tool_call in getattr(message, 'tool_calls', [])] if getattr(message, 'tool_calls', None) else None,
tool_call_id=getattr(message, 'tool_call_id', None)
)
# Convert the Chat object to a dictionary
chat_dict = chat_item.model_dump(exclude_unset=True)
# Convert any Enum values to strings
chat_dict = {k: v.value if isinstance(v, Enum) else v for k, v in chat_dict.items()}
# Convert the dictionary to a JSON string
json_str = json.dumps(chat_dict)
yield json_str
except Exception as e:
error_chat = Chat(
role=ChatRole.assistant,
content=f"Error processing message: {str(e)}",
name="Error"
)
error_dict = error_chat.model_dump(exclude_unset=True)
error_dict = {k: v.value if isinstance(v, Enum) else v for k, v in error_dict.items()}
error_json = json.dumps(error_dict)
yield error_json
#Usage example
body = GraphConfig(
graph_id=3,
name="Graph",
description="Graph",
session_id=None,
agents=[
AgentConfig(
name="Researcher",
prompt="You are a research assistant",
llm="claude-sonnet-3.5",
tools=[],
color="#FF871F"
),
AgentConfig(
name="Writer",
prompt="You are a great writer",
llm="claude-sonnet-3.5",
tools=[],
color="#FF871F"
),
],
edges=[
GraphEdge(source="start", target="Researcher"),
GraphEdge(source="Researcher", target="Writer"),
GraphEdge(source="Writer", target="end")
]
)
graph = GraphBuilder(graph_config=body)
graph.build()
response = GraphResponse(
config = graph.get_compiled_config(),
base64 = graph.get_image(xray=True, mermaid=True)
)
response
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment