suna/backend/agent/tools/agent_execution_tool.py

397 lines
18 KiB
Python

from agentpress.tool import ToolResult, openapi_schema, usage_example
from agentpress.thread_manager import ThreadManager
from sandbox.tool_base import SandboxToolsBase
from utils.logger import logger
from typing import Optional
import json
import asyncio
from datetime import datetime
class AgentExecutionTool(SandboxToolsBase):
"""
Tool for executing other agents with prompts or workflows.
Enables agent-to-agent communication by allowing one agent to call another
and receive the results. Supports both custom prompt execution and workflow execution.
"""
def __init__(self, project_id: str, thread_manager: ThreadManager, account_id: str):
super().__init__(project_id, thread_manager)
self.account_id = account_id
def _extract_last_message(self, full_output: str) -> str:
"""Extract the last meaningful message from agent output."""
if not full_output.strip():
return "No output received"
lines = full_output.strip().split('\n')
# Look for the last substantial message
for line in reversed(lines):
if line.strip() and not line.startswith('#') and not line.startswith('```'):
try:
line_index = lines.index(line)
start_index = max(0, line_index - 3)
return '\n'.join(lines[start_index:]).strip()
except ValueError:
return line.strip()
# Fallback: return last 20% of the output
return full_output[-len(full_output)//5:].strip() if len(full_output) > 100 else full_output
def _truncate_from_end(self, text: str, max_tokens: int) -> str:
"""Truncate text from the beginning, keeping the end."""
if max_tokens <= 0:
return ""
max_chars = max_tokens * 4 # Rough token estimation
if len(text) <= max_chars:
return text
truncated = text[-max_chars:]
return f"...[truncated {len(text) - max_chars} characters]...\n{truncated}"
def _get_fallback_model(self, requested_model: Optional[str] = None) -> str:
"""Get a reliable model with fallback logic."""
if requested_model:
# Validate the requested model is reasonable
if any(provider in requested_model.lower() for provider in ['openrouter', 'anthropic', 'openai', 'google']):
return requested_model
# Use a reliable free-tier model as fallback
fallback_model = "openrouter/google/gemini-2.5-flash"
if requested_model and requested_model != fallback_model:
logger.info(f"Model {requested_model} not validated, using fallback: {fallback_model}")
return fallback_model
@openapi_schema({
"type": "function",
"function": {
"name": "call_agent",
"description": "Execute another agent with a custom prompt or workflow. This allows inter-agent communication and delegation of tasks to specialized agents.",
"parameters": {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": "The ID of the agent to call"
},
"message": {
"type": "string",
"description": "The message/prompt to send to the agent"
},
"execution_mode": {
"type": "string",
"enum": ["prompt", "workflow"],
"description": "Either 'prompt' for custom prompt execution or 'workflow' for workflow execution",
"default": "prompt"
},
"workflow_id": {
"type": "string",
"description": "Required when execution_mode is 'workflow' - the ID of the workflow to run"
},
"model_name": {
"type": "string",
"description": "Model to use for the agent execution. If not specified, uses the agent's configured model or fallback."
},
"timeout": {
"type": "integer",
"description": "Maximum time to wait for agent response in seconds",
"default": 60,
"minimum": 10,
"maximum": 300
},
"max_tokens": {
"type": "integer",
"description": "Maximum tokens in response",
"default": 1000,
"minimum": 100,
"maximum": 4000
},
"output_mode": {
"type": "string",
"enum": ["last_message", "full"],
"description": "How to format output: 'last_message' (default) extracts key results, 'full' returns complete output",
"default": "last_message"
}
},
"required": ["agent_id", "message"]
}
}
})
@usage_example('''
<function_calls>
<invoke name="call_agent">
<parameter name="agent_id">data_analyst_agent_123</parameter>
<parameter name="message">Analyze the sales data from Q3 and provide key insights</parameter>
<parameter name="execution_mode">prompt</parameter>
<parameter name="timeout">120</parameter>
</invoke>
</function_calls>
<!-- Example with workflow -->
<function_calls>
<invoke name="call_agent">
<parameter name="agent_id">report_generator_456</parameter>
<parameter name="message">Generate monthly report with latest metrics</parameter>
<parameter name="execution_mode">workflow</parameter>
<parameter name="workflow_id">monthly_report_workflow_789</parameter>
</invoke>
</function_calls>
''')
async def call_agent(
self,
agent_id: str,
message: str,
execution_mode: str = "prompt",
workflow_id: Optional[str] = None,
model_name: Optional[str] = None,
timeout: int = 60,
max_tokens: int = 1000,
output_mode: str = "last_message"
) -> ToolResult:
"""Execute another agent and return the results."""
try:
# Validate execution mode and workflow parameters
if execution_mode not in ["prompt", "workflow"]:
return self.fail_response("Error: execution_mode must be either 'prompt' or 'workflow'")
if execution_mode == "workflow" and not workflow_id:
return self.fail_response("Error: workflow_id is required when execution_mode is 'workflow'")
# Verify agent access
await self._verify_agent_access(agent_id)
# Apply model fallback logic
model_name = self._get_fallback_model(model_name)
# Validate parameters
timeout = max(10, min(300, timeout)) # Clamp between 10 and 300 seconds
max_tokens = max(100, min(4000, max_tokens)) # Clamp between 100 and 4000 tokens
# Validate output mode
if output_mode not in ["last_message", "full"]:
output_mode = "last_message"
logger.info(f"Calling agent {agent_id} in {execution_mode} mode with timeout {timeout}s")
if execution_mode == "workflow":
# Execute workflow
raw_result = await self._execute_agent_workflow(agent_id, workflow_id, message, model_name, timeout)
else:
# Execute agent with prompt
raw_result = await self._execute_agent_prompt(agent_id, message, model_name, timeout)
# Process the output based on the requested mode
if output_mode == "last_message":
processed_result = self._extract_last_message(raw_result)
else:
processed_result = raw_result
# Apply token limiting
final_result = self._truncate_from_end(processed_result, max_tokens)
# Return structured result
response_data = {
"agent_id": agent_id,
"execution_mode": execution_mode,
"workflow_id": workflow_id if execution_mode == "workflow" else None,
"model_used": model_name,
"output_mode": output_mode,
"max_tokens": max_tokens,
"result": final_result,
"timestamp": datetime.now().isoformat()
}
logger.info(f"Agent call completed for agent {agent_id} in {execution_mode} mode")
return self.success_response(response_data)
except Exception as e:
logger.error(f"Error calling agent {agent_id}: {str(e)}")
return self.fail_response(f"Error calling agent: {str(e)}")
async def _execute_agent_workflow(self, agent_id: str, workflow_id: str, message: str, model_name: str, timeout: int) -> str:
"""Execute an agent workflow."""
try:
client = await self.thread_manager.db.client
# Verify workflow exists and is active
workflow_result = await client.table('agent_workflows').select('*').eq('id', workflow_id).eq('agent_id', agent_id).execute()
if not workflow_result.data:
return f"Error: Workflow {workflow_id} not found for agent {agent_id}"
workflow = workflow_result.data[0]
if workflow.get('status') != 'active':
return f"Error: Workflow {workflow['name']} is not active (status: {workflow.get('status')})"
# Execute workflow through the execution service
try:
from triggers.execution_service import execute_workflow
# Execute the workflow with the provided message
execution_result = await execute_workflow(
workflow_id=workflow_id,
agent_id=agent_id,
input_data={"message": message},
user_id=self.account_id
)
if execution_result.get('success'):
return execution_result.get('output', f"Workflow '{workflow['name']}' executed successfully")
else:
return f"Workflow execution failed: {execution_result.get('error', 'Unknown error')}"
except ImportError:
logger.warning("Execution service not available, using fallback workflow execution")
# Fallback: Create a thread and run the agent with workflow context
return await self._execute_agent_with_thread(agent_id, message, model_name, timeout, workflow)
except Exception as e:
logger.error(f"Error executing workflow {workflow_id}: {str(e)}")
return f"Error executing workflow: {str(e)}"
async def _execute_agent_prompt(self, agent_id: str, message: str, model_name: str, timeout: int) -> str:
"""Execute an agent with a custom prompt."""
try:
return await self._execute_agent_with_thread(agent_id, message, model_name, timeout)
except Exception as e:
logger.error(f"Error executing agent prompt: {str(e)}")
return f"Error executing agent: {str(e)}"
async def _execute_agent_with_thread(self, agent_id: str, message: str, model_name: str, timeout: int, workflow: Optional[dict] = None) -> str:
"""Execute agent using thread-based approach."""
try:
# Import existing agent execution functions
from agent.api import create_thread, add_message_to_thread, start_agent, AgentStartRequest
# Create thread name based on execution type
thread_name = f"Workflow: {workflow['name']}" if workflow else "Agent-to-Agent Call"
# Create a new thread
thread_response = await create_thread(name=thread_name, user_id=self.account_id)
thread_id = thread_response.get('thread_id') if isinstance(thread_response, dict) else thread_response.thread_id
# Prepare message with workflow context if needed
final_message = message
if workflow:
workflow_context = f"Executing workflow '{workflow['name']}'"
if workflow.get('description'):
workflow_context += f": {workflow['description']}"
final_message = f"{workflow_context}\n\nUser message: {message}"
# Add the message to the thread
await add_message_to_thread(
thread_id=thread_id,
message=final_message,
user_id=self.account_id
)
# Start the agent
agent_request = AgentStartRequest(
agent_id=agent_id,
enable_thinking=False,
stream=False,
model_name=model_name
)
await start_agent(
thread_id=thread_id,
body=agent_request,
user_id=self.account_id
)
# Wait for agent completion and get response
return await self._poll_for_completion(thread_id, timeout)
except Exception as e:
logger.error(f"Error executing agent with thread: {str(e)}")
return f"Error executing agent: {str(e)}"
async def _poll_for_completion(self, thread_id: str, timeout: int) -> str:
"""Poll for agent completion and return the result."""
client = await self.thread_manager.db.client
poll_interval = 2
elapsed = 0
while elapsed < timeout:
# Check thread messages for agent response
messages_result = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(5).execute()
if messages_result.data:
# Look for the most recent agent message (not user message)
for msg in messages_result.data:
# Parse JSON content to check role
content = msg.get('content')
if content:
try:
if isinstance(content, str):
parsed_content = json.loads(content)
else:
parsed_content = content
if parsed_content.get('role') == 'assistant':
return parsed_content.get('content', '')
except:
# If parsing fails, check if it's a direct assistant message
if msg.get('type') == 'assistant':
return content
# Check if agent run is complete by checking agent_runs table
runs_result = await client.table('agent_runs').select('status, error').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
if runs_result.data:
run = runs_result.data[0]
if run['status'] in ['completed', 'failed', 'cancelled']:
if run['status'] == 'failed':
return f"Agent execution failed: {run.get('error', 'Unknown error')}"
elif run['status'] == 'cancelled':
return "Agent execution was cancelled"
# If completed, continue to check for messages
break
await asyncio.sleep(poll_interval)
elapsed += poll_interval
# Timeout fallback - get latest messages
messages_result = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(10).execute()
if messages_result.data:
# Return the most recent assistant message or fallback message
for msg in messages_result.data:
# Parse JSON content to check role
content = msg.get('content')
if content:
try:
if isinstance(content, str):
parsed_content = json.loads(content)
else:
parsed_content = content
if parsed_content.get('role') == 'assistant':
return parsed_content.get('content', '')
except:
# If parsing fails, check if it's a direct assistant message
if msg.get('type') == 'assistant':
return content
return f"Agent execution timed out after {timeout}s. Thread ID: {thread_id}"
return f"No response received from agent. Thread ID: {thread_id}"
async def _verify_agent_access(self, agent_id: str):
"""Verify account has access to the agent."""
try:
client = await self.thread_manager.db.client
result = await client.table('agents').select('agent_id').eq('agent_id', agent_id).eq('account_id', self.account_id).execute()
if not result.data:
raise ValueError("Agent not found or access denied")
except ValueError:
# Re-raise ValueError for proper error messages
raise
except Exception as e:
logger.error(f"Database error in verify_agent_access: {str(e)}")
raise ValueError("Database connection error")