mirror of https://github.com/kortix-ai/suna.git
fix trigger import error
This commit is contained in:
parent
3fb032185c
commit
46010875d8
|
@ -76,14 +76,17 @@ async def run_agent(
|
||||||
if not sandbox_info.get('id'):
|
if not sandbox_info.get('id'):
|
||||||
raise ValueError(f"No sandbox found for project {project_id}")
|
raise ValueError(f"No sandbox found for project {project_id}")
|
||||||
|
|
||||||
# Initialize tools with project_id instead of sandbox object
|
enabled_tools = {}
|
||||||
# This ensures each tool independently verifies it's operating on the correct project
|
|
||||||
|
|
||||||
# Get enabled tools from agent config, or use defaults
|
|
||||||
enabled_tools = None
|
|
||||||
if agent_config and 'agentpress_tools' in agent_config:
|
if agent_config and 'agentpress_tools' in agent_config:
|
||||||
enabled_tools = agent_config['agentpress_tools']
|
raw_tools = agent_config['agentpress_tools']
|
||||||
logger.info(f"Using custom tool configuration from agent")
|
logger.info(f"Raw agentpress_tools type: {type(raw_tools)}, value: {raw_tools}")
|
||||||
|
|
||||||
|
if isinstance(raw_tools, dict):
|
||||||
|
enabled_tools = raw_tools
|
||||||
|
logger.info(f"Using custom tool configuration from agent")
|
||||||
|
else:
|
||||||
|
logger.warning(f"agentpress_tools is not a dict (got {type(raw_tools)}), using empty dict")
|
||||||
|
enabled_tools = {}
|
||||||
|
|
||||||
|
|
||||||
# Check if this is Suna (default agent) and enable builder capabilities for self-configuration
|
# Check if this is Suna (default agent) and enable builder capabilities for self-configuration
|
||||||
|
@ -142,23 +145,43 @@ async def run_agent(
|
||||||
thread_manager.add_tool(DataProvidersTool)
|
thread_manager.add_tool(DataProvidersTool)
|
||||||
else:
|
else:
|
||||||
logger.info("Custom agent specified - registering only enabled tools")
|
logger.info("Custom agent specified - registering only enabled tools")
|
||||||
|
|
||||||
|
# Final safety check: ensure enabled_tools is always a dictionary
|
||||||
|
if not isinstance(enabled_tools, dict):
|
||||||
|
logger.error(f"CRITICAL: enabled_tools is still not a dict at runtime! Type: {type(enabled_tools)}, Value: {enabled_tools}")
|
||||||
|
enabled_tools = {}
|
||||||
|
|
||||||
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
|
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
|
||||||
thread_manager.add_tool(MessageTool)
|
thread_manager.add_tool(MessageTool)
|
||||||
if enabled_tools.get('sb_shell_tool', {}).get('enabled', False):
|
|
||||||
|
def safe_tool_check(tool_name: str) -> bool:
|
||||||
|
try:
|
||||||
|
if not isinstance(enabled_tools, dict):
|
||||||
|
logger.error(f"enabled_tools is {type(enabled_tools)} at tool check for {tool_name}")
|
||||||
|
return False
|
||||||
|
tool_config = enabled_tools.get(tool_name, {})
|
||||||
|
if not isinstance(tool_config, dict):
|
||||||
|
return bool(tool_config) if isinstance(tool_config, bool) else False
|
||||||
|
return tool_config.get('enabled', False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception in tool check for {tool_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if safe_tool_check('sb_shell_tool'):
|
||||||
thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
if enabled_tools.get('sb_files_tool', {}).get('enabled', False):
|
if safe_tool_check('sb_files_tool'):
|
||||||
thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
if enabled_tools.get('sb_browser_tool', {}).get('enabled', False):
|
if safe_tool_check('sb_browser_tool'):
|
||||||
thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||||
if enabled_tools.get('sb_deploy_tool', {}).get('enabled', False):
|
if safe_tool_check('sb_deploy_tool'):
|
||||||
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
if enabled_tools.get('sb_expose_tool', {}).get('enabled', False):
|
if safe_tool_check('sb_expose_tool'):
|
||||||
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
if enabled_tools.get('web_search_tool', {}).get('enabled', False):
|
if safe_tool_check('web_search_tool'):
|
||||||
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
if enabled_tools.get('sb_vision_tool', {}).get('enabled', False):
|
if safe_tool_check('sb_vision_tool'):
|
||||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||||
if config.RAPID_API_KEY and enabled_tools.get('data_providers_tool', {}).get('enabled', False):
|
if config.RAPID_API_KEY and safe_tool_check('data_providers_tool'):
|
||||||
thread_manager.add_tool(DataProvidersTool)
|
thread_manager.add_tool(DataProvidersTool)
|
||||||
|
|
||||||
# Register MCP tool wrapper if agent has configured MCPs or custom MCPs
|
# Register MCP tool wrapper if agent has configured MCPs or custom MCPs
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .base_tool import AgentBuilderBaseTool
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
from triggers.support.factory import TriggerModuleFactory
|
from triggers import get_trigger_service
|
||||||
|
|
||||||
|
|
||||||
class TriggerTool(AgentBuilderBaseTool):
|
class TriggerTool(AgentBuilderBaseTool):
|
||||||
|
@ -124,8 +124,7 @@ class TriggerTool(AgentBuilderBaseTool):
|
||||||
else:
|
else:
|
||||||
trigger_config["agent_prompt"] = agent_prompt
|
trigger_config["agent_prompt"] = agent_prompt
|
||||||
|
|
||||||
trigger_db = DBConnection()
|
trigger_svc = get_trigger_service(self.db)
|
||||||
trigger_svc, _, _ = await TriggerModuleFactory.create_trigger_module(trigger_db)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trigger = await trigger_svc.create_trigger(
|
trigger = await trigger_svc.create_trigger(
|
||||||
|
@ -153,12 +152,12 @@ class TriggerTool(AgentBuilderBaseTool):
|
||||||
"message": result_message,
|
"message": result_message,
|
||||||
"trigger": {
|
"trigger": {
|
||||||
"id": trigger.trigger_id,
|
"id": trigger.trigger_id,
|
||||||
"name": trigger.config.name,
|
"name": trigger.name,
|
||||||
"description": trigger.config.description,
|
"description": trigger.description,
|
||||||
"cron_expression": cron_expression,
|
"cron_expression": cron_expression,
|
||||||
"execution_type": execution_type,
|
"execution_type": execution_type,
|
||||||
"is_active": trigger.config.is_active,
|
"is_active": trigger.is_active,
|
||||||
"created_at": trigger.metadata.created_at.isoformat()
|
"created_at": trigger.created_at.isoformat()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
|
@ -195,12 +194,11 @@ class TriggerTool(AgentBuilderBaseTool):
|
||||||
)
|
)
|
||||||
async def get_scheduled_triggers(self) -> ToolResult:
|
async def get_scheduled_triggers(self) -> ToolResult:
|
||||||
try:
|
try:
|
||||||
from triggers.core import TriggerType
|
from triggers import TriggerType
|
||||||
|
|
||||||
trigger_db = DBConnection()
|
trigger_svc = get_trigger_service(self.db)
|
||||||
trigger_manager = TriggerManager(trigger_db)
|
|
||||||
|
|
||||||
triggers = await trigger_manager.get_agent_triggers(self.agent_id)
|
triggers = await trigger_svc.get_agent_triggers(self.agent_id)
|
||||||
|
|
||||||
schedule_triggers = [t for t in triggers if t.trigger_type == TriggerType.SCHEDULE]
|
schedule_triggers = [t for t in triggers if t.trigger_type == TriggerType.SCHEDULE]
|
||||||
|
|
||||||
|
@ -282,10 +280,9 @@ class TriggerTool(AgentBuilderBaseTool):
|
||||||
)
|
)
|
||||||
async def delete_scheduled_trigger(self, trigger_id: str) -> ToolResult:
|
async def delete_scheduled_trigger(self, trigger_id: str) -> ToolResult:
|
||||||
try:
|
try:
|
||||||
trigger_db = DBConnection()
|
trigger_svc = get_trigger_service(self.db)
|
||||||
trigger_manager = TriggerManager(trigger_db)
|
|
||||||
|
|
||||||
trigger_config = await trigger_manager.get_trigger(trigger_id)
|
trigger_config = await trigger_svc.get_trigger(trigger_id)
|
||||||
|
|
||||||
if not trigger_config:
|
if not trigger_config:
|
||||||
return self.fail_response("Trigger not found")
|
return self.fail_response("Trigger not found")
|
||||||
|
@ -293,7 +290,7 @@ class TriggerTool(AgentBuilderBaseTool):
|
||||||
if trigger_config.agent_id != self.agent_id:
|
if trigger_config.agent_id != self.agent_id:
|
||||||
return self.fail_response("This trigger doesn't belong to the current agent")
|
return self.fail_response("This trigger doesn't belong to the current agent")
|
||||||
|
|
||||||
success = await trigger_manager.delete_trigger(trigger_id)
|
success = await trigger_svc.delete_trigger(trigger_id)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
return self.success_response({
|
return self.success_response({
|
||||||
|
@ -345,10 +342,9 @@ class TriggerTool(AgentBuilderBaseTool):
|
||||||
)
|
)
|
||||||
async def toggle_scheduled_trigger(self, trigger_id: str, is_active: bool) -> ToolResult:
|
async def toggle_scheduled_trigger(self, trigger_id: str, is_active: bool) -> ToolResult:
|
||||||
try:
|
try:
|
||||||
trigger_db = DBConnection()
|
trigger_svc = get_trigger_service(self.db)
|
||||||
trigger_manager = TriggerManager(trigger_db)
|
|
||||||
|
|
||||||
trigger_config = await trigger_manager.get_trigger(trigger_id)
|
trigger_config = await trigger_svc.get_trigger(trigger_id)
|
||||||
|
|
||||||
if not trigger_config:
|
if not trigger_config:
|
||||||
return self.fail_response("Trigger not found")
|
return self.fail_response("Trigger not found")
|
||||||
|
@ -356,7 +352,7 @@ class TriggerTool(AgentBuilderBaseTool):
|
||||||
if trigger_config.agent_id != self.agent_id:
|
if trigger_config.agent_id != self.agent_id:
|
||||||
return self.fail_response("This trigger doesn't belong to the current agent")
|
return self.fail_response("This trigger doesn't belong to the current agent")
|
||||||
|
|
||||||
updated_config = await trigger_manager.update_trigger(
|
updated_config = await trigger_svc.update_trigger(
|
||||||
trigger_id=trigger_id,
|
trigger_id=trigger_id,
|
||||||
is_active=is_active
|
is_active=is_active
|
||||||
)
|
)
|
||||||
|
|
|
@ -269,7 +269,7 @@ class AgentExecutor:
|
||||||
}
|
}
|
||||||
for mcp in active_version.custom_mcps
|
for mcp in active_version.custom_mcps
|
||||||
],
|
],
|
||||||
'agentpress_tools': active_version.tool_configuration.tools,
|
'agentpress_tools': active_version.tool_configuration.tools if isinstance(active_version.tool_configuration.tools, dict) else {},
|
||||||
'current_version_id': str(active_version.version_id),
|
'current_version_id': str(active_version.version_id),
|
||||||
'version_name': active_version.version_name
|
'version_name': active_version.version_name
|
||||||
}
|
}
|
||||||
|
@ -461,7 +461,7 @@ class WorkflowExecutor:
|
||||||
}
|
}
|
||||||
for mcp in active_version.custom_mcps
|
for mcp in active_version.custom_mcps
|
||||||
],
|
],
|
||||||
'agentpress_tools': active_version.tool_configuration.tools,
|
'agentpress_tools': active_version.tool_configuration.tools if isinstance(active_version.tool_configuration.tools, dict) else {},
|
||||||
'current_version_id': str(active_version.version_id),
|
'current_version_id': str(active_version.version_id),
|
||||||
'version_name': active_version.version_name
|
'version_name': active_version.version_name
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue