suna/agentpress/tool.py

225 lines
7.1 KiB
Python
Raw Permalink Normal View History

2024-10-23 09:28:12 +08:00
"""
2024-11-18 13:54:26 +08:00
Core tool system providing the foundation for creating and managing tools.
This module defines the base classes and decorators for creating tools in AgentPress:
- Tool base class for implementing tool functionality
- Schema decorators for OpenAPI and XML tool definitions
- Result containers for standardized tool outputs
2024-10-23 09:28:12 +08:00
"""
2024-11-18 13:54:26 +08:00
from typing import Dict, Any, Union, Optional, List, Type
2024-11-18 08:38:31 +08:00
from dataclasses import dataclass, field
2024-10-23 09:28:12 +08:00
from abc import ABC
2024-10-06 01:04:15 +08:00
import json
2024-10-23 09:28:12 +08:00
import inspect
2024-11-18 08:38:31 +08:00
from enum import Enum
class SchemaType(Enum):
2024-11-18 13:54:26 +08:00
"""Enumeration of supported schema types for tool definitions."""
2024-11-18 08:38:31 +08:00
OPENAPI = "openapi"
XML = "xml"
CUSTOM = "custom"
2024-11-18 09:15:34 +08:00
@dataclass
class XMLNodeMapping:
2024-11-18 13:54:26 +08:00
"""Maps an XML node to a function parameter.
Attributes:
param_name (str): Name of the function parameter
node_type (str): Type of node ("element", "attribute", or "content")
path (str): XPath-like path to the node ("." means root element)
"""
param_name: str
node_type: str = "element"
path: str = "."
2024-11-18 09:15:34 +08:00
2024-11-18 08:38:31 +08:00
@dataclass
class XMLTagSchema:
2024-11-18 13:54:26 +08:00
"""Schema definition for XML tool tags.
Attributes:
tag_name (str): Root tag name for the tool
mappings (List[XMLNodeMapping]): Parameter mappings for the tag
example (str, optional): Example showing tag usage
Methods:
add_mapping: Add a new parameter mapping to the schema
"""
tag_name: str
2024-11-18 09:15:34 +08:00
mappings: List[XMLNodeMapping] = field(default_factory=list)
2024-11-18 13:54:26 +08:00
example: Optional[str] = None
2024-11-18 09:15:34 +08:00
def add_mapping(self, param_name: str, node_type: str = "element", path: str = ".") -> None:
2024-11-18 13:54:26 +08:00
"""Add a new node mapping to the schema.
Args:
param_name: Name of the function parameter
node_type: Type of node ("element", "attribute", or "content")
path: XPath-like path to the node
"""
2024-11-18 09:15:34 +08:00
self.mappings.append(XMLNodeMapping(
param_name=param_name,
node_type=node_type,
path=path
))
2024-11-18 08:38:31 +08:00
@dataclass
class ToolSchema:
2024-11-18 13:54:26 +08:00
"""Container for tool schemas with type information.
Attributes:
schema_type (SchemaType): Type of schema (OpenAPI, XML, or Custom)
schema (Dict[str, Any]): The actual schema definition
xml_schema (XMLTagSchema, optional): XML-specific schema if applicable
"""
2024-11-18 08:38:31 +08:00
schema_type: SchemaType
schema: Dict[str, Any]
xml_schema: Optional[XMLTagSchema] = None
2024-10-06 01:04:15 +08:00
@dataclass
class ToolResult:
2024-11-18 13:54:26 +08:00
"""Container for tool execution results.
Attributes:
success (bool): Whether the tool execution succeeded
output (str): Output message or error description
"""
2024-10-06 01:04:15 +08:00
success: bool
output: str
class Tool(ABC):
2024-11-18 13:54:26 +08:00
"""Abstract base class for all tools.
Provides the foundation for implementing tools with schema registration
and result handling capabilities.
Attributes:
_schemas (Dict[str, List[ToolSchema]]): Registered schemas for tool methods
Methods:
get_schemas: Get all registered tool schemas
success_response: Create a successful result
fail_response: Create a failed result
"""
2024-10-06 01:04:15 +08:00
def __init__(self):
2024-11-18 13:54:26 +08:00
"""Initialize tool with empty schema registry."""
2024-11-18 08:38:31 +08:00
self._schemas: Dict[str, List[ToolSchema]] = {}
2024-10-23 09:28:12 +08:00
self._register_schemas()
def _register_schemas(self):
"""Register schemas from all decorated methods."""
2024-10-23 09:28:12 +08:00
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
2024-11-18 08:38:31 +08:00
if hasattr(method, 'tool_schemas'):
self._schemas[name] = method.tool_schemas
2024-10-06 01:04:15 +08:00
2024-11-18 08:38:31 +08:00
def get_schemas(self) -> Dict[str, List[ToolSchema]]:
2024-11-18 13:54:26 +08:00
"""Get all registered tool schemas.
Returns:
Dict mapping method names to their schema definitions
"""
2024-10-23 09:28:12 +08:00
return self._schemas
2024-10-06 01:04:15 +08:00
2024-11-02 08:01:26 +08:00
def success_response(self, data: Union[Dict[str, Any], str]) -> ToolResult:
2024-11-18 13:54:26 +08:00
"""Create a successful tool result.
Args:
data: Result data (dictionary or string)
Returns:
ToolResult with success=True and formatted output
"""
2024-10-06 01:04:15 +08:00
if isinstance(data, str):
text = data
else:
text = json.dumps(data, indent=2)
return ToolResult(success=True, output=text)
def fail_response(self, msg: str) -> ToolResult:
2024-11-18 13:54:26 +08:00
"""Create a failed tool result.
Args:
msg: Error message describing the failure
Returns:
ToolResult with success=False and error message
"""
2024-10-06 01:04:15 +08:00
return ToolResult(success=False, output=msg)
2024-11-18 08:38:31 +08:00
def _add_schema(func, schema: ToolSchema):
"""Helper to add schema to a function."""
if not hasattr(func, 'tool_schemas'):
func.tool_schemas = []
func.tool_schemas.append(schema)
return func
def openapi_schema(schema: Dict[str, Any]):
"""Decorator for OpenAPI schema tools."""
def decorator(func):
return _add_schema(func, ToolSchema(
schema_type=SchemaType.OPENAPI,
schema=schema
))
return decorator
def xml_schema(
tag_name: str,
2024-11-18 09:15:34 +08:00
mappings: List[Dict[str, str]] = None,
example: str = None # Changed from description to example
2024-11-18 08:38:31 +08:00
):
"""
2024-11-18 09:15:34 +08:00
Decorator for XML schema tools with improved node mapping.
2024-11-18 08:38:31 +08:00
Args:
tag_name: Name of the root XML tag
2024-11-18 09:15:34 +08:00
mappings: List of mapping definitions, each containing:
- param_name: Name of the function parameter
- node_type: "element", "attribute", or "content"
- path: Path to the node (default "." for root)
example: Optional example showing how to use the XML tag
2024-11-18 08:38:31 +08:00
Example:
@xml_schema(
tag_name="str-replace",
2024-11-18 09:15:34 +08:00
mappings=[
{"param_name": "file_path", "node_type": "attribute", "path": "."},
{"param_name": "old_str", "node_type": "element", "path": "old_str"},
{"param_name": "new_str", "node_type": "element", "path": "new_str"}
],
example='''
<str-replace file_path="path/to/file">
<old_str>text to replace</old_str>
<new_str>replacement text</new_str>
</str-replace>
'''
2024-11-18 08:38:31 +08:00
)
2024-10-23 09:28:12 +08:00
"""
def decorator(func):
xml_schema = XMLTagSchema(tag_name=tag_name, example=example)
2024-11-18 09:15:34 +08:00
# Add mappings
if mappings:
for mapping in mappings:
xml_schema.add_mapping(
param_name=mapping["param_name"],
node_type=mapping.get("node_type", "element"),
path=mapping.get("path", ".")
)
2024-11-18 08:38:31 +08:00
return _add_schema(func, ToolSchema(
schema_type=SchemaType.XML,
2024-11-18 09:15:34 +08:00
schema={}, # OpenAPI schema could be added here if needed
xml_schema=xml_schema
2024-11-18 08:38:31 +08:00
))
return decorator
def custom_schema(schema: Dict[str, Any]):
"""Decorator for custom schema tools."""
def decorator(func):
return _add_schema(func, ToolSchema(
schema_type=SchemaType.CUSTOM,
schema=schema
))
2024-10-23 09:28:12 +08:00
return decorator