fix trigger import error

This commit is contained in:
Saumya 2025-07-28 18:05:59 +05:30
parent 3fb032185c
commit 46010875d8
3 changed files with 55 additions and 36 deletions

View File

@ -76,14 +76,17 @@ async def run_agent(
if not sandbox_info.get('id'): if not sandbox_info.get('id'):
raise ValueError(f"No sandbox found for project {project_id}") raise ValueError(f"No sandbox found for project {project_id}")
# Initialize tools with project_id instead of sandbox object enabled_tools = {}
# This ensures each tool independently verifies it's operating on the correct project
# Get enabled tools from agent config, or use defaults
enabled_tools = None
if agent_config and 'agentpress_tools' in agent_config: if agent_config and 'agentpress_tools' in agent_config:
enabled_tools = agent_config['agentpress_tools'] raw_tools = agent_config['agentpress_tools']
logger.info(f"Using custom tool configuration from agent") logger.info(f"Raw agentpress_tools type: {type(raw_tools)}, value: {raw_tools}")
if isinstance(raw_tools, dict):
enabled_tools = raw_tools
logger.info(f"Using custom tool configuration from agent")
else:
logger.warning(f"agentpress_tools is not a dict (got {type(raw_tools)}), using empty dict")
enabled_tools = {}
# Check if this is Suna (default agent) and enable builder capabilities for self-configuration # Check if this is Suna (default agent) and enable builder capabilities for self-configuration
@ -142,23 +145,43 @@ async def run_agent(
thread_manager.add_tool(DataProvidersTool) thread_manager.add_tool(DataProvidersTool)
else: else:
logger.info("Custom agent specified - registering only enabled tools") logger.info("Custom agent specified - registering only enabled tools")
# Final safety check: ensure enabled_tools is always a dictionary
if not isinstance(enabled_tools, dict):
logger.error(f"CRITICAL: enabled_tools is still not a dict at runtime! Type: {type(enabled_tools)}, Value: {enabled_tools}")
enabled_tools = {}
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager) thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
thread_manager.add_tool(MessageTool) thread_manager.add_tool(MessageTool)
if enabled_tools.get('sb_shell_tool', {}).get('enabled', False):
def safe_tool_check(tool_name: str) -> bool:
try:
if not isinstance(enabled_tools, dict):
logger.error(f"enabled_tools is {type(enabled_tools)} at tool check for {tool_name}")
return False
tool_config = enabled_tools.get(tool_name, {})
if not isinstance(tool_config, dict):
return bool(tool_config) if isinstance(tool_config, bool) else False
return tool_config.get('enabled', False)
except Exception as e:
logger.error(f"Exception in tool check for {tool_name}: {e}")
return False
if safe_tool_check('sb_shell_tool'):
thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
if enabled_tools.get('sb_files_tool', {}).get('enabled', False): if safe_tool_check('sb_files_tool'):
thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager)
if enabled_tools.get('sb_browser_tool', {}).get('enabled', False): if safe_tool_check('sb_browser_tool'):
thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
if enabled_tools.get('sb_deploy_tool', {}).get('enabled', False): if safe_tool_check('sb_deploy_tool'):
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
if enabled_tools.get('sb_expose_tool', {}).get('enabled', False): if safe_tool_check('sb_expose_tool'):
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
if enabled_tools.get('web_search_tool', {}).get('enabled', False): if safe_tool_check('web_search_tool'):
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
if enabled_tools.get('sb_vision_tool', {}).get('enabled', False): if safe_tool_check('sb_vision_tool'):
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
if config.RAPID_API_KEY and enabled_tools.get('data_providers_tool', {}).get('enabled', False): if config.RAPID_API_KEY and safe_tool_check('data_providers_tool'):
thread_manager.add_tool(DataProvidersTool) thread_manager.add_tool(DataProvidersTool)
# Register MCP tool wrapper if agent has configured MCPs or custom MCPs # Register MCP tool wrapper if agent has configured MCPs or custom MCPs

View File

@ -6,7 +6,7 @@ from .base_tool import AgentBuilderBaseTool
from utils.logger import logger from utils.logger import logger
from datetime import datetime from datetime import datetime
from services.supabase import DBConnection from services.supabase import DBConnection
from triggers.support.factory import TriggerModuleFactory from triggers import get_trigger_service
class TriggerTool(AgentBuilderBaseTool): class TriggerTool(AgentBuilderBaseTool):
@ -124,8 +124,7 @@ class TriggerTool(AgentBuilderBaseTool):
else: else:
trigger_config["agent_prompt"] = agent_prompt trigger_config["agent_prompt"] = agent_prompt
trigger_db = DBConnection() trigger_svc = get_trigger_service(self.db)
trigger_svc, _, _ = await TriggerModuleFactory.create_trigger_module(trigger_db)
try: try:
trigger = await trigger_svc.create_trigger( trigger = await trigger_svc.create_trigger(
@ -153,12 +152,12 @@ class TriggerTool(AgentBuilderBaseTool):
"message": result_message, "message": result_message,
"trigger": { "trigger": {
"id": trigger.trigger_id, "id": trigger.trigger_id,
"name": trigger.config.name, "name": trigger.name,
"description": trigger.config.description, "description": trigger.description,
"cron_expression": cron_expression, "cron_expression": cron_expression,
"execution_type": execution_type, "execution_type": execution_type,
"is_active": trigger.config.is_active, "is_active": trigger.is_active,
"created_at": trigger.metadata.created_at.isoformat() "created_at": trigger.created_at.isoformat()
} }
}) })
except ValueError as ve: except ValueError as ve:
@ -195,12 +194,11 @@ class TriggerTool(AgentBuilderBaseTool):
) )
async def get_scheduled_triggers(self) -> ToolResult: async def get_scheduled_triggers(self) -> ToolResult:
try: try:
from triggers.core import TriggerType from triggers import TriggerType
trigger_db = DBConnection() trigger_svc = get_trigger_service(self.db)
trigger_manager = TriggerManager(trigger_db)
triggers = await trigger_manager.get_agent_triggers(self.agent_id) triggers = await trigger_svc.get_agent_triggers(self.agent_id)
schedule_triggers = [t for t in triggers if t.trigger_type == TriggerType.SCHEDULE] schedule_triggers = [t for t in triggers if t.trigger_type == TriggerType.SCHEDULE]
@ -282,10 +280,9 @@ class TriggerTool(AgentBuilderBaseTool):
) )
async def delete_scheduled_trigger(self, trigger_id: str) -> ToolResult: async def delete_scheduled_trigger(self, trigger_id: str) -> ToolResult:
try: try:
trigger_db = DBConnection() trigger_svc = get_trigger_service(self.db)
trigger_manager = TriggerManager(trigger_db)
trigger_config = await trigger_manager.get_trigger(trigger_id) trigger_config = await trigger_svc.get_trigger(trigger_id)
if not trigger_config: if not trigger_config:
return self.fail_response("Trigger not found") return self.fail_response("Trigger not found")
@ -293,7 +290,7 @@ class TriggerTool(AgentBuilderBaseTool):
if trigger_config.agent_id != self.agent_id: if trigger_config.agent_id != self.agent_id:
return self.fail_response("This trigger doesn't belong to the current agent") return self.fail_response("This trigger doesn't belong to the current agent")
success = await trigger_manager.delete_trigger(trigger_id) success = await trigger_svc.delete_trigger(trigger_id)
if success: if success:
return self.success_response({ return self.success_response({
@ -345,10 +342,9 @@ class TriggerTool(AgentBuilderBaseTool):
) )
async def toggle_scheduled_trigger(self, trigger_id: str, is_active: bool) -> ToolResult: async def toggle_scheduled_trigger(self, trigger_id: str, is_active: bool) -> ToolResult:
try: try:
trigger_db = DBConnection() trigger_svc = get_trigger_service(self.db)
trigger_manager = TriggerManager(trigger_db)
trigger_config = await trigger_manager.get_trigger(trigger_id) trigger_config = await trigger_svc.get_trigger(trigger_id)
if not trigger_config: if not trigger_config:
return self.fail_response("Trigger not found") return self.fail_response("Trigger not found")
@ -356,7 +352,7 @@ class TriggerTool(AgentBuilderBaseTool):
if trigger_config.agent_id != self.agent_id: if trigger_config.agent_id != self.agent_id:
return self.fail_response("This trigger doesn't belong to the current agent") return self.fail_response("This trigger doesn't belong to the current agent")
updated_config = await trigger_manager.update_trigger( updated_config = await trigger_svc.update_trigger(
trigger_id=trigger_id, trigger_id=trigger_id,
is_active=is_active is_active=is_active
) )

View File

@ -269,7 +269,7 @@ class AgentExecutor:
} }
for mcp in active_version.custom_mcps for mcp in active_version.custom_mcps
], ],
'agentpress_tools': active_version.tool_configuration.tools, 'agentpress_tools': active_version.tool_configuration.tools if isinstance(active_version.tool_configuration.tools, dict) else {},
'current_version_id': str(active_version.version_id), 'current_version_id': str(active_version.version_id),
'version_name': active_version.version_name 'version_name': active_version.version_name
} }
@ -461,7 +461,7 @@ class WorkflowExecutor:
} }
for mcp in active_version.custom_mcps for mcp in active_version.custom_mcps
], ],
'agentpress_tools': active_version.tool_configuration.tools, 'agentpress_tools': active_version.tool_configuration.tools if isinstance(active_version.tool_configuration.tools, dict) else {},
'current_version_id': str(active_version.version_id), 'current_version_id': str(active_version.version_id),
'version_name': active_version.version_name 'version_name': active_version.version_name
} }