suna/agentpress/tool_registry.py

42 lines
1.5 KiB
Python
Raw Normal View History

2024-10-23 09:28:12 +08:00
from typing import Dict, Type, Any, List, Optional
2024-10-10 22:21:39 +08:00
from agentpress.tool import Tool
from agentpress.config import settings
2024-10-06 01:04:15 +08:00
import importlib.util
import os
import inspect
class ToolRegistry:
def __init__(self):
self.tools: Dict[str, Dict[str, Any]] = {}
2024-10-23 09:28:12 +08:00
def register_tool(self, tool_cls: Type[Tool], function_names: Optional[List[str]] = None):
2024-10-06 01:04:15 +08:00
tool_instance = tool_cls()
schemas = tool_instance.get_schemas()
2024-10-23 09:28:12 +08:00
if function_names is None:
# Register all functions
for func_name, schema in schemas.items():
self.tools[func_name] = {
"instance": tool_instance,
"schema": schema
}
else:
# Register only specified functions
for func_name in function_names:
if func_name in schemas:
self.tools[func_name] = {
"instance": tool_instance,
"schema": schemas[func_name]
}
else:
raise ValueError(f"Function '{func_name}' not found in {tool_cls.__name__}")
2024-10-06 01:04:15 +08:00
2024-10-23 09:28:12 +08:00
def get_tool(self, tool_name: str) -> Dict[str, Any]:
return self.tools.get(tool_name, {})
2024-10-06 01:04:15 +08:00
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:
2024-10-23 09:28:12 +08:00
return self.tools
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
return [tool_info['schema'] for tool_info in self.tools.values()]