2024-11-18 06:36:37 +08:00
|
|
|
from typing import List, Dict, Any, Set, Callable, Optional
|
|
|
|
import asyncio
|
|
|
|
import json
|
|
|
|
import logging
|
2024-11-18 11:21:08 +08:00
|
|
|
from agentpress.base_processors import ToolExecutorBase
|
2024-11-18 06:36:37 +08:00
|
|
|
from agentpress.tool import ToolResult
|
2024-11-18 08:38:31 +08:00
|
|
|
from agentpress.tool_registry import ToolRegistry
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
class XMLToolExecutor(ToolExecutorBase):
|
2024-11-18 08:38:31 +08:00
|
|
|
def __init__(self, parallel: bool = True, tool_registry: Optional[ToolRegistry] = None):
|
2024-11-18 06:36:37 +08:00
|
|
|
self.parallel = parallel
|
2024-11-18 08:38:31 +08:00
|
|
|
self.tool_registry = tool_registry or ToolRegistry()
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
async def execute_tool_calls(
|
|
|
|
self,
|
|
|
|
tool_calls: List[Dict[str, Any]],
|
|
|
|
available_functions: Dict[str, Callable],
|
|
|
|
thread_id: str,
|
|
|
|
executed_tool_calls: Optional[Set[str]] = None
|
|
|
|
) -> List[Dict[str, Any]]:
|
2024-11-18 08:38:31 +08:00
|
|
|
logging.info(f"Executing {len(tool_calls)} tool calls")
|
2024-11-18 06:36:37 +08:00
|
|
|
if executed_tool_calls is None:
|
|
|
|
executed_tool_calls = set()
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
return await self._execute_parallel(
|
|
|
|
tool_calls,
|
|
|
|
available_functions,
|
|
|
|
thread_id,
|
|
|
|
executed_tool_calls
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return await self._execute_sequential(
|
|
|
|
tool_calls,
|
|
|
|
available_functions,
|
|
|
|
thread_id,
|
|
|
|
executed_tool_calls
|
|
|
|
)
|
|
|
|
|
|
|
|
async def _execute_parallel(
|
|
|
|
self,
|
|
|
|
tool_calls: List[Dict[str, Any]],
|
|
|
|
available_functions: Dict[str, Callable],
|
|
|
|
thread_id: str,
|
|
|
|
executed_tool_calls: Set[str]
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
if tool_call['id'] in executed_tool_calls:
|
2024-11-18 08:38:31 +08:00
|
|
|
logging.info(f"Tool call {tool_call['id']} already executed")
|
2024-11-18 06:36:37 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
|
|
|
function_name = tool_call['function']['name']
|
|
|
|
function_args = tool_call['function']['arguments']
|
2024-11-18 08:38:31 +08:00
|
|
|
logging.info(f"Executing tool: {function_name} with args: {function_args}")
|
|
|
|
|
2024-11-18 06:36:37 +08:00
|
|
|
if isinstance(function_args, str):
|
|
|
|
function_args = json.loads(function_args)
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
# Get tool info from registry
|
|
|
|
tool_info = self.tool_registry.get_tool(function_name)
|
|
|
|
if not tool_info:
|
|
|
|
error_msg = f"Function {function_name} not found in registry"
|
|
|
|
logging.error(error_msg)
|
|
|
|
return {
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(ToolResult(success=False, output=error_msg))
|
|
|
|
}
|
|
|
|
|
|
|
|
# Get function from tool instance
|
|
|
|
function_to_call = getattr(tool_info['instance'], function_name)
|
2024-11-18 06:36:37 +08:00
|
|
|
if not function_to_call:
|
2024-11-18 08:38:31 +08:00
|
|
|
error_msg = f"Function {function_name} not found on tool instance"
|
2024-11-18 06:36:37 +08:00
|
|
|
logging.error(error_msg)
|
|
|
|
return {
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(ToolResult(success=False, output=error_msg))
|
|
|
|
}
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
logging.info(f"Calling function {function_name} with args: {function_args}")
|
2024-11-18 06:36:37 +08:00
|
|
|
result = await function_to_call(**function_args)
|
|
|
|
executed_tool_calls.add(tool_call['id'])
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
logging.info(f"Function {function_name} completed with result: {result}")
|
2024-11-18 06:36:37 +08:00
|
|
|
return {
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(result)
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
error_msg = f"Error executing {function_name}: {str(e)}"
|
|
|
|
logging.error(error_msg)
|
|
|
|
return {
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(ToolResult(success=False, output=error_msg))
|
|
|
|
}
|
|
|
|
|
|
|
|
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
return [r for r in results if r is not None]
|
|
|
|
|
|
|
|
async def _execute_sequential(
|
|
|
|
self,
|
|
|
|
tool_calls: List[Dict[str, Any]],
|
|
|
|
available_functions: Dict[str, Callable],
|
|
|
|
thread_id: str,
|
|
|
|
executed_tool_calls: Set[str]
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
results = []
|
|
|
|
for tool_call in tool_calls:
|
|
|
|
if tool_call['id'] in executed_tool_calls:
|
|
|
|
continue
|
|
|
|
|
|
|
|
try:
|
|
|
|
function_name = tool_call['function']['name']
|
|
|
|
function_args = tool_call['function']['arguments']
|
|
|
|
if isinstance(function_args, str):
|
|
|
|
function_args = json.loads(function_args)
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
# Get tool info from registry
|
|
|
|
tool_info = self.tool_registry.get_tool(function_name)
|
|
|
|
if not tool_info:
|
|
|
|
error_msg = f"Function {function_name} not found in registry"
|
2024-11-18 06:36:37 +08:00
|
|
|
logging.error(error_msg)
|
|
|
|
result = ToolResult(success=False, output=error_msg)
|
|
|
|
else:
|
2024-11-18 08:38:31 +08:00
|
|
|
# Get function from tool instance
|
|
|
|
function_to_call = getattr(tool_info['instance'], function_name, None)
|
|
|
|
if not function_to_call:
|
|
|
|
error_msg = f"Function {function_name} not found on tool instance"
|
|
|
|
logging.error(error_msg)
|
|
|
|
result = ToolResult(success=False, output=error_msg)
|
|
|
|
else:
|
|
|
|
result = await function_to_call(**function_args)
|
|
|
|
executed_tool_calls.add(tool_call['id'])
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
results.append({
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(result)
|
|
|
|
})
|
|
|
|
except Exception as e:
|
|
|
|
error_msg = f"Error executing {function_name}: {str(e)}"
|
|
|
|
logging.error(error_msg)
|
|
|
|
results.append({
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(ToolResult(success=False, output=error_msg))
|
|
|
|
})
|
|
|
|
|
|
|
|
return results
|