Last active
April 14, 2025 16:45
-
-
Save rvndbalaji/2f50bf3d2fc635bf2f36f71c724e3347 to your computer and use it in GitHub Desktop.
Utility class to build dynamic langgrah agents
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import functools | |
import operator | |
from typing import * | |
from typing import Annotated, Any, Dict, List, Sequence | |
from dto.chat import * | |
from dto.graph import * | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.tools import BaseTool | |
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langgraph.graph import END, START, StateGraph | |
from langgraph.graph.state import CompiledStateGraph | |
from langgraph.prebuilt import ToolNode | |
from models import get_text_model | |
from tools._index import get_tools | |
from typing_extensions import TypedDict | |
from utils.common import * | |
class GraphState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
sender: str | |
class Agent: | |
def __init__(self, name: str, llm, prompt: str, tools: List[BaseTool]): | |
self.name = name | |
self.prompt = prompt | |
prompt_template = ChatPromptTemplate.from_messages([ | |
("system", prompt), | |
MessagesPlaceholder(variable_name="messages"), | |
]) | |
self.agent = prompt_template | llm.bind_tools(tools) | |
self.tools = tools | |
class GraphBuilder: | |
def __init__(self, graph_config: GraphConfig): | |
self.graph_config = graph_config | |
self.agents: List[Agent] = [] | |
self.graph: CompiledStateGraph | None = None | |
@staticmethod | |
def _router(state) -> Literal["call_tool", "__end__", "continue"]: | |
messages = state["messages"] | |
last_message = messages[-1] | |
if last_message.tool_calls: | |
return "call_tool" | |
if "FINISHED" in last_message.content: | |
return "__end__" | |
return "continue" | |
@staticmethod | |
def _agent_node(state, agent, name): | |
result = agent.invoke(state) | |
result = AIMessage(**result.dict(exclude={"type", "name"}), name=name) | |
return { | |
"messages": [result], | |
"sender": name, | |
} | |
def _create_agent_node(self, agent: Agent): | |
return functools.partial(self._agent_node, agent=agent.agent, name=agent.name) | |
def _add_nodes(self, state_graph): | |
for agent in self.agents: | |
state_graph.add_node(agent.name, self._create_agent_node(agent)) | |
all_tools = [tool for agent in self.agents for tool in agent.tools] | |
state_graph.add_node("call_tool", ToolNode(all_tools)) | |
def _add_edges(self, state_graph): | |
graph_structure = {edge.source: edge.target for edge in self.graph_config.edges} | |
# Add start edge | |
start_node = graph_structure['start'] | |
state_graph.add_edge(START, start_node) | |
for agent in self.agents: | |
next_node = graph_structure.get(agent.name, 'end') | |
next_node = END if next_node.lower() == 'end' else next_node | |
state_graph.add_conditional_edges( | |
agent.name, | |
self._router, | |
{ | |
"call_tool": "call_tool", | |
"continue": next_node, | |
"__end__": END | |
} | |
) | |
# Add edges from call_tool back to agents | |
state_graph.add_conditional_edges( | |
"call_tool", | |
lambda x: x["sender"], | |
{agent.name: agent.name for agent in self.agents} | |
) | |
def get_stream(self, initial_state: Dict[str, Any], user : User): | |
if not self.graph: | |
raise ValueError("Graph not created. Call build() first.") | |
return self.graph.stream(initial_state,config={"recursion_limit": 150, **get_llm_config(user)}) | |
def get_image(self, xray=True, mermaid=True): | |
if not self.graph: | |
raise ValueError("Graph not created. Call build() first.") | |
img_data = self.graph.get_graph(xray=xray) | |
img_data = img_data.draw_mermaid_png() if mermaid else img_data.draw_png() | |
img_base64 = base64.b64encode(img_data).decode('utf-8') | |
return img_base64 | |
def get_compiled_graph(self): | |
if not self.graph: | |
raise ValueError("Graph not created. Call build() first.") | |
return self.graph.get_graph() | |
def build(self): | |
# Create agents | |
for agent_params in self.graph_config.agents: | |
if(agent_params.name not in ['start','end']): | |
tools : List[BaseTool] = get_tools(tool_names=agent_params.tool_names) | |
tools.extend(agent_params.tools) | |
self.agents.append(Agent( | |
name=agent_params.name, | |
llm=get_text_model(agent_params.llm), | |
prompt=agent_params.prompt, | |
tools=tools, | |
)) | |
#Build & Compile the Graph | |
state_graph = StateGraph(GraphState) | |
self._add_nodes(state_graph) | |
self._add_edges(state_graph) | |
self.graph = state_graph.compile() | |
return self.graph | |
def run(self, user: User, messages : List[Chat]): | |
if not self.graph: | |
raise ValueError("Graph not created. Call build() first.") | |
messages_dict = to_dict_list(messages) | |
events = self.get_stream({"messages" : messages_dict}, user) | |
for event in events: | |
for key, value in event.items(): | |
if 'messages' in value: | |
for message in value['messages']: | |
try: | |
chat_item = Chat( | |
role=map_to_chat_role(message.type), | |
content=message.content, | |
name=getattr(message, 'name', None), | |
tool_calls=[ToolCall(**tool_call) for tool_call in getattr(message, 'tool_calls', [])] if getattr(message, 'tool_calls', None) else None, | |
tool_call_id=getattr(message, 'tool_call_id', None) | |
) | |
# Convert the Chat object to a dictionary | |
chat_dict = chat_item.model_dump(exclude_unset=True) | |
# Convert any Enum values to strings | |
chat_dict = {k: v.value if isinstance(v, Enum) else v for k, v in chat_dict.items()} | |
# Convert the dictionary to a JSON string | |
json_str = json.dumps(chat_dict) | |
yield json_str | |
except Exception as e: | |
error_chat = Chat( | |
role=ChatRole.assistant, | |
content=f"Error processing message: {str(e)}", | |
name="Error" | |
) | |
error_dict = error_chat.model_dump(exclude_unset=True) | |
error_dict = {k: v.value if isinstance(v, Enum) else v for k, v in error_dict.items()} | |
error_json = json.dumps(error_dict) | |
yield error_json | |
#Usage example | |
body = GraphConfig( | |
graph_id=3, | |
name="Graph", | |
description="Graph", | |
session_id=None, | |
agents=[ | |
AgentConfig( | |
name="Researcher", | |
prompt="You are a research assistant", | |
llm="claude-sonnet-3.5", | |
tools=[], | |
color="#FF871F" | |
), | |
AgentConfig( | |
name="Writer", | |
prompt="You are a great writer", | |
llm="claude-sonnet-3.5", | |
tools=[], | |
color="#FF871F" | |
), | |
], | |
edges=[ | |
GraphEdge(source="start", target="Researcher"), | |
GraphEdge(source="Researcher", target="Writer"), | |
GraphEdge(source="Writer", target="end") | |
] | |
) | |
graph = GraphBuilder(graph_config=body) | |
graph.build() | |
response = GraphResponse( | |
config = graph.get_compiled_config(), | |
base64 = graph.get_image(xray=True, mermaid=True) | |
) | |
response |
Could you make a quick read me on using this? Thanks in advance!
Just copy this whole file and place it somewhere and refer to the usage example to import the class and use it to build graphs
Might have some compile errors but you can create the required pydantic classes to make it work
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In my org, I implemented a utility class to dynamically build LangGraph graphs
and abstracted away some of the common details
It is no means perfect, and is probably incomplete,
I have this raw implementation here
This can be used as a reference to implement a proper utility class