mirror of https://github.com/kortix-ai/suna.git
BIG FEATURE: agent calls another agent
This commit is contained in:
parent
cd3ebb5e78
commit
389883f185
|
@ -213,7 +213,9 @@ def _get_default_agentpress_tools() -> Dict[str, bool]:
|
|||
"mcp_search_tool": True,
|
||||
"credential_profile_tool": True,
|
||||
"workflow_tool": True,
|
||||
"trigger_tool": True
|
||||
"trigger_tool": True,
|
||||
"agent_discovery_tool": True,
|
||||
"agent_execution_tool": True
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ from agentpress.tool import SchemaType
|
|||
from agent.tools.sb_sheets_tool import SandboxSheetsTool
|
||||
from agent.tools.sb_web_dev_tool import SandboxWebDevTool
|
||||
from agent.tools.sb_upload_file_tool import SandboxUploadFileTool
|
||||
from agent.tools.agent_discovery_tool import AgentDiscoveryTool
|
||||
from agent.tools.agent_execution_tool import AgentExecutionTool
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -59,10 +61,11 @@ class AgentConfig:
|
|||
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, thread_manager: ThreadManager, project_id: str, thread_id: str):
|
||||
def __init__(self, thread_manager: ThreadManager, project_id: str, thread_id: str, account_id: Optional[str] = None):
|
||||
self.thread_manager = thread_manager
|
||||
self.project_id = project_id
|
||||
self.thread_id = thread_id
|
||||
self.account_id = account_id
|
||||
|
||||
def register_all_tools(self, agent_id: Optional[str] = None, disabled_tools: Optional[List[str]] = None):
|
||||
"""Register all available tools by default, with optional exclusions.
|
||||
|
@ -88,6 +91,9 @@ class ToolManager:
|
|||
if agent_id:
|
||||
self._register_agent_builder_tools(agent_id, disabled_tools)
|
||||
|
||||
# Always register agent communication tools
|
||||
self._register_agent_communication_tools(disabled_tools)
|
||||
|
||||
# Browser tool
|
||||
self._register_browser_tool(disabled_tools)
|
||||
|
||||
|
@ -152,6 +158,27 @@ class ToolManager:
|
|||
self.thread_manager.add_tool(tool_class, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id)
|
||||
logger.debug(f"Registered {tool_name}")
|
||||
|
||||
def _register_agent_communication_tools(self, disabled_tools: List[str]):
|
||||
"""Register agent-to-agent communication tools."""
|
||||
if not self.account_id:
|
||||
logger.warning("Agent communication tools disabled: account_id not available")
|
||||
return
|
||||
|
||||
agent_comm_tools = [
|
||||
('agent_discovery_tool', AgentDiscoveryTool),
|
||||
('agent_execution_tool', AgentExecutionTool),
|
||||
]
|
||||
|
||||
for tool_name, tool_class in agent_comm_tools:
|
||||
if tool_name not in disabled_tools:
|
||||
self.thread_manager.add_tool(
|
||||
tool_class,
|
||||
project_id=self.project_id,
|
||||
thread_manager=self.thread_manager,
|
||||
account_id=self.account_id
|
||||
)
|
||||
logger.debug(f"Registered {tool_name}")
|
||||
|
||||
def _register_browser_tool(self, disabled_tools: List[str]):
|
||||
"""Register browser tool."""
|
||||
if 'browser_tool' not in disabled_tools:
|
||||
|
@ -471,7 +498,7 @@ class AgentRunner:
|
|||
logger.debug(f"No sandbox found for project {self.config.project_id}; will create lazily when needed")
|
||||
|
||||
async def setup_tools(self):
|
||||
tool_manager = ToolManager(self.thread_manager, self.config.project_id, self.config.thread_id)
|
||||
tool_manager = ToolManager(self.thread_manager, self.config.project_id, self.config.thread_id, self.account_id)
|
||||
|
||||
# Determine agent ID for agent builder tools
|
||||
agent_id = None
|
||||
|
|
|
@ -29,7 +29,9 @@ SUNA_CONFIG = {
|
|||
"mcp_search_tool": True,
|
||||
"credential_profile_tool": True,
|
||||
"workflow_tool": True,
|
||||
"trigger_tool": True
|
||||
"trigger_tool": True,
|
||||
"agent_discovery_tool": True,
|
||||
"agent_execution_tool": True
|
||||
},
|
||||
"is_default": True
|
||||
}
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
from agentpress.tool import ToolResult, openapi_schema, usage_example
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from sandbox.tool_base import SandboxToolsBase
|
||||
from utils.logger import logger
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
class AgentDiscoveryTool(SandboxToolsBase):
|
||||
"""
|
||||
Tool for discovering and listing available agents and their workflows.
|
||||
|
||||
Allows agents to discover other agents in the same account and view their capabilities.
|
||||
This enables agent-to-agent communication and coordination within the platform.
|
||||
"""
|
||||
|
||||
def __init__(self, project_id: str, thread_manager: ThreadManager, account_id: str):
|
||||
super().__init__(project_id, thread_manager)
|
||||
self.account_id = account_id
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_available_agents",
|
||||
"description": "List all available agents in the current account that can be called by this agent. Returns agent IDs, names, descriptions, and basic information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"include_self": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include the current agent in the results",
|
||||
"default": False
|
||||
},
|
||||
"search": {
|
||||
"type": "string",
|
||||
"description": "Optional search term to filter agents by name or description"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
})
|
||||
@usage_example('''
|
||||
<function_calls>
|
||||
<invoke name="list_available_agents">
|
||||
<parameter name="include_self">false</parameter>
|
||||
<parameter name="search">data analysis</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
''')
|
||||
async def list_available_agents(self, include_self: bool = False, search: Optional[str] = None) -> ToolResult:
|
||||
"""List all available agents in the account."""
|
||||
try:
|
||||
# Import the get_agents function from agent.api
|
||||
try:
|
||||
from agent.api import get_agents
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import get_agents: {str(e)}")
|
||||
return self.fail_response("Agent discovery service is not available")
|
||||
|
||||
# Call the existing endpoint to get agents
|
||||
response = await get_agents(
|
||||
user_id=self.account_id,
|
||||
page=1,
|
||||
limit=100, # Get all agents
|
||||
search=search,
|
||||
sort_by="created_at",
|
||||
sort_order="desc",
|
||||
has_default=None,
|
||||
has_mcp_tools=None,
|
||||
has_agentpress_tools=None,
|
||||
tools=None
|
||||
)
|
||||
|
||||
# Handle both dict and object response formats
|
||||
if hasattr(response, 'agents'):
|
||||
agents = response.agents
|
||||
elif isinstance(response, dict) and 'agents' in response:
|
||||
agents = response['agents']
|
||||
else:
|
||||
logger.error(f"Unexpected response format from get_agents: {type(response)}")
|
||||
return self.fail_response(f"Error: Unexpected response format from get_agents")
|
||||
|
||||
if not agents:
|
||||
return self.success_response("No agents found in your account.")
|
||||
|
||||
# Format agents for easy consumption
|
||||
agent_list = []
|
||||
for agent in agents:
|
||||
# Handle both dict and object formats for individual agents
|
||||
agent_id = agent.agent_id if hasattr(agent, 'agent_id') else agent.get('agent_id')
|
||||
name = agent.name if hasattr(agent, 'name') else agent.get('name')
|
||||
description = agent.description if hasattr(agent, 'description') else agent.get('description')
|
||||
|
||||
# Skip self if not requested
|
||||
if not include_self and agent_id == getattr(self.thread_manager, 'current_agent_id', None):
|
||||
continue
|
||||
|
||||
agent_info = {
|
||||
"agent_id": agent_id,
|
||||
"name": name,
|
||||
"description": description or "No description available"
|
||||
}
|
||||
agent_list.append(agent_info)
|
||||
|
||||
result = {
|
||||
"total_agents": len(agent_list),
|
||||
"agents": agent_list,
|
||||
"note": "Use 'call_agent' tool with the agent_id to invoke any of these agents"
|
||||
}
|
||||
|
||||
logger.info(f"Listed {len(agent_list)} agents via native tool")
|
||||
return self.success_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in list_available_agents: {str(e)}")
|
||||
return self.fail_response(f"Error listing agents: {str(e)}")
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_agent_workflows",
|
||||
"description": "Get all available workflows for a specific agent. Workflows are pre-configured execution paths with specific parameters.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the agent to get workflows for"
|
||||
}
|
||||
},
|
||||
"required": ["agent_id"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@usage_example('''
|
||||
<function_calls>
|
||||
<invoke name="get_agent_workflows">
|
||||
<parameter name="agent_id">agent_12345</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
''')
|
||||
async def get_agent_workflows(self, agent_id: str) -> ToolResult:
|
||||
"""Get workflows for a specific agent."""
|
||||
try:
|
||||
# Verify agent access
|
||||
await self._verify_agent_access(agent_id)
|
||||
|
||||
# Get workflows from database
|
||||
client = await self.thread_manager.db.client
|
||||
result = await client.table('agent_workflows').select('*').eq('agent_id', agent_id).order('created_at', desc=True).execute()
|
||||
|
||||
if not result.data:
|
||||
return self.success_response({
|
||||
"agent_id": agent_id,
|
||||
"workflows": [],
|
||||
"message": f"No workflows found for agent {agent_id}. This agent can only be run with custom prompts."
|
||||
})
|
||||
|
||||
# Format workflows for consumption
|
||||
workflows = []
|
||||
for workflow in result.data:
|
||||
workflow_info = {
|
||||
"workflow_id": workflow['id'],
|
||||
"name": workflow['name'],
|
||||
"description": workflow.get('description', 'No description available'),
|
||||
"status": workflow.get('status', 'unknown')
|
||||
}
|
||||
workflows.append(workflow_info)
|
||||
|
||||
result_data = {
|
||||
"agent_id": agent_id,
|
||||
"total_workflows": len(workflows),
|
||||
"workflows": workflows,
|
||||
"note": "Use 'call_agent' tool with execution_mode='workflow' and workflow_id to run a specific workflow"
|
||||
}
|
||||
|
||||
logger.info(f"Listed {len(workflows)} workflows for agent {agent_id} via native tool")
|
||||
return self.success_response(result_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_workflows: {str(e)}")
|
||||
return self.fail_response(f"Error listing workflows: {str(e)}")
|
||||
|
||||
async def _verify_agent_access(self, agent_id: str):
|
||||
"""Verify account has access to the agent."""
|
||||
try:
|
||||
client = await self.thread_manager.db.client
|
||||
result = await client.table('agents').select('agent_id').eq('agent_id', agent_id).eq('account_id', self.account_id).execute()
|
||||
|
||||
if not result.data:
|
||||
raise ValueError("Agent not found or access denied")
|
||||
except ValueError:
|
||||
# Re-raise ValueError for proper error messages
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Database error in verify_agent_access: {str(e)}")
|
||||
raise ValueError("Database connection error")
|
|
@ -0,0 +1,397 @@
|
|||
from agentpress.tool import ToolResult, openapi_schema, usage_example
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from sandbox.tool_base import SandboxToolsBase
|
||||
from utils.logger import logger
|
||||
from typing import Optional
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
class AgentExecutionTool(SandboxToolsBase):
|
||||
"""
|
||||
Tool for executing other agents with prompts or workflows.
|
||||
|
||||
Enables agent-to-agent communication by allowing one agent to call another
|
||||
and receive the results. Supports both custom prompt execution and workflow execution.
|
||||
"""
|
||||
|
||||
def __init__(self, project_id: str, thread_manager: ThreadManager, account_id: str):
|
||||
super().__init__(project_id, thread_manager)
|
||||
self.account_id = account_id
|
||||
|
||||
def _extract_last_message(self, full_output: str) -> str:
|
||||
"""Extract the last meaningful message from agent output."""
|
||||
if not full_output.strip():
|
||||
return "No output received"
|
||||
|
||||
lines = full_output.strip().split('\n')
|
||||
|
||||
# Look for the last substantial message
|
||||
for line in reversed(lines):
|
||||
if line.strip() and not line.startswith('#') and not line.startswith('```'):
|
||||
try:
|
||||
line_index = lines.index(line)
|
||||
start_index = max(0, line_index - 3)
|
||||
return '\n'.join(lines[start_index:]).strip()
|
||||
except ValueError:
|
||||
return line.strip()
|
||||
|
||||
# Fallback: return last 20% of the output
|
||||
return full_output[-len(full_output)//5:].strip() if len(full_output) > 100 else full_output
|
||||
|
||||
def _truncate_from_end(self, text: str, max_tokens: int) -> str:
|
||||
"""Truncate text from the beginning, keeping the end."""
|
||||
if max_tokens <= 0:
|
||||
return ""
|
||||
|
||||
max_chars = max_tokens * 4 # Rough token estimation
|
||||
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
|
||||
truncated = text[-max_chars:]
|
||||
return f"...[truncated {len(text) - max_chars} characters]...\n{truncated}"
|
||||
|
||||
def _get_fallback_model(self, requested_model: Optional[str] = None) -> str:
|
||||
"""Get a reliable model with fallback logic."""
|
||||
if requested_model:
|
||||
# Validate the requested model is reasonable
|
||||
if any(provider in requested_model.lower() for provider in ['openrouter', 'anthropic', 'openai', 'google']):
|
||||
return requested_model
|
||||
|
||||
# Use a reliable free-tier model as fallback
|
||||
fallback_model = "openrouter/google/gemini-2.5-flash"
|
||||
if requested_model and requested_model != fallback_model:
|
||||
logger.info(f"Model {requested_model} not validated, using fallback: {fallback_model}")
|
||||
|
||||
return fallback_model
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "call_agent",
|
||||
"description": "Execute another agent with a custom prompt or workflow. This allows inter-agent communication and delegation of tasks to specialized agents.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the agent to call"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message/prompt to send to the agent"
|
||||
},
|
||||
"execution_mode": {
|
||||
"type": "string",
|
||||
"enum": ["prompt", "workflow"],
|
||||
"description": "Either 'prompt' for custom prompt execution or 'workflow' for workflow execution",
|
||||
"default": "prompt"
|
||||
},
|
||||
"workflow_id": {
|
||||
"type": "string",
|
||||
"description": "Required when execution_mode is 'workflow' - the ID of the workflow to run"
|
||||
},
|
||||
"model_name": {
|
||||
"type": "string",
|
||||
"description": "Model to use for the agent execution. If not specified, uses the agent's configured model or fallback."
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Maximum time to wait for agent response in seconds",
|
||||
"default": 60,
|
||||
"minimum": 10,
|
||||
"maximum": 300
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"description": "Maximum tokens in response",
|
||||
"default": 1000,
|
||||
"minimum": 100,
|
||||
"maximum": 4000
|
||||
},
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["last_message", "full"],
|
||||
"description": "How to format output: 'last_message' (default) extracts key results, 'full' returns complete output",
|
||||
"default": "last_message"
|
||||
}
|
||||
},
|
||||
"required": ["agent_id", "message"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@usage_example('''
|
||||
<function_calls>
|
||||
<invoke name="call_agent">
|
||||
<parameter name="agent_id">data_analyst_agent_123</parameter>
|
||||
<parameter name="message">Analyze the sales data from Q3 and provide key insights</parameter>
|
||||
<parameter name="execution_mode">prompt</parameter>
|
||||
<parameter name="timeout">120</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
<!-- Example with workflow -->
|
||||
<function_calls>
|
||||
<invoke name="call_agent">
|
||||
<parameter name="agent_id">report_generator_456</parameter>
|
||||
<parameter name="message">Generate monthly report with latest metrics</parameter>
|
||||
<parameter name="execution_mode">workflow</parameter>
|
||||
<parameter name="workflow_id">monthly_report_workflow_789</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
''')
|
||||
async def call_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
execution_mode: str = "prompt",
|
||||
workflow_id: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_tokens: int = 1000,
|
||||
output_mode: str = "last_message"
|
||||
) -> ToolResult:
|
||||
"""Execute another agent and return the results."""
|
||||
try:
|
||||
# Validate execution mode and workflow parameters
|
||||
if execution_mode not in ["prompt", "workflow"]:
|
||||
return self.fail_response("Error: execution_mode must be either 'prompt' or 'workflow'")
|
||||
|
||||
if execution_mode == "workflow" and not workflow_id:
|
||||
return self.fail_response("Error: workflow_id is required when execution_mode is 'workflow'")
|
||||
|
||||
# Verify agent access
|
||||
await self._verify_agent_access(agent_id)
|
||||
|
||||
# Apply model fallback logic
|
||||
model_name = self._get_fallback_model(model_name)
|
||||
|
||||
# Validate parameters
|
||||
timeout = max(10, min(300, timeout)) # Clamp between 10 and 300 seconds
|
||||
max_tokens = max(100, min(4000, max_tokens)) # Clamp between 100 and 4000 tokens
|
||||
|
||||
# Validate output mode
|
||||
if output_mode not in ["last_message", "full"]:
|
||||
output_mode = "last_message"
|
||||
|
||||
logger.info(f"Calling agent {agent_id} in {execution_mode} mode with timeout {timeout}s")
|
||||
|
||||
if execution_mode == "workflow":
|
||||
# Execute workflow
|
||||
raw_result = await self._execute_agent_workflow(agent_id, workflow_id, message, model_name, timeout)
|
||||
else:
|
||||
# Execute agent with prompt
|
||||
raw_result = await self._execute_agent_prompt(agent_id, message, model_name, timeout)
|
||||
|
||||
# Process the output based on the requested mode
|
||||
if output_mode == "last_message":
|
||||
processed_result = self._extract_last_message(raw_result)
|
||||
else:
|
||||
processed_result = raw_result
|
||||
|
||||
# Apply token limiting
|
||||
final_result = self._truncate_from_end(processed_result, max_tokens)
|
||||
|
||||
# Return structured result
|
||||
response_data = {
|
||||
"agent_id": agent_id,
|
||||
"execution_mode": execution_mode,
|
||||
"workflow_id": workflow_id if execution_mode == "workflow" else None,
|
||||
"model_used": model_name,
|
||||
"output_mode": output_mode,
|
||||
"max_tokens": max_tokens,
|
||||
"result": final_result,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Agent call completed for agent {agent_id} in {execution_mode} mode")
|
||||
return self.success_response(response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling agent {agent_id}: {str(e)}")
|
||||
return self.fail_response(f"Error calling agent: {str(e)}")
|
||||
|
||||
async def _execute_agent_workflow(self, agent_id: str, workflow_id: str, message: str, model_name: str, timeout: int) -> str:
|
||||
"""Execute an agent workflow."""
|
||||
try:
|
||||
client = await self.thread_manager.db.client
|
||||
|
||||
# Verify workflow exists and is active
|
||||
workflow_result = await client.table('agent_workflows').select('*').eq('id', workflow_id).eq('agent_id', agent_id).execute()
|
||||
if not workflow_result.data:
|
||||
return f"Error: Workflow {workflow_id} not found for agent {agent_id}"
|
||||
|
||||
workflow = workflow_result.data[0]
|
||||
if workflow.get('status') != 'active':
|
||||
return f"Error: Workflow {workflow['name']} is not active (status: {workflow.get('status')})"
|
||||
|
||||
# Execute workflow through the execution service
|
||||
try:
|
||||
from triggers.execution_service import execute_workflow
|
||||
|
||||
# Execute the workflow with the provided message
|
||||
execution_result = await execute_workflow(
|
||||
workflow_id=workflow_id,
|
||||
agent_id=agent_id,
|
||||
input_data={"message": message},
|
||||
user_id=self.account_id
|
||||
)
|
||||
|
||||
if execution_result.get('success'):
|
||||
return execution_result.get('output', f"Workflow '{workflow['name']}' executed successfully")
|
||||
else:
|
||||
return f"Workflow execution failed: {execution_result.get('error', 'Unknown error')}"
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Execution service not available, using fallback workflow execution")
|
||||
# Fallback: Create a thread and run the agent with workflow context
|
||||
return await self._execute_agent_with_thread(agent_id, message, model_name, timeout, workflow)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing workflow {workflow_id}: {str(e)}")
|
||||
return f"Error executing workflow: {str(e)}"
|
||||
|
||||
async def _execute_agent_prompt(self, agent_id: str, message: str, model_name: str, timeout: int) -> str:
|
||||
"""Execute an agent with a custom prompt."""
|
||||
try:
|
||||
return await self._execute_agent_with_thread(agent_id, message, model_name, timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing agent prompt: {str(e)}")
|
||||
return f"Error executing agent: {str(e)}"
|
||||
|
||||
async def _execute_agent_with_thread(self, agent_id: str, message: str, model_name: str, timeout: int, workflow: Optional[dict] = None) -> str:
|
||||
"""Execute agent using thread-based approach."""
|
||||
try:
|
||||
# Import existing agent execution functions
|
||||
from agent.api import create_thread, add_message_to_thread, start_agent, AgentStartRequest
|
||||
|
||||
# Create thread name based on execution type
|
||||
thread_name = f"Workflow: {workflow['name']}" if workflow else "Agent-to-Agent Call"
|
||||
|
||||
# Create a new thread
|
||||
thread_response = await create_thread(name=thread_name, user_id=self.account_id)
|
||||
thread_id = thread_response.get('thread_id') if isinstance(thread_response, dict) else thread_response.thread_id
|
||||
|
||||
# Prepare message with workflow context if needed
|
||||
final_message = message
|
||||
if workflow:
|
||||
workflow_context = f"Executing workflow '{workflow['name']}'"
|
||||
if workflow.get('description'):
|
||||
workflow_context += f": {workflow['description']}"
|
||||
final_message = f"{workflow_context}\n\nUser message: {message}"
|
||||
|
||||
# Add the message to the thread
|
||||
await add_message_to_thread(
|
||||
thread_id=thread_id,
|
||||
message=final_message,
|
||||
user_id=self.account_id
|
||||
)
|
||||
|
||||
# Start the agent
|
||||
agent_request = AgentStartRequest(
|
||||
agent_id=agent_id,
|
||||
enable_thinking=False,
|
||||
stream=False,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
await start_agent(
|
||||
thread_id=thread_id,
|
||||
body=agent_request,
|
||||
user_id=self.account_id
|
||||
)
|
||||
|
||||
# Wait for agent completion and get response
|
||||
return await self._poll_for_completion(thread_id, timeout)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing agent with thread: {str(e)}")
|
||||
return f"Error executing agent: {str(e)}"
|
||||
|
||||
async def _poll_for_completion(self, thread_id: str, timeout: int) -> str:
|
||||
"""Poll for agent completion and return the result."""
|
||||
client = await self.thread_manager.db.client
|
||||
|
||||
poll_interval = 2
|
||||
elapsed = 0
|
||||
|
||||
while elapsed < timeout:
|
||||
# Check thread messages for agent response
|
||||
messages_result = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(5).execute()
|
||||
|
||||
if messages_result.data:
|
||||
# Look for the most recent agent message (not user message)
|
||||
for msg in messages_result.data:
|
||||
# Parse JSON content to check role
|
||||
content = msg.get('content')
|
||||
if content:
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
parsed_content = json.loads(content)
|
||||
else:
|
||||
parsed_content = content
|
||||
|
||||
if parsed_content.get('role') == 'assistant':
|
||||
return parsed_content.get('content', '')
|
||||
except:
|
||||
# If parsing fails, check if it's a direct assistant message
|
||||
if msg.get('type') == 'assistant':
|
||||
return content
|
||||
|
||||
# Check if agent run is complete by checking agent_runs table
|
||||
runs_result = await client.table('agent_runs').select('status, error').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
|
||||
|
||||
if runs_result.data:
|
||||
run = runs_result.data[0]
|
||||
if run['status'] in ['completed', 'failed', 'cancelled']:
|
||||
if run['status'] == 'failed':
|
||||
return f"Agent execution failed: {run.get('error', 'Unknown error')}"
|
||||
elif run['status'] == 'cancelled':
|
||||
return "Agent execution was cancelled"
|
||||
# If completed, continue to check for messages
|
||||
break
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
|
||||
# Timeout fallback - get latest messages
|
||||
messages_result = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(10).execute()
|
||||
|
||||
if messages_result.data:
|
||||
# Return the most recent assistant message or fallback message
|
||||
for msg in messages_result.data:
|
||||
# Parse JSON content to check role
|
||||
content = msg.get('content')
|
||||
if content:
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
parsed_content = json.loads(content)
|
||||
else:
|
||||
parsed_content = content
|
||||
|
||||
if parsed_content.get('role') == 'assistant':
|
||||
return parsed_content.get('content', '')
|
||||
except:
|
||||
# If parsing fails, check if it's a direct assistant message
|
||||
if msg.get('type') == 'assistant':
|
||||
return content
|
||||
|
||||
return f"Agent execution timed out after {timeout}s. Thread ID: {thread_id}"
|
||||
|
||||
return f"No response received from agent. Thread ID: {thread_id}"
|
||||
|
||||
async def _verify_agent_access(self, agent_id: str):
|
||||
"""Verify account has access to the agent."""
|
||||
try:
|
||||
client = await self.thread_manager.db.client
|
||||
result = await client.table('agents').select('agent_id').eq('agent_id', agent_id).eq('account_id', self.account_id).execute()
|
||||
|
||||
if not result.data:
|
||||
raise ValueError("Agent not found or access denied")
|
||||
except ValueError:
|
||||
# Re-raise ValueError for proper error messages
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Database error in verify_agent_access: {str(e)}")
|
||||
raise ValueError("Database connection error")
|
Loading…
Reference in New Issue