Created
December 17, 2024 13:54
-
-
Save bemijonathan/4515aedebc3a009960a9c10212fc3aff to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from typing import Any, Dict, Optional, List, Type, get_type_hints, Union | |
from dataclasses import dataclass | |
import inspect | |
import json | |
from functools import wraps | |
@dataclass | |
class ToolResult: | |
success: bool | |
data: Any | |
error: Optional[str] = None | |
@dataclass | |
class MethodInfo: | |
"""Information about a tool method""" | |
name: str | |
description: str | |
parameters: Dict[str, Dict[str, Any]] | |
return_type: Any | |
method: Any | |
def tool_method(description: str = ""): | |
"""Decorator to mark methods as available tool actions""" | |
def decorator(func): | |
func.is_tool_method = True | |
func.description = description | |
return func | |
return decorator | |
class BaseTool(ABC): | |
"""Base class for all tools with automatic method discovery""" | |
def __init__(self): | |
self.methods = self._discover_methods() | |
def _discover_methods(self) -> Dict[str, MethodInfo]: | |
"""Discover all available tool methods""" | |
methods = {} | |
for name, method in inspect.getmembers(self, inspect.ismethod): | |
if hasattr(method, 'is_tool_method'): | |
# Get method signature | |
sig = inspect.signature(method) | |
hints = get_type_hints(method) | |
# Build parameter info | |
params = {} | |
for param_name, param in sig.parameters.items(): | |
if param_name == 'self': | |
continue | |
param_type = hints.get(param_name, Any) | |
param_info = { | |
'type': self._type_to_str(param_type), | |
'required': param.default == param.empty, | |
'default': None if param.default == param.empty else param.default, | |
'description': self._get_param_doc(method, param_name) | |
} | |
params[param_name] = param_info | |
methods[name] = MethodInfo( | |
name=name, | |
description=method.description, | |
parameters=params, | |
return_type=hints.get('return', Any), | |
method=method | |
) | |
return methods | |
def _type_to_str(self, type_hint: Any) -> str: | |
"""Convert type hint to string representation""" | |
if hasattr(type_hint, '__origin__'): | |
# Handle generic types like List[str], Optional[int], etc. | |
origin = type_hint.__origin__ | |
args = type_hint.__args__ | |
if origin == Union and type(None) in args: | |
# Handle Optional | |
other_type = next(arg for arg in args if arg != type(None)) | |
return f"Optional[{self._type_to_str(other_type)}]" | |
return f"{origin.__name__}[{', '.join(self._type_to_str(arg) for arg in args)}]" | |
return type_hint.__name__ if hasattr(type_hint, '__name__') else str(type_hint) | |
def _get_param_doc(self, method: Any, param_name: str) -> str: | |
"""Extract parameter documentation from docstring""" | |
if not method.__doc__: | |
return "" | |
doc_lines = method.__doc__.split('\n') | |
for line in doc_lines: | |
line = line.strip() | |
if line.startswith(f":param {param_name}:"): | |
return line[len(f":param {param_name}:"):].strip() | |
return "" | |
def get_available_methods(self) -> Dict[str, Dict[str, Any]]: | |
"""Get information about all available methods""" | |
return { | |
name: { | |
'description': info.description, | |
'parameters': info.parameters, | |
'return_type': self._type_to_str(info.return_type) | |
} | |
for name, info in self.methods.items() | |
} | |
async def execute(self, method_name: str, params: Optional[Dict[str, Any]] = None) -> ToolResult: | |
"""Execute a tool method with parameters""" | |
try: | |
params = params or {} | |
# Check if method exists | |
if method_name not in self.methods: | |
available = list(self.methods.keys()) | |
return ToolResult( | |
success=False, | |
data=None, | |
error=f"Unknown method: {method_name}. Available methods: {available}" | |
) | |
method_info = self.methods[method_name] | |
# Validate parameters | |
validated_params = {} | |
for param_name, param_info in method_info.parameters.items(): | |
if param_name not in params and param_info['required']: | |
return ToolResult( | |
success=False, | |
data=None, | |
error=f"Missing required parameter: {param_name}" | |
) | |
if param_name in params: | |
# TODO: Add type conversion here if needed | |
validated_params[param_name] = params[param_name] | |
elif param_info['default'] is not None: | |
validated_params[param_name] = param_info['default'] | |
# Execute the method | |
result = await method_info.method(**validated_params) | |
return ToolResult(success=True, data=result) | |
except Exception as e: | |
return ToolResult(success=False, data=None, error=str(e)) | |
@tool_method(description="Get information about available methods") | |
async def help(self) -> Dict[str, Any]: | |
"""Get help information about available methods""" | |
return self.get_available_methods() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment