Skip to content

Instantly share code, notes, and snippets.

@Finndersen
Last active January 29, 2025 15:59
Show Gist options
  • Save Finndersen/c8cbcd257fa4b7cf19407df27286c66a to your computer and use it in GitHub Desktop.
Save Finndersen/c8cbcd257fa4b7cf19407df27286c66a to your computer and use it in GitHub Desktop.
Example code for how the PydanticAI graph framework could be adapted to use Edges for flexibility
# NODES
@dataclass
class UserPromptNode(BaseNode[...]):
"""Get the system & user prompt parts for initial message, or user prompt message if there is chat history"""
user_prompt: str
system_prompts: tuple[str, ...]
...
def run(tx: GraphRunContext[...]) -> _messages.ModelRequest:
...
@dataclass
class ModelResponse:
# Structure to store output of ModelRequestNode
tool_calls: list[_messages.ToolCallPart]
texts: list[str]
@dataclass
class ModelRequestNode(BaseNode[...]):
"""
Make a request to the model using the last message in state.message_history (or a specified request).
Returns the results of the model request.
"""
request: _messages.ModelRequest
model: models.Model # IMO model should be configurable for this node
def run(tx: GraphRunContext[...]) -> ModelResponse:
...
@dataclass
class ToolCallResult:
tool_responses: list[_messages.ModelRequestPart]
final_result: MarkFinalResult | None
@dataclass
class HandleToolCallsNode(BaseNode[...]):
"""
Handles tool calls of a ModelResponse, including structured output via result_tool.
"""
tool_calls: list[_messages.ModelRequestPart]
def run(tx: GraphRunContext[...]) -> ToolCallResult:
# Get results from calling tools
# If result_tool is used, set final_result = True
@dataclass
class FinalResultNode(BaseNode[...]):
"""
For backwards compatibility, append a new ModelRequest to message history using the tool returns and retries.
Also set logging run span attributes etc.
"""
data: MarkFinalResult[NodeRunEndT]
def run(tx: GraphRunContext[...]) -> MarkFinalResult:
...
# EDGES
def user_prompt_edge(ctx: GraphRunContext[...], user_prompt: _messages.ModelRequest) -> ModelRequestNode:
return ModelRequestNode(model=ctx.deps.model, request=user_prompt)
def model_request_edge(ctx: GraphRunContext[...], model_response: ModelResponse) -> HandleToolCallsNode | ModelRequestNode | FinalResultNode:
if model_response.tool_calls:
return HandleToolCallsNode(tool_calls=model_response.tool_calls)
elif model_response.texts:
return _handle_text_response(ctx=ctx, texts=model_response.texts) # Same as ModelRequestNode._handle_text_response() in linked PR
else:
raise exceptions.UnexpectedModelBehavior('Received empty model response')
def handle_tool_call_edge(ctx: GraphRunContext[...], tool_call_result: ToolCallResult) -> ModelRequestNode | FinalResultNode:
if tool_call_result.final_result:
return FinalResultNode(data=tool_call_result.final_result, extra_parts=tool_call_result.tool_responses)
else:
return ModelRequestNode(messages=_messages.ModelRequest(parts=tool_call_result.tool_responses))
def final_result_edge(ctx: GraphRunContext[...], data: MarkFinalResult) -> End:
return End(data)
## GRAPH
# The nodes and edges are provided together when building the graph
graph = Graph[...](
nodes_and_edges=[
(UserPromptNode, user_prompt_edge),
(ModelRequestNode, model_request_edge),
(HandleToolCallsNode, handle_tool_call_edge),
(FinalResultNode, final_result_edge)
],
...
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment