Skip to content

Instantly share code, notes, and snippets.

@rvndbalaji
Last active April 14, 2025 16:45
Show Gist options
  • Save rvndbalaji/2f50bf3d2fc635bf2f36f71c724e3347 to your computer and use it in GitHub Desktop.
Save rvndbalaji/2f50bf3d2fc635bf2f36f71c724e3347 to your computer and use it in GitHub Desktop.
Utility class to build dynamic langgrah agents
import base64
import functools
import operator
from typing import *
from typing import Annotated, Any, Dict, List, Sequence
from dto.chat import *
from dto.graph import *
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools import BaseTool
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from models import get_text_model
from tools._index import get_tools
from typing_extensions import TypedDict
from utils.common import *
class GraphState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
sender: str
class Agent:
def __init__(self, name: str, llm, prompt: str, tools: List[BaseTool]):
self.name = name
self.prompt = prompt
prompt_template = ChatPromptTemplate.from_messages([
("system", prompt),
MessagesPlaceholder(variable_name="messages"),
])
self.agent = prompt_template | llm.bind_tools(tools)
self.tools = tools
class GraphBuilder:
def __init__(self, graph_config: GraphConfig):
self.graph_config = graph_config
self.agents: List[Agent] = []
self.graph: CompiledStateGraph | None = None
@staticmethod
def _router(state) -> Literal["call_tool", "__end__", "continue"]:
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "call_tool"
if "FINISHED" in last_message.content:
return "__end__"
return "continue"
@staticmethod
def _agent_node(state, agent, name):
result = agent.invoke(state)
result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
return {
"messages": [result],
"sender": name,
}
def _create_agent_node(self, agent: Agent):
return functools.partial(self._agent_node, agent=agent.agent, name=agent.name)
def _add_nodes(self, state_graph):
for agent in self.agents:
state_graph.add_node(agent.name, self._create_agent_node(agent))
all_tools = [tool for agent in self.agents for tool in agent.tools]
state_graph.add_node("call_tool", ToolNode(all_tools))
def _add_edges(self, state_graph):
graph_structure = {edge.source: edge.target for edge in self.graph_config.edges}
# Add start edge
start_node = graph_structure['start']
state_graph.add_edge(START, start_node)
for agent in self.agents:
next_node = graph_structure.get(agent.name, 'end')
next_node = END if next_node.lower() == 'end' else next_node
state_graph.add_conditional_edges(
agent.name,
self._router,
{
"call_tool": "call_tool",
"continue": next_node,
"__end__": END
}
)
# Add edges from call_tool back to agents
state_graph.add_conditional_edges(
"call_tool",
lambda x: x["sender"],
{agent.name: agent.name for agent in self.agents}
)
def get_stream(self, initial_state: Dict[str, Any], user : User):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
return self.graph.stream(initial_state,config={"recursion_limit": 150, **get_llm_config(user)})
def get_image(self, xray=True, mermaid=True):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
img_data = self.graph.get_graph(xray=xray)
img_data = img_data.draw_mermaid_png() if mermaid else img_data.draw_png()
img_base64 = base64.b64encode(img_data).decode('utf-8')
return img_base64
def get_compiled_graph(self):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
return self.graph.get_graph()
def build(self):
# Create agents
for agent_params in self.graph_config.agents:
if(agent_params.name not in ['start','end']):
tools : List[BaseTool] = get_tools(tool_names=agent_params.tool_names)
tools.extend(agent_params.tools)
self.agents.append(Agent(
name=agent_params.name,
llm=get_text_model(agent_params.llm),
prompt=agent_params.prompt,
tools=tools,
))
#Build & Compile the Graph
state_graph = StateGraph(GraphState)
self._add_nodes(state_graph)
self._add_edges(state_graph)
self.graph = state_graph.compile()
return self.graph
def run(self, user: User, messages : List[Chat]):
if not self.graph:
raise ValueError("Graph not created. Call build() first.")
messages_dict = to_dict_list(messages)
events = self.get_stream({"messages" : messages_dict}, user)
for event in events:
for key, value in event.items():
if 'messages' in value:
for message in value['messages']:
try:
chat_item = Chat(
role=map_to_chat_role(message.type),
content=message.content,
name=getattr(message, 'name', None),
tool_calls=[ToolCall(**tool_call) for tool_call in getattr(message, 'tool_calls', [])] if getattr(message, 'tool_calls', None) else None,
tool_call_id=getattr(message, 'tool_call_id', None)
)
# Convert the Chat object to a dictionary
chat_dict = chat_item.model_dump(exclude_unset=True)
# Convert any Enum values to strings
chat_dict = {k: v.value if isinstance(v, Enum) else v for k, v in chat_dict.items()}
# Convert the dictionary to a JSON string
json_str = json.dumps(chat_dict)
yield json_str
except Exception as e:
error_chat = Chat(
role=ChatRole.assistant,
content=f"Error processing message: {str(e)}",
name="Error"
)
error_dict = error_chat.model_dump(exclude_unset=True)
error_dict = {k: v.value if isinstance(v, Enum) else v for k, v in error_dict.items()}
error_json = json.dumps(error_dict)
yield error_json
#Usage example
body = GraphConfig(
graph_id=3,
name="Graph",
description="Graph",
session_id=None,
agents=[
AgentConfig(
name="Researcher",
prompt="You are a research assistant",
llm="claude-sonnet-3.5",
tools=[],
color="#FF871F"
),
AgentConfig(
name="Writer",
prompt="You are a great writer",
llm="claude-sonnet-3.5",
tools=[],
color="#FF871F"
),
],
edges=[
GraphEdge(source="start", target="Researcher"),
GraphEdge(source="Researcher", target="Writer"),
GraphEdge(source="Writer", target="end")
]
)
graph = GraphBuilder(graph_config=body)
graph.build()
response = GraphResponse(
config = graph.get_compiled_config(),
base64 = graph.get_image(xray=True, mermaid=True)
)
response
@RobSisson
Copy link

Could you make a quick read me on using this? Thanks in advance!

@rvndbalaji
Copy link
Author

rvndbalaji commented Jan 12, 2025

In my org, I implemented a utility class to dynamically build LangGraph graphs
and abstracted away some of the common details

It is no means perfect, and is probably incomplete,

I have this raw implementation here

This can be used as a reference to implement a proper utility class

body = GraphConfig(
    graph_id=3,
    name="Graph",
    description="Graph",
    session_id=None,
    agents=[
        AgentConfig(
                    name="Researcher", 
                    prompt="You are a research assistant", 
                    llm="claude-sonnet-3.5", 
                    tools=[], 
                    color="#FF871F"
                   ),
        AgentConfig(
                    name="Writer", 
                    prompt="You are a great writer",
                    llm="claude-sonnet-3.5", 
                    tools=[], 
                    color="#FF871F"
                   ),
    ],
    edges=[
        GraphEdge(source="start", target="Researcher"),
         GraphEdge(source="Researcher", target="Writer"),
        GraphEdge(source="Writer", target="end")
    ]
)


graph = GraphBuilder(graph_config=body)
graph.build()
response = GraphResponse(
    config = graph.get_compiled_config(),
    base64 = graph.get_image(xray=True, mermaid=True)
)
response

@rvndbalaji
Copy link
Author

Could you make a quick read me on using this? Thanks in advance!

Just copy this whole file and place it somewhere and refer to the usage example to import the class and use it to build graphs
Might have some compile errors but you can create the required pydantic classes to make it work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment