Skip to content

Instantly share code, notes, and snippets.

@bemijonathan
Created December 17, 2024 13:54
Show Gist options
  • Save bemijonathan/4515aedebc3a009960a9c10212fc3aff to your computer and use it in GitHub Desktop.
Save bemijonathan/4515aedebc3a009960a9c10212fc3aff to your computer and use it in GitHub Desktop.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, List, Type, get_type_hints, Union
from dataclasses import dataclass
import inspect
import json
from functools import wraps
@dataclass
class ToolResult:
success: bool
data: Any
error: Optional[str] = None
@dataclass
class MethodInfo:
"""Information about a tool method"""
name: str
description: str
parameters: Dict[str, Dict[str, Any]]
return_type: Any
method: Any
def tool_method(description: str = ""):
"""Decorator to mark methods as available tool actions"""
def decorator(func):
func.is_tool_method = True
func.description = description
return func
return decorator
class BaseTool(ABC):
"""Base class for all tools with automatic method discovery"""
def __init__(self):
self.methods = self._discover_methods()
def _discover_methods(self) -> Dict[str, MethodInfo]:
"""Discover all available tool methods"""
methods = {}
for name, method in inspect.getmembers(self, inspect.ismethod):
if hasattr(method, 'is_tool_method'):
# Get method signature
sig = inspect.signature(method)
hints = get_type_hints(method)
# Build parameter info
params = {}
for param_name, param in sig.parameters.items():
if param_name == 'self':
continue
param_type = hints.get(param_name, Any)
param_info = {
'type': self._type_to_str(param_type),
'required': param.default == param.empty,
'default': None if param.default == param.empty else param.default,
'description': self._get_param_doc(method, param_name)
}
params[param_name] = param_info
methods[name] = MethodInfo(
name=name,
description=method.description,
parameters=params,
return_type=hints.get('return', Any),
method=method
)
return methods
def _type_to_str(self, type_hint: Any) -> str:
"""Convert type hint to string representation"""
if hasattr(type_hint, '__origin__'):
# Handle generic types like List[str], Optional[int], etc.
origin = type_hint.__origin__
args = type_hint.__args__
if origin == Union and type(None) in args:
# Handle Optional
other_type = next(arg for arg in args if arg != type(None))
return f"Optional[{self._type_to_str(other_type)}]"
return f"{origin.__name__}[{', '.join(self._type_to_str(arg) for arg in args)}]"
return type_hint.__name__ if hasattr(type_hint, '__name__') else str(type_hint)
def _get_param_doc(self, method: Any, param_name: str) -> str:
"""Extract parameter documentation from docstring"""
if not method.__doc__:
return ""
doc_lines = method.__doc__.split('\n')
for line in doc_lines:
line = line.strip()
if line.startswith(f":param {param_name}:"):
return line[len(f":param {param_name}:"):].strip()
return ""
def get_available_methods(self) -> Dict[str, Dict[str, Any]]:
"""Get information about all available methods"""
return {
name: {
'description': info.description,
'parameters': info.parameters,
'return_type': self._type_to_str(info.return_type)
}
for name, info in self.methods.items()
}
async def execute(self, method_name: str, params: Optional[Dict[str, Any]] = None) -> ToolResult:
"""Execute a tool method with parameters"""
try:
params = params or {}
# Check if method exists
if method_name not in self.methods:
available = list(self.methods.keys())
return ToolResult(
success=False,
data=None,
error=f"Unknown method: {method_name}. Available methods: {available}"
)
method_info = self.methods[method_name]
# Validate parameters
validated_params = {}
for param_name, param_info in method_info.parameters.items():
if param_name not in params and param_info['required']:
return ToolResult(
success=False,
data=None,
error=f"Missing required parameter: {param_name}"
)
if param_name in params:
# TODO: Add type conversion here if needed
validated_params[param_name] = params[param_name]
elif param_info['default'] is not None:
validated_params[param_name] = param_info['default']
# Execute the method
result = await method_info.method(**validated_params)
return ToolResult(success=True, data=result)
except Exception as e:
return ToolResult(success=False, data=None, error=str(e))
@tool_method(description="Get information about available methods")
async def help(self) -> Dict[str, Any]:
"""Get help information about available methods"""
return self.get_available_methods()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment