Last active
January 29, 2025 15:59
-
-
Save Finndersen/c8cbcd257fa4b7cf19407df27286c66a to your computer and use it in GitHub Desktop.
Example code for how the PydanticAI graph framework could be adapted to use Edges for flexibility
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NODES | |
@dataclass | |
class UserPromptNode(BaseNode[...]): | |
"""Get the system & user prompt parts for initial message, or user prompt message if there is chat history""" | |
user_prompt: str | |
system_prompts: tuple[str, ...] | |
... | |
def run(tx: GraphRunContext[...]) -> _messages.ModelRequest: | |
... | |
@dataclass | |
class ModelResponse: | |
# Structure to store output of ModelRequestNode | |
tool_calls: list[_messages.ToolCallPart] | |
texts: list[str] | |
@dataclass | |
class ModelRequestNode(BaseNode[...]): | |
""" | |
Make a request to the model using the last message in state.message_history (or a specified request). | |
Returns the results of the model request. | |
""" | |
request: _messages.ModelRequest | |
model: models.Model # IMO model should be configurable for this node | |
def run(tx: GraphRunContext[...]) -> ModelResponse: | |
... | |
@dataclass | |
class ToolCallResult: | |
tool_responses: list[_messages.ModelRequestPart] | |
final_result: MarkFinalResult | None | |
@dataclass | |
class HandleToolCallsNode(BaseNode[...]): | |
""" | |
Handles tool calls of a ModelResponse, including structured output via result_tool. | |
""" | |
tool_calls: list[_messages.ModelRequestPart] | |
def run(tx: GraphRunContext[...]) -> ToolCallResult: | |
# Get results from calling tools | |
# If result_tool is used, set final_result = True | |
@dataclass | |
class FinalResultNode(BaseNode[...]): | |
""" | |
For backwards compatibility, append a new ModelRequest to message history using the tool returns and retries. | |
Also set logging run span attributes etc. | |
""" | |
data: MarkFinalResult[NodeRunEndT] | |
def run(tx: GraphRunContext[...]) -> MarkFinalResult: | |
... | |
# EDGES | |
def user_prompt_edge(ctx: GraphRunContext[...], user_prompt: _messages.ModelRequest) -> ModelRequestNode: | |
return ModelRequestNode(model=ctx.deps.model, request=user_prompt) | |
def model_request_edge(ctx: GraphRunContext[...], model_response: ModelResponse) -> HandleToolCallsNode | ModelRequestNode | FinalResultNode: | |
if model_response.tool_calls: | |
return HandleToolCallsNode(tool_calls=model_response.tool_calls) | |
elif model_response.texts: | |
return _handle_text_response(ctx=ctx, texts=model_response.texts) # Same as ModelRequestNode._handle_text_response() in linked PR | |
else: | |
raise exceptions.UnexpectedModelBehavior('Received empty model response') | |
def handle_tool_call_edge(ctx: GraphRunContext[...], tool_call_result: ToolCallResult) -> ModelRequestNode | FinalResultNode: | |
if tool_call_result.final_result: | |
return FinalResultNode(data=tool_call_result.final_result, extra_parts=tool_call_result.tool_responses) | |
else: | |
return ModelRequestNode(messages=_messages.ModelRequest(parts=tool_call_result.tool_responses)) | |
def final_result_edge(ctx: GraphRunContext[...], data: MarkFinalResult) -> End: | |
return End(data) | |
## GRAPH | |
# The nodes and edges are provided together when building the graph | |
graph = Graph[...]( | |
nodes_and_edges=[ | |
(UserPromptNode, user_prompt_edge), | |
(ModelRequestNode, model_request_edge), | |
(HandleToolCallsNode, handle_tool_call_edge), | |
(FinalResultNode, final_result_edge) | |
], | |
... | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment