diff --git a/backend/agent/run.py b/backend/agent/run.py index 4828e55e..97d21f64 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -1,9 +1,9 @@ import os import json import asyncio -from typing import Optional +from typing import Optional, Dict, List, Any, AsyncGenerator +from dataclasses import dataclass -# from agent.tools.message_tool import MessageTool from agent.tools.message_tool import MessageTool from agent.tools.sb_deploy_tool import SandboxDeployTool from agent.tools.sb_expose_tool import SandboxExposeTool @@ -27,202 +27,129 @@ from agent.tools.sb_vision_tool import SandboxVisionTool from agent.tools.sb_image_edit_tool import SandboxImageEditTool from services.langfuse import langfuse from langfuse.client import StatefulTraceClient -from services.langfuse import langfuse from agent.gemini_prompt import get_gemini_system_prompt from agent.tools.mcp_tool_wrapper import MCPToolWrapper from agentpress.tool import SchemaType load_dotenv() -async def run_agent( - thread_id: str, - project_id: str, - stream: bool, - thread_manager: Optional[ThreadManager] = None, - native_max_auto_continues: int = 25, - max_iterations: int = 100, - model_name: str = "anthropic/claude-sonnet-4-20250514", - enable_thinking: Optional[bool] = False, - reasoning_effort: Optional[str] = 'low', - enable_context_manager: bool = True, - agent_config: Optional[dict] = None, - trace: Optional[StatefulTraceClient] = None, - is_agent_builder: Optional[bool] = False, + +@dataclass +class AgentConfig: + thread_id: str + project_id: str + stream: bool + native_max_auto_continues: int = 25 + max_iterations: int = 100 + model_name: str = "anthropic/claude-sonnet-4-20250514" + enable_thinking: Optional[bool] = False + reasoning_effort: Optional[str] = 'low' + enable_context_manager: bool = True + agent_config: Optional[dict] = None + trace: Optional[StatefulTraceClient] = None + is_agent_builder: Optional[bool] = False target_agent_id: Optional[str] = None -): - """Run the development agent with specified configuration.""" - logger.info(f"šŸš€ Starting agent with model: {model_name}") - if agent_config: - logger.info(f"Using custom agent: {agent_config.get('name', 'Unknown')}") - if not trace: - trace = langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id}) - thread_manager = ThreadManager(trace=trace, is_agent_builder=is_agent_builder or False, target_agent_id=target_agent_id, agent_config=agent_config) - client = await thread_manager.db.client - - # Get account ID from thread for billing checks - account_id = await get_account_id_from_thread(client, thread_id) - if not account_id: - raise ValueError("Could not determine account ID for thread") - - # Get sandbox info from project - project = await client.table('projects').select('*').eq('project_id', project_id).execute() - if not project.data or len(project.data) == 0: - raise ValueError(f"Project {project_id} not found") - - project_data = project.data[0] - sandbox_info = project_data.get('sandbox', {}) - if not sandbox_info.get('id'): - raise ValueError(f"No sandbox found for project {project_id}") - - enabled_tools = {} - if agent_config and 'agentpress_tools' in agent_config: - raw_tools = agent_config['agentpress_tools'] - logger.info(f"Raw agentpress_tools type: {type(raw_tools)}, value: {raw_tools}") - - if isinstance(raw_tools, dict): - enabled_tools = raw_tools - logger.info(f"Using custom tool configuration from agent") - else: - logger.warning(f"agentpress_tools is not a dict (got {type(raw_tools)}), using empty dict") - enabled_tools = {} +class ToolManager: + def __init__(self, thread_manager: ThreadManager, project_id: str, thread_id: str): + self.thread_manager = thread_manager + self.project_id = project_id + self.thread_id = thread_id - - # Check if this is Suna (default agent) and enable builder capabilities for self-configuration - if agent_config and agent_config.get('is_suna_default', False): - logger.info("Detected Suna default agent - enabling self-configuration capabilities") - - from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool - from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool - from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool - from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool - from agent.tools.agent_builder_tools.trigger_tool import TriggerTool - from services.supabase import DBConnection - db = DBConnection() - - # Use Suna's own agent ID for self-configuration - suna_agent_id = agent_config['agent_id'] - - thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id) - thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id) - thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id) - thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id) - thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id) - - logger.info(f"Enabled Suna self-configuration with agent ID: {suna_agent_id}") - - # Original agent builder logic for custom agents (preserved) - if is_agent_builder: - from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool - from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool - from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool - from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool - from agent.tools.agent_builder_tools.trigger_tool import TriggerTool - from services.supabase import DBConnection - db = DBConnection() - - thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - - - if enabled_tools is None: - logger.info("No agent specified - registering all tools for full Suna capabilities") - thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) - thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager) - thread_manager.add_tool(MessageTool) - thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxImageEditTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + def register_all_tools(self): + self.thread_manager.add_tool(SandboxShellTool, project_id=self.project_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(SandboxFilesTool, project_id=self.project_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(SandboxBrowserTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(SandboxDeployTool, project_id=self.project_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(SandboxExposeTool, project_id=self.project_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(ExpandMessageTool, thread_id=self.thread_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(MessageTool) + self.thread_manager.add_tool(SandboxWebSearchTool, project_id=self.project_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(SandboxVisionTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(SandboxImageEditTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager) if config.RAPID_API_KEY: - thread_manager.add_tool(DataProvidersTool) - else: - logger.info("Custom agent specified - registering only enabled tools") + self.thread_manager.add_tool(DataProvidersTool) + + def register_agent_builder_tools(self, agent_id: str): + from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool + from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool + from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool + from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool + from agent.tools.agent_builder_tools.trigger_tool import TriggerTool + from services.supabase import DBConnection - # Final safety check: ensure enabled_tools is always a dictionary - if not isinstance(enabled_tools, dict): - logger.error(f"CRITICAL: enabled_tools is still not a dict at runtime! Type: {type(enabled_tools)}, Value: {enabled_tools}") - enabled_tools = {} - - thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager) - thread_manager.add_tool(MessageTool) + db = DBConnection() + self.thread_manager.add_tool(AgentConfigTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id) + self.thread_manager.add_tool(MCPSearchTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id) + self.thread_manager.add_tool(CredentialProfileTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id) + self.thread_manager.add_tool(WorkflowTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id) + self.thread_manager.add_tool(TriggerTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id) + + def register_custom_tools(self, enabled_tools: Dict[str, Any]): + self.thread_manager.add_tool(ExpandMessageTool, thread_id=self.thread_id, thread_manager=self.thread_manager) + self.thread_manager.add_tool(MessageTool) def safe_tool_check(tool_name: str) -> bool: try: if not isinstance(enabled_tools, dict): - logger.error(f"enabled_tools is {type(enabled_tools)} at tool check for {tool_name}") return False tool_config = enabled_tools.get(tool_name, {}) if not isinstance(tool_config, dict): return bool(tool_config) if isinstance(tool_config, bool) else False return tool_config.get('enabled', False) - except Exception as e: - logger.error(f"Exception in tool check for {tool_name}: {e}") + except Exception: return False if safe_tool_check('sb_shell_tool'): - thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager) + self.thread_manager.add_tool(SandboxShellTool, project_id=self.project_id, thread_manager=self.thread_manager) if safe_tool_check('sb_files_tool'): - thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager) + self.thread_manager.add_tool(SandboxFilesTool, project_id=self.project_id, thread_manager=self.thread_manager) if safe_tool_check('sb_browser_tool'): - thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + self.thread_manager.add_tool(SandboxBrowserTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager) if safe_tool_check('sb_deploy_tool'): - thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) + self.thread_manager.add_tool(SandboxDeployTool, project_id=self.project_id, thread_manager=self.thread_manager) if safe_tool_check('sb_expose_tool'): - thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) + self.thread_manager.add_tool(SandboxExposeTool, project_id=self.project_id, thread_manager=self.thread_manager) if safe_tool_check('web_search_tool'): - thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager) + self.thread_manager.add_tool(SandboxWebSearchTool, project_id=self.project_id, thread_manager=self.thread_manager) if safe_tool_check('sb_vision_tool'): - thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + self.thread_manager.add_tool(SandboxVisionTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager) if config.RAPID_API_KEY and safe_tool_check('data_providers_tool'): - thread_manager.add_tool(DataProvidersTool) + self.thread_manager.add_tool(DataProvidersTool) - # Register MCP tool wrapper if agent has configured MCPs or custom MCPs - mcp_wrapper_instance = None - if agent_config: - # Merge configured_mcps and custom_mcps + +class MCPManager: + def __init__(self, thread_manager: ThreadManager, account_id: str): + self.thread_manager = thread_manager + self.account_id = account_id + + async def register_mcp_tools(self, agent_config: dict) -> Optional[MCPToolWrapper]: all_mcps = [] - # Add standard configured MCPs if agent_config.get('configured_mcps'): all_mcps.extend(agent_config['configured_mcps']) - # Add custom MCPs if agent_config.get('custom_mcps'): for custom_mcp in agent_config['custom_mcps']: - # Transform custom MCP to standard format custom_type = custom_mcp.get('customType', custom_mcp.get('type', 'sse')) - # For Pipedream MCPs, ensure we have the user ID and proper config if custom_type == 'pipedream': - # Get user ID from thread if 'config' not in custom_mcp: custom_mcp['config'] = {} - # Get external_user_id from profile if not present if not custom_mcp['config'].get('external_user_id'): profile_id = custom_mcp['config'].get('profile_id') if profile_id: try: - from pipedream.profiles import get_profile_manager + from pipedream.facade import get_profile_manager from services.supabase import DBConnection profile_db = DBConnection() profile_manager = get_profile_manager(profile_db) - # Get the profile to retrieve external_user_id - profile = await profile_manager.get_profile(account_id, profile_id) + profile = await profile_manager.get_profile(self.account_id, profile_id) if profile: custom_mcp['config']['external_user_id'] = profile.external_user_id - logger.info(f"Retrieved external_user_id from profile {profile_id} for Pipedream MCP") - else: - logger.error(f"Could not find profile {profile_id} for Pipedream MCP") except Exception as e: logger.error(f"Error retrieving external_user_id from profile {profile_id}: {e}") @@ -240,203 +167,130 @@ async def run_agent( } all_mcps.append(mcp_config) - if all_mcps: - logger.info(f"Registering MCP tool wrapper for {len(all_mcps)} MCP servers (including {len(agent_config.get('custom_mcps', []))} custom)") - thread_manager.add_tool(MCPToolWrapper, mcp_configs=all_mcps) - - for tool_name, tool_info in thread_manager.tool_registry.tools.items(): - if isinstance(tool_info['instance'], MCPToolWrapper): - mcp_wrapper_instance = tool_info['instance'] - break - - if mcp_wrapper_instance: - try: - await mcp_wrapper_instance.initialize_and_register_tools() - logger.info("MCP tools initialized successfully") - updated_schemas = mcp_wrapper_instance.get_schemas() - logger.info(f"MCP wrapper has {len(updated_schemas)} schemas available") - for method_name, schema_list in updated_schemas.items(): - if method_name != 'call_mcp_tool': - for schema in schema_list: - if schema.schema_type == SchemaType.OPENAPI: - thread_manager.tool_registry.tools[method_name] = { - "instance": mcp_wrapper_instance, - "schema": schema - } - logger.info(f"Registered dynamic MCP tool: {method_name}") - - # Log all registered tools for debugging - all_tools = list(thread_manager.tool_registry.tools.keys()) - logger.info(f"All registered tools after MCP initialization: {all_tools}") - mcp_tools = [tool for tool in all_tools if tool not in ['call_mcp_tool', 'sb_files_tool', 'message_tool', 'expand_msg_tool', 'web_search_tool', 'sb_shell_tool', 'sb_vision_tool', 'sb_browser_tool', 'computer_use_tool', 'data_providers_tool', 'sb_deploy_tool', 'sb_expose_tool', 'update_agent_tool']] - logger.info(f"MCP tools registered: {mcp_tools}") - - except Exception as e: - logger.error(f"Failed to initialize MCP tools: {e}") - # Continue without MCP tools if initialization fails - - # Prepare system prompt - # First, get the default system prompt - if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower(): - default_system_content = get_gemini_system_prompt() - else: - # Use the original prompt - the LLM can only use tools that are registered - default_system_content = get_system_prompt() + if not all_mcps: + return None - # Add sample response for non-anthropic models - if "anthropic" not in model_name.lower(): - sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt') - with open(sample_response_path, 'r') as file: - sample_response = file.read() - default_system_content = default_system_content + "\n\n " + sample_response + "" - - # Handle custom agent system prompt - if agent_config and agent_config.get('system_prompt'): - custom_system_prompt = agent_config['system_prompt'].strip() - # Completely replace the default system prompt with the custom one - # This prevents confusion and tool hallucination - system_content = custom_system_prompt - logger.info(f"Using ONLY custom agent system prompt for: {agent_config.get('name', 'Unknown')}") - elif is_agent_builder: - system_content = get_agent_builder_prompt() - logger.info("Using agent builder system prompt") - else: - # Use just the default system prompt - system_content = default_system_content - logger.info("Using default system prompt only") - - if await is_enabled("knowledge_base"): + mcp_wrapper_instance = MCPToolWrapper(mcp_configs=all_mcps) try: - from services.supabase import DBConnection - kb_db = DBConnection() - kb_client = await kb_db.client + await mcp_wrapper_instance.initialize_and_register_tools() - current_agent_id = agent_config.get('agent_id') if agent_config else None - - kb_result = await kb_client.rpc('get_combined_knowledge_base_context', { - 'p_thread_id': thread_id, - 'p_agent_id': current_agent_id, - 'p_max_tokens': 4000 - }).execute() - - if kb_result.data and kb_result.data.strip(): - logger.info(f"Adding combined knowledge base context to system prompt for thread {thread_id}, agent {current_agent_id}") - system_content += "\n\n" + kb_result.data - else: - logger.debug(f"No knowledge base context found for thread {thread_id}, agent {current_agent_id}") - - except Exception as e: - logger.error(f"Error retrieving knowledge base context for thread {thread_id}: {e}") - - - if agent_config and (agent_config.get('configured_mcps') or agent_config.get('custom_mcps')) and mcp_wrapper_instance and mcp_wrapper_instance._initialized: - mcp_info = "\n\n--- MCP Tools Available ---\n" - mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n" - mcp_info += "MCP tools can be called directly using their native function names in the standard function calling format:\n" - mcp_info += '\n' - mcp_info += '\n' - mcp_info += 'value1\n' - mcp_info += 'value2\n' - mcp_info += '\n' - mcp_info += '\n\n' - - # List available MCP tools - mcp_info += "Available MCP tools:\n" - try: - # Get the actual registered schemas from the wrapper - registered_schemas = mcp_wrapper_instance.get_schemas() - for method_name, schema_list in registered_schemas.items(): - if method_name == 'call_mcp_tool': - continue # Skip the fallback method - - # Get the schema info + updated_schemas = mcp_wrapper_instance.get_schemas() + for method_name, schema_list in updated_schemas.items(): for schema in schema_list: - if schema.schema_type == SchemaType.OPENAPI: - func_info = schema.schema.get('function', {}) - description = func_info.get('description', 'No description available') - # Extract server name from description if available - server_match = description.find('(MCP Server: ') - if server_match != -1: - server_end = description.find(')', server_match) - server_info = description[server_match:server_end+1] - else: - server_info = '' - - mcp_info += f"- **{method_name}**: {description}\n" - - # Show parameter info - params = func_info.get('parameters', {}) - props = params.get('properties', {}) - if props: - mcp_info += f" Parameters: {', '.join(props.keys())}\n" - + self.thread_manager.tool_registry.tools[method_name] = { + "instance": mcp_wrapper_instance, + "schema": schema + } + + return mcp_wrapper_instance except Exception as e: - logger.error(f"Error listing MCP tools: {e}") - mcp_info += "- Error loading MCP tool list\n" + logger.error(f"Failed to initialize MCP tools: {e}") + return None + + +class PromptManager: + @staticmethod + async def build_system_prompt(model_name: str, agent_config: Optional[dict], + is_agent_builder: bool, thread_id: str, + mcp_wrapper_instance: Optional[MCPToolWrapper]) -> dict: - # Add critical instructions for using search results - mcp_info += "\n🚨 CRITICAL MCP TOOL RESULT INSTRUCTIONS 🚨\n" - mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n" - mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n" - mcp_info += "2. For search tools: ONLY cite URLs, sources, and information from the actual search results\n" - mcp_info += "3. For any tool: Base your response entirely on the tool's output - do NOT add external information\n" - mcp_info += "4. DO NOT fabricate, invent, hallucinate, or make up any sources, URLs, or data\n" - mcp_info += "5. If you need more information, call the MCP tool again with different parameters\n" - mcp_info += "6. When writing reports/summaries: Reference ONLY the data from MCP tool results\n" - mcp_info += "7. If the MCP tool doesn't return enough information, explicitly state this limitation\n" - mcp_info += "8. Always double-check that every fact, URL, and reference comes from the MCP tool output\n" - mcp_info += "\nIMPORTANT: MCP tool results are your PRIMARY and ONLY source of truth for external data!\n" - mcp_info += "NEVER supplement MCP results with your training data or make assumptions beyond what the tools provide.\n" + if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower(): + default_system_content = get_gemini_system_prompt() + else: + default_system_content = get_system_prompt() - system_content += mcp_info + if "anthropic" not in model_name.lower(): + sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt') + with open(sample_response_path, 'r') as file: + sample_response = file.read() + default_system_content = default_system_content + "\n\n " + sample_response + "" + + if agent_config and agent_config.get('system_prompt'): + system_content = agent_config['system_prompt'].strip() + elif is_agent_builder: + system_content = get_agent_builder_prompt() + else: + system_content = default_system_content + + if await is_enabled("knowledge_base"): + try: + from services.supabase import DBConnection + kb_db = DBConnection() + kb_client = await kb_db.client + + current_agent_id = agent_config.get('agent_id') if agent_config else None + + kb_result = await kb_client.rpc('get_combined_knowledge_base_context', { + 'p_thread_id': thread_id, + 'p_agent_id': current_agent_id, + 'p_max_tokens': 4000 + }).execute() + + if kb_result.data and kb_result.data.strip(): + system_content += "\n\n" + kb_result.data + + except Exception as e: + logger.error(f"Error retrieving knowledge base context for thread {thread_id}: {e}") - system_message = { "role": "system", "content": system_content } + if agent_config and (agent_config.get('configured_mcps') or agent_config.get('custom_mcps')) and mcp_wrapper_instance and mcp_wrapper_instance._initialized: + mcp_info = "\n\n--- MCP Tools Available ---\n" + mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n" + mcp_info += "MCP tools can be called directly using their native function names in the standard function calling format:\n" + mcp_info += '\n' + mcp_info += '\n' + mcp_info += 'value1\n' + mcp_info += 'value2\n' + mcp_info += '\n' + mcp_info += '\n\n' + + mcp_info += "Available MCP tools:\n" + try: + registered_schemas = mcp_wrapper_instance.get_schemas() + for method_name, schema_list in registered_schemas.items(): + for schema in schema_list: + if schema.schema_type == SchemaType.OPENAPI: + func_info = schema.schema.get('function', {}) + description = func_info.get('description', 'No description available') + mcp_info += f"- **{method_name}**: {description}\n" + + params = func_info.get('parameters', {}) + props = params.get('properties', {}) + if props: + mcp_info += f" Parameters: {', '.join(props.keys())}\n" + + except Exception as e: + logger.error(f"Error listing MCP tools: {e}") + mcp_info += "- Error loading MCP tool list\n" + + mcp_info += "\n🚨 CRITICAL MCP TOOL RESULT INSTRUCTIONS 🚨\n" + mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n" + mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n" + mcp_info += "2. For search tools: ONLY cite URLs, sources, and information from the actual search results\n" + mcp_info += "3. For any tool: Base your response entirely on the tool's output - do NOT add external information\n" + mcp_info += "4. DO NOT fabricate, invent, hallucinate, or make up any sources, URLs, or data\n" + mcp_info += "5. If you need more information, call the MCP tool again with different parameters\n" + mcp_info += "6. When writing reports/summaries: Reference ONLY the data from MCP tool results\n" + mcp_info += "7. If the MCP tool doesn't return enough information, explicitly state this limitation\n" + mcp_info += "8. Always double-check that every fact, URL, and reference comes from the MCP tool output\n" + mcp_info += "\nIMPORTANT: MCP tool results are your PRIMARY and ONLY source of truth for external data!\n" + mcp_info += "NEVER supplement MCP results with your training data or make assumptions beyond what the tools provide.\n" + + system_content += mcp_info - iteration_count = 0 - continue_execution = True + return {"role": "system", "content": system_content} - latest_user_message = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute() - if latest_user_message.data and len(latest_user_message.data) > 0: - data = latest_user_message.data[0]['content'] - if isinstance(data, str): - data = json.loads(data) - if trace: - trace.update(input=data['content']) - while continue_execution and iteration_count < max_iterations: - iteration_count += 1 - logger.info(f"šŸ”„ Running iteration {iteration_count} of {max_iterations}...") +class MessageManager: + def __init__(self, client, thread_id: str, model_name: str, trace: Optional[StatefulTraceClient]): + self.client = client + self.thread_id = thread_id + self.model_name = model_name + self.trace = trace + + async def build_temporary_message(self) -> Optional[dict]: + temp_message_content_list = [] - # Billing check on each iteration - still needed within the iterations - can_run, message, subscription = await check_billing_status(client, account_id) - if not can_run: - error_msg = f"Billing limit reached: {message}" - if trace: - trace.event(name="billing_limit_reached", level="ERROR", status_message=(f"{error_msg}")) - # Yield a special message to indicate billing limit reached - yield { - "type": "status", - "status": "stopped", - "message": error_msg - } - break - # Check if last message is from assistant using direct Supabase query - latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute() - if latest_message.data and len(latest_message.data) > 0: - message_type = latest_message.data[0].get('type') - if message_type == 'assistant': - logger.info(f"Last message was from assistant, stopping execution") - if trace: - trace.event(name="last_message_from_assistant", level="DEFAULT", status_message=(f"Last message was from assistant, stopping execution")) - continue_execution = False - break - - # ---- Temporary Message Handling (Browser State & Image Context) ---- - temporary_message = None - temp_message_content_list = [] # List to hold text/image blocks - - # Get the latest browser_state message - latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() + latest_browser_state_msg = await self.client.table('messages').select('*').eq('thread_id', self.thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() if latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0: try: browser_content = latest_browser_state_msg.data[0]["content"] @@ -445,7 +299,6 @@ async def run_agent( screenshot_base64 = browser_content.get("screenshot_base64") screenshot_url = browser_content.get("image_url") - # Create a copy of the browser state without screenshot data browser_state_text = browser_content.copy() browser_state_text.pop('screenshot_base64', None) browser_state_text.pop('image_url', None) @@ -456,9 +309,7 @@ async def run_agent( "text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}" }) - # Only add screenshot if model is not Gemini, Anthropic, or OpenAI - if 'gemini' in model_name.lower() or 'anthropic' in model_name.lower() or 'openai' in model_name.lower(): - # Prioritize screenshot_url if available + if 'gemini' in self.model_name.lower() or 'anthropic' in self.model_name.lower() or 'openai' in self.model_name.lower(): if screenshot_url: temp_message_content_list.append({ "type": "image_url", @@ -467,34 +318,18 @@ async def run_agent( "format": "image/jpeg" } }) - if trace: - trace.event(name="screenshot_url_added_to_temporary_message", level="DEFAULT", status_message=(f"Screenshot URL added to temporary message.")) elif screenshot_base64: - # Fallback to base64 if URL not available temp_message_content_list.append({ "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{screenshot_base64}", } }) - if trace: - trace.event(name="screenshot_base64_added_to_temporary_message", level="WARNING", status_message=(f"Screenshot base64 added to temporary message. Prefer screenshot_url if available.")) - else: - logger.warning("Browser state found but no screenshot data.") - if trace: - trace.event(name="browser_state_found_but_no_screenshot_data", level="WARNING", status_message=(f"Browser state found but no screenshot data.")) - else: - logger.warning("Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message.") - if trace: - trace.event(name="model_is_gemini_anthropic_or_openai", level="WARNING", status_message=(f"Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message.")) except Exception as e: logger.error(f"Error parsing browser state: {e}") - if trace: - trace.event(name="error_parsing_browser_state", level="ERROR", status_message=(f"{e}")) - # Get the latest image_context message (NEW) - latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute() + latest_image_context_msg = await self.client.table('messages').select('*').eq('thread_id', self.thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute() if latest_image_context_msg.data and len(latest_image_context_msg.data) > 0: try: image_context_content = latest_image_context_msg.data[0]["content"] if isinstance(latest_image_context_msg.data[0]["content"], dict) else json.loads(latest_image_context_msg.data[0]["content"]) @@ -513,208 +348,303 @@ async def run_agent( "url": f"data:{mime_type};base64,{base64_image}", } }) - else: - logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.") - await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute() + await self.client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute() except Exception as e: logger.error(f"Error parsing image context: {e}") - if trace: - trace.event(name="error_parsing_image_context", level="ERROR", status_message=(f"{e}")) - # If we have any content, construct the temporary_message if temp_message_content_list: - temporary_message = {"role": "user", "content": temp_message_content_list} - # logger.debug(f"Constructed temporary message with {len(temp_message_content_list)} content blocks.") - # ---- End Temporary Message Handling ---- + return {"role": "user", "content": temp_message_content_list} + return None - # Set max_tokens based on model - max_tokens = None - if "sonnet" in model_name.lower(): - # Claude 3.5 Sonnet has a limit of 8192 tokens - max_tokens = 8192 - elif "gpt-4" in model_name.lower(): - max_tokens = 4096 - elif "gemini-2.5-pro" in model_name.lower(): - # Gemini 2.5 Pro has 64k max output tokens - max_tokens = 64000 - elif "kimi-k2" in model_name.lower(): - # Kimi-K2 has 120K context, set reasonable max output tokens - max_tokens = 8192 + +class AgentRunner: + def __init__(self, config: AgentConfig): + self.config = config + + async def setup(self): + if not self.config.trace: + self.config.trace = langfuse.trace(name="run_agent", session_id=self.config.thread_id, metadata={"project_id": self.config.project_id}) + + self.thread_manager = ThreadManager( + trace=self.config.trace, + is_agent_builder=self.config.is_agent_builder or False, + target_agent_id=self.config.target_agent_id, + agent_config=self.config.agent_config + ) + + self.client = await self.thread_manager.db.client + self.account_id = await get_account_id_from_thread(self.client, self.config.thread_id) + if not self.account_id: + raise ValueError("Could not determine account ID for thread") + + project = await self.client.table('projects').select('*').eq('project_id', self.config.project_id).execute() + if not project.data or len(project.data) == 0: + raise ValueError(f"Project {self.config.project_id} not found") + + project_data = project.data[0] + sandbox_info = project_data.get('sandbox', {}) + if not sandbox_info.get('id'): + raise ValueError(f"No sandbox found for project {self.config.project_id}") + + async def setup_tools(self): + tool_manager = ToolManager(self.thread_manager, self.config.project_id, self.config.thread_id) + + if self.config.agent_config and self.config.agent_config.get('is_suna_default', False): + suna_agent_id = self.config.agent_config['agent_id'] + tool_manager.register_agent_builder_tools(suna_agent_id) + + if self.config.is_agent_builder: + tool_manager.register_agent_builder_tools(self.config.target_agent_id) + + enabled_tools = None + if self.config.agent_config and 'agentpress_tools' in self.config.agent_config: + raw_tools = self.config.agent_config['agentpress_tools'] - generation = trace.generation(name="thread_manager.run_thread") if trace else None - try: - # Make the LLM call and process the response - response = await thread_manager.run_thread( - thread_id=thread_id, - system_prompt=system_message, - stream=stream, - llm_model=model_name, - llm_temperature=0, - llm_max_tokens=max_tokens, - tool_choice="auto", - max_xml_tool_calls=1, - temporary_message=temporary_message, - processor_config=ProcessorConfig( - xml_tool_calling=True, - native_tool_calling=False, - execute_tools=True, - execute_on_stream=True, - tool_execution_strategy="parallel", - xml_adding_strategy="user_message" - ), - native_max_auto_continues=native_max_auto_continues, - include_xml_examples=True, - enable_thinking=enable_thinking, - reasoning_effort=reasoning_effort, - enable_context_manager=enable_context_manager, - generation=generation - ) + if isinstance(raw_tools, dict): + if self.config.agent_config.get('is_suna_default', False) and not raw_tools: + enabled_tools = None + else: + enabled_tools = raw_tools + else: + enabled_tools = None - if isinstance(response, dict) and "status" in response and response["status"] == "error": - logger.error(f"Error response from run_thread: {response.get('message', 'Unknown error')}") - if trace: - trace.event(name="error_response_from_run_thread", level="ERROR", status_message=(f"{response.get('message', 'Unknown error')}")) - yield response + if enabled_tools is None: + tool_manager.register_all_tools() + else: + if not isinstance(enabled_tools, dict): + enabled_tools = {} + tool_manager.register_custom_tools(enabled_tools) + + async def setup_mcp_tools(self) -> Optional[MCPToolWrapper]: + if not self.config.agent_config: + return None + + mcp_manager = MCPManager(self.thread_manager, self.account_id) + return await mcp_manager.register_mcp_tools(self.config.agent_config) + + def get_max_tokens(self) -> Optional[int]: + if "sonnet" in self.config.model_name.lower(): + return 8192 + elif "gpt-4" in self.config.model_name.lower(): + return 4096 + elif "gemini-2.5-pro" in self.config.model_name.lower(): + return 64000 + elif "kimi-k2" in self.config.model_name.lower(): + return 8192 + return None + + async def run(self) -> AsyncGenerator[Dict[str, Any], None]: + await self.setup() + await self.setup_tools() + mcp_wrapper_instance = await self.setup_mcp_tools() + + system_message = await PromptManager.build_system_prompt( + self.config.model_name, self.config.agent_config, + self.config.is_agent_builder, self.config.thread_id, + mcp_wrapper_instance + ) + + iteration_count = 0 + continue_execution = True + + latest_user_message = await self.client.table('messages').select('*').eq('thread_id', self.config.thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute() + if latest_user_message.data and len(latest_user_message.data) > 0: + data = latest_user_message.data[0]['content'] + if isinstance(data, str): + data = json.loads(data) + if self.config.trace: + self.config.trace.update(input=data['content']) + + message_manager = MessageManager(self.client, self.config.thread_id, self.config.model_name, self.config.trace) + + while continue_execution and iteration_count < self.config.max_iterations: + iteration_count += 1 + + can_run, message, subscription = await check_billing_status(self.client, self.account_id) + if not can_run: + error_msg = f"Billing limit reached: {message}" + yield { + "type": "status", + "status": "stopped", + "message": error_msg + } break - # Track if we see ask, complete, or web-browser-takeover tool calls - last_tool_call = None - agent_should_terminate = False + latest_message = await self.client.table('messages').select('*').eq('thread_id', self.config.thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute() + if latest_message.data and len(latest_message.data) > 0: + message_type = latest_message.data[0].get('type') + if message_type == 'assistant': + continue_execution = False + break - # Process the response - error_detected = False - full_response = "" + temporary_message = await message_manager.build_temporary_message() + max_tokens = self.get_max_tokens() + + generation = self.config.trace.generation(name="thread_manager.run_thread") if self.config.trace else None try: - # Check if response is iterable (async generator) or a dict (error case) - if hasattr(response, '__aiter__') and not isinstance(response, dict): - async for chunk in response: - # If we receive an error chunk, we should stop after this iteration - if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error': - logger.error(f"Error chunk detected: {chunk.get('message', 'Unknown error')}") - if trace: - trace.event(name="error_chunk_detected", level="ERROR", status_message=(f"{chunk.get('message', 'Unknown error')}")) - error_detected = True - yield chunk # Forward the error chunk - continue # Continue processing other chunks but don't break yet - - # Check for termination signal in status messages - if chunk.get('type') == 'status': - try: - # Parse the metadata to check for termination signal - metadata = chunk.get('metadata', {}) - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if metadata.get('agent_should_terminate'): - agent_should_terminate = True - logger.info("Agent termination signal detected in status message") - if trace: - trace.event(name="agent_termination_signal_detected", level="DEFAULT", status_message="Agent termination signal detected in status message") + response = await self.thread_manager.run_thread( + thread_id=self.config.thread_id, + system_prompt=system_message, + stream=self.config.stream, + llm_model=self.config.model_name, + llm_temperature=0, + llm_max_tokens=max_tokens, + tool_choice="auto", + max_xml_tool_calls=1, + temporary_message=temporary_message, + processor_config=ProcessorConfig( + xml_tool_calling=True, + native_tool_calling=False, + execute_tools=True, + execute_on_stream=True, + tool_execution_strategy="parallel", + xml_adding_strategy="user_message" + ), + native_max_auto_continues=self.config.native_max_auto_continues, + include_xml_examples=True, + enable_thinking=self.config.enable_thinking, + reasoning_effort=self.config.reasoning_effort, + enable_context_manager=self.config.enable_context_manager, + generation=generation + ) + + if isinstance(response, dict) and "status" in response and response["status"] == "error": + yield response + break + + last_tool_call = None + agent_should_terminate = False + error_detected = False + full_response = "" + + try: + if hasattr(response, '__aiter__') and not isinstance(response, dict): + async for chunk in response: + if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error': + error_detected = True + yield chunk + continue + + if chunk.get('type') == 'status': + try: + metadata = chunk.get('metadata', {}) + if isinstance(metadata, str): + metadata = json.loads(metadata) - # Extract the tool name from the status content if available - content = chunk.get('content', {}) - if isinstance(content, str): - content = json.loads(content) - - if content.get('function_name'): - last_tool_call = content['function_name'] - elif content.get('xml_tag_name'): - last_tool_call = content['xml_tag_name'] + if metadata.get('agent_should_terminate'): + agent_should_terminate = True - except Exception as e: - logger.debug(f"Error parsing status message for termination check: {e}") + content = chunk.get('content', {}) + if isinstance(content, str): + content = json.loads(content) + + if content.get('function_name'): + last_tool_call = content['function_name'] + elif content.get('xml_tag_name'): + last_tool_call = content['xml_tag_name'] + + except Exception: + pass - # Check for XML versions like , , or in assistant content chunks - if chunk.get('type') == 'assistant' and 'content' in chunk: - try: - # The content field might be a JSON string or object - content = chunk.get('content', '{}') - if isinstance(content, str): - assistant_content_json = json.loads(content) - else: - assistant_content_json = content + if chunk.get('type') == 'assistant' and 'content' in chunk: + try: + content = chunk.get('content', '{}') + if isinstance(content, str): + assistant_content_json = json.loads(content) + else: + assistant_content_json = content - # The actual text content is nested within - assistant_text = assistant_content_json.get('content', '') - full_response += assistant_text - if isinstance(assistant_text, str): - if '' in assistant_text or '' in assistant_text or '' in assistant_text: - if '' in assistant_text: - xml_tool = 'ask' - elif '' in assistant_text: - xml_tool = 'complete' - elif '' in assistant_text: - xml_tool = 'web-browser-takeover' + assistant_text = assistant_content_json.get('content', '') + full_response += assistant_text + if isinstance(assistant_text, str): + if '' in assistant_text or '' in assistant_text or '' in assistant_text: + if '' in assistant_text: + xml_tool = 'ask' + elif '' in assistant_text: + xml_tool = 'complete' + elif '' in assistant_text: + xml_tool = 'web-browser-takeover' - last_tool_call = xml_tool - logger.info(f"Agent used XML tool: {xml_tool}") - if trace: - trace.event(name="agent_used_xml_tool", level="DEFAULT", status_message=(f"Agent used XML tool: {xml_tool}")) - - except json.JSONDecodeError: - # Handle cases where content might not be valid JSON - logger.warning(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}") - if trace: - trace.event(name="warning_could_not_parse_assistant_content_json", level="WARNING", status_message=(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}")) - except Exception as e: - logger.error(f"Error processing assistant chunk: {e}") - if trace: - trace.event(name="error_processing_assistant_chunk", level="ERROR", status_message=(f"Error processing assistant chunk: {e}")) + last_tool_call = xml_tool + + except json.JSONDecodeError: + pass + except Exception: + pass - yield chunk - else: - # Response is not iterable, likely an error dict - logger.error(f"Response is not iterable: {response}") - error_detected = True + yield chunk + else: + error_detected = True - # Check if we should stop based on the last tool call or error - if error_detected: - logger.info(f"Stopping due to error detected in response") - if trace: - trace.event(name="stopping_due_to_error_detected_in_response", level="DEFAULT", status_message=(f"Stopping due to error detected in response")) + if error_detected: + if generation: + generation.end(output=full_response, status_message="error_detected", level="ERROR") + break + + if agent_should_terminate or last_tool_call in ['ask', 'complete', 'web-browser-takeover']: + if generation: + generation.end(output=full_response, status_message="agent_stopped") + continue_execution = False + + except Exception as e: + error_msg = f"Error during response streaming: {str(e)}" if generation: - generation.end(output=full_response, status_message="error_detected", level="ERROR") + generation.end(output=full_response, status_message=error_msg, level="ERROR") + yield { + "type": "status", + "status": "error", + "message": error_msg + } break - if agent_should_terminate or last_tool_call in ['ask', 'complete', 'web-browser-takeover']: - logger.info(f"Agent decided to stop with tool: {last_tool_call}") - if trace: - trace.event(name="agent_decided_to_stop_with_tool", level="DEFAULT", status_message=(f"Agent decided to stop with tool: {last_tool_call}")) - if generation: - generation.end(output=full_response, status_message="agent_stopped") - continue_execution = False - except Exception as e: - # Just log the error and re-raise to stop all iterations - error_msg = f"Error during response streaming: {str(e)}" - logger.error(f"Error: {error_msg}") - if trace: - trace.event(name="error_during_response_streaming", level="ERROR", status_message=(f"Error during response streaming: {str(e)}")) - if generation: - generation.end(output=full_response, status_message=error_msg, level="ERROR") + error_msg = f"Error running thread: {str(e)}" yield { "type": "status", "status": "error", "message": error_msg } - # Stop execution immediately on any error break - - except Exception as e: - # Just log the error and re-raise to stop all iterations - error_msg = f"Error running thread: {str(e)}" - logger.error(f"Error: {error_msg}") - if trace: - trace.event(name="error_running_thread", level="ERROR", status_message=(f"Error running thread: {str(e)}")) - yield { - "type": "status", - "status": "error", - "message": error_msg - } - # Stop execution immediately on any error - break - if generation: - generation.end(output=full_response) + + if generation: + generation.end(output=full_response) - asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush())) \ No newline at end of file + asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush())) + + +async def run_agent( + thread_id: str, + project_id: str, + stream: bool, + thread_manager: Optional[ThreadManager] = None, + native_max_auto_continues: int = 25, + max_iterations: int = 100, + model_name: str = "anthropic/claude-sonnet-4-20250514", + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low', + enable_context_manager: bool = True, + agent_config: Optional[dict] = None, + trace: Optional[StatefulTraceClient] = None, + is_agent_builder: Optional[bool] = False, + target_agent_id: Optional[str] = None +): + config = AgentConfig( + thread_id=thread_id, + project_id=project_id, + stream=stream, + native_max_auto_continues=native_max_auto_continues, + max_iterations=max_iterations, + model_name=model_name, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager, + agent_config=agent_config, + trace=trace, + is_agent_builder=is_agent_builder, + target_agent_id=target_agent_id + ) + + runner = AgentRunner(config) + async for chunk in runner.run(): + yield chunk \ No newline at end of file diff --git a/backend/agent/tools/agent_builder_tools/agent_config_tool.py b/backend/agent/tools/agent_builder_tools/agent_config_tool.py index 44170279..6515053c 100644 --- a/backend/agent/tools/agent_builder_tools/agent_config_tool.py +++ b/backend/agent/tools/agent_builder_tools/agent_config_tool.py @@ -194,16 +194,40 @@ class AgentConfigTool(AgentBuilderBaseTool): if isinstance(configured_mcps, str): configured_mcps = json.loads(configured_mcps) - existing_mcps_by_name = {mcp.get('qualifiedName', ''): mcp for mcp in current_configured_mcps} + def get_mcp_identifier(mcp): + if not isinstance(mcp, dict): + return None + return ( + mcp.get('qualifiedName') or + mcp.get('name') or + f"{mcp.get('type', 'unknown')}_{mcp.get('config', {}).get('url', 'nourl')}" or + str(hash(json.dumps(mcp, sort_keys=True))) + ) + + merged_mcps = [] + existing_identifiers = set() + + for existing_mcp in current_configured_mcps: + identifier = get_mcp_identifier(existing_mcp) + if identifier: + existing_identifiers.add(identifier) + merged_mcps.append(existing_mcp) for new_mcp in configured_mcps: - qualified_name = new_mcp.get('qualifiedName', '') - if qualified_name: - existing_mcps_by_name[qualified_name] = new_mcp + identifier = get_mcp_identifier(new_mcp) + + if identifier and identifier in existing_identifiers: + for i, existing_mcp in enumerate(merged_mcps): + if get_mcp_identifier(existing_mcp) == identifier: + merged_mcps[i] = new_mcp + break else: - current_configured_mcps.append(new_mcp) + merged_mcps.append(new_mcp) + if identifier: + existing_identifiers.add(identifier) - current_configured_mcps = list(existing_mcps_by_name.values()) + current_configured_mcps = merged_mcps + logger.info(f"MCP merge result: {len(current_configured_mcps)} total MCPs (was {len(current_version.get('configured_mcps', []))}, adding {len(configured_mcps)})") current_custom_mcps = current_version.get('custom_mcps', []) diff --git a/backend/agent/tools/agent_builder_tools/credential_profile_tool.py b/backend/agent/tools/agent_builder_tools/credential_profile_tool.py index 6ef93d20..c134f6e5 100644 --- a/backend/agent/tools/agent_builder_tools/credential_profile_tool.py +++ b/backend/agent/tools/agent_builder_tools/credential_profile_tool.py @@ -278,13 +278,11 @@ class CredentialProfileTool(AgentBuilderBaseTool): if profile.is_connected and connections: try: - # directly discover MCP servers via the facade from pipedream.domain.entities import ConnectionStatus servers = await self.pipedream_manager.discover_mcp_servers( external_user_id=profile.external_user_id.value if hasattr(profile.external_user_id, 'value') else str(profile.external_user_id), app_slug=profile.app_slug.value if hasattr(profile.app_slug, 'value') else str(profile.app_slug) ) - # filter connected servers connected_servers = [s for s in servers if s.status == ConnectionStatus.CONNECTED] if connected_servers: tools = [t.name for t in connected_servers[0].available_tools] @@ -422,7 +420,6 @@ class CredentialProfileTool(AgentBuilderBaseTool): if not profile: return self.fail_response("Credential profile not found") - # Get current version config agent_result = await client.table('agents').select('current_version_id').eq('agent_id', self.agent_id).execute() if agent_result.data and agent_result.data[0].get('current_version_id'): version_result = await client.table('agent_versions')\ diff --git a/backend/agent/tools/agent_builder_tools/mcp_search_tool.py b/backend/agent/tools/agent_builder_tools/mcp_search_tool.py index d7824596..1fb51065 100644 --- a/backend/agent/tools/agent_builder_tools/mcp_search_tool.py +++ b/backend/agent/tools/agent_builder_tools/mcp_search_tool.py @@ -145,10 +145,54 @@ class MCPSearchTool(AgentBuilderBaseTool): "available_triggers": getattr(app_data, 'available_triggers', []) } - return self.success_response({ + available_tools = [] + try: + import httpx + import json + + url = f"https://remote.mcp.pipedream.net/?app={app_slug}&externalUserId=tools_preview" + payload = {"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 1} + headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"} + + async with httpx.AsyncClient(timeout=30.0) as client: + async with client.stream("POST", url, json=payload, headers=headers) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line or not line.startswith("data:"): + continue + data_str = line[len("data:"):].strip() + try: + data_obj = json.loads(data_str) + tools = data_obj.get("result", {}).get("tools", []) + for tool in tools: + desc = tool.get("description", "") or "" + idx = desc.find("[") + if idx != -1: + desc = desc[:idx].strip() + + available_tools.append({ + "name": tool.get("name", ""), + "description": desc + }) + break + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON data: {data_str}") + continue + + except Exception as tools_error: + logger.warning(f"Could not fetch MCP tools for {app_slug}: {tools_error}") + + result = { "message": f"Retrieved details for {formatted_app['name']}", - "app": formatted_app - }) + "app": formatted_app, + "available_mcp_tools": available_tools, + "total_mcp_tools": len(available_tools) + } + + if available_tools: + result["message"] += f" - {len(available_tools)} MCP tools available" + + return self.success_response(result) except Exception as e: return self.fail_response(f"Error getting app details: {str(e)}") diff --git a/backend/agent/tools/mcp_tool_wrapper.py b/backend/agent/tools/mcp_tool_wrapper.py index 79a735a3..17169d66 100644 --- a/backend/agent/tools/mcp_tool_wrapper.py +++ b/backend/agent/tools/mcp_tool_wrapper.py @@ -1,5 +1,5 @@ from typing import Any, Dict, List, Optional -from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema, ToolSchema, SchemaType +from agentpress.tool import Tool, ToolResult, ToolSchema, SchemaType from mcp_module import mcp_manager from utils.logger import logger import inspect @@ -77,33 +77,55 @@ class MCPToolWrapper(Tool): logger.info(f"Created {len(self._dynamic_tools)} dynamic MCP tool methods") + # Re-register schemas to pick up the dynamic methods + self._register_schemas() + logger.info(f"Re-registered schemas after creating dynamic tools - total: {len(self._schemas)}") + except Exception as e: logger.error(f"Error creating dynamic MCP tools: {e}") def _register_schemas(self): + self._schemas.clear() + for name, method in inspect.getmembers(self, predicate=inspect.ismethod): if hasattr(method, 'tool_schemas'): self._schemas[name] = method.tool_schemas logger.debug(f"Registered schemas for method '{name}' in {self.__class__.__name__}") - logger.debug(f"Initial registration complete for MCPToolWrapper") + if hasattr(self, '_dynamic_tools') and self._dynamic_tools: + for tool_name, tool_data in self._dynamic_tools.items(): + method_name = tool_data.get('method_name') + if method_name and method_name in self._schemas: + continue + + method = tool_data.get('method') + if method and hasattr(method, 'tool_schemas'): + self._schemas[method_name] = method.tool_schemas + logger.debug(f"Registered dynamic method schemas for '{method_name}'") + + logger.debug(f"Registration complete for MCPToolWrapper - total schemas: {len(self._schemas)}") def get_schemas(self) -> Dict[str, List[ToolSchema]]: + logger.debug(f"get_schemas called - returning {len(self._schemas)} schemas") + for method_name in self._schemas: + logger.debug(f" - Schema available for: {method_name}") return self._schemas def __getattr__(self, name: str): - method = self.tool_builder.find_method_by_name(name) - if method: - return method + if hasattr(self, 'tool_builder') and self.tool_builder: + method = self.tool_builder.find_method_by_name(name) + if method: + return method - for tool_data in self._dynamic_tools.values(): - if tool_data.get('method_name') == name: - return tool_data.get('method') - - name_with_hyphens = name.replace('_', '-') - for tool_name, tool_data in self._dynamic_tools.items(): - if tool_data.get('method_name') == name or tool_name == name_with_hyphens: - return tool_data.get('method') + if hasattr(self, '_dynamic_tools') and self._dynamic_tools: + for tool_data in self._dynamic_tools.values(): + if tool_data.get('method_name') == name: + return tool_data.get('method') + + name_with_hyphens = name.replace('_', '-') + for tool_name, tool_data in self._dynamic_tools.items(): + if tool_data.get('method_name') == name or tool_name == name_with_hyphens: + return tool_data.get('method') raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") @@ -111,10 +133,7 @@ class MCPToolWrapper(Tool): await self._ensure_initialized() if tool_registry and self._dynamic_tools: logger.info(f"Updating tool registry with {len(self._dynamic_tools)} MCP tools") - for method_name, schemas in self._schemas.items(): - if method_name not in ['call_mcp_tool']: - pass - + async def get_available_tools(self) -> List[Dict[str, Any]]: await self._ensure_initialized() return self.mcp_manager.get_all_tools_openapi() @@ -123,46 +142,6 @@ class MCPToolWrapper(Tool): await self._ensure_initialized() return await self.tool_executor.execute_tool(tool_name, arguments) - @openapi_schema({ - "type": "function", - "function": { - "name": "call_mcp_tool", - "description": "Execute a tool from any connected MCP server. This is a fallback wrapper that forwards calls to MCP tools. The tool_name should be in the format 'mcp_{server}_{tool}' where {server} is the MCP server's qualified name and {tool} is the specific tool name.", - "parameters": { - "type": "object", - "properties": { - "tool_name": { - "type": "string", - "description": "The full MCP tool name in format 'mcp_{server}_{tool}', e.g., 'mcp_exa_web_search_exa'" - }, - "arguments": { - "type": "object", - "description": "The arguments to pass to the MCP tool, as a JSON object. The required arguments depend on the specific tool being called.", - "additionalProperties": True - } - }, - "required": ["tool_name", "arguments"] - } - } - }) - @xml_schema( - tag_name="call-mcp-tool", - mappings=[ - {"param_name": "tool_name", "node_type": "attribute", "path": "."}, - {"param_name": "arguments", "node_type": "content", "path": "."} - ], - example=''' - - - mcp_exa_web_search_exa - {"query": "latest developments in AI", "num_results": 10} - - - ''' - ) - async def call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any]) -> ToolResult: - return await self._execute_mcp_tool(tool_name, arguments) - async def cleanup(self): if self._initialized: try: diff --git a/backend/pipedream/facade.py b/backend/pipedream/facade.py index 4c47a000..5f35bc0f 100644 --- a/backend/pipedream/facade.py +++ b/backend/pipedream/facade.py @@ -337,6 +337,7 @@ class PipedreamManager: user_id: str, enabled_tools: List[str] ) -> Dict[str, Any]: + from services.supabase import DBConnection from agent.versioning.version_service import get_version_service import copy @@ -353,17 +354,24 @@ class PipedreamManager: agent = agent_result.data[0] + print(f"[DEBUG] Starting update_agent_profile_tools for agent {agent_id}, profile {profile_id}") + current_version_data = None if agent.get('current_version_id'): try: - current_version_data = await get_version_service().get_version( + version_service = await get_version_service() + current_version_data = await version_service.get_version( agent_id=agent_id, version_id=agent['current_version_id'], user_id=user_id ) version_data = current_version_data.to_dict() current_custom_mcps = version_data.get('custom_mcps', []) + print(f"[DEBUG] Retrieved current version {agent['current_version_id']}") + print(f"[DEBUG] Current custom_mcps count: {len(current_custom_mcps)}") + print(f"[DEBUG] Current custom_mcps: {current_custom_mcps}") except Exception as e: + print(f"[DEBUG] Error getting current version: {e}") pass @@ -376,30 +384,40 @@ class PipedreamManager: configured_mcps = current_version_data.configured_mcps agentpress_tools = current_version_data.agentpress_tools current_custom_mcps = current_version_data.custom_mcps + print(f"[DEBUG] Using version data - custom_mcps count: {len(current_custom_mcps)}") else: system_prompt = '' configured_mcps = [] agentpress_tools = {} current_custom_mcps = [] + print(f"[DEBUG] No version data - starting with empty custom_mcps") + updated_custom_mcps = copy.deepcopy(current_custom_mcps) + print(f"[DEBUG] After deepcopy - updated_custom_mcps count: {len(updated_custom_mcps)}") + print(f"[DEBUG] After deepcopy - updated_custom_mcps: {updated_custom_mcps}") + # Normalize enabledTools vs enabled_tools for mcp in updated_custom_mcps: if 'enabled_tools' in mcp and 'enabledTools' not in mcp: mcp['enabledTools'] = mcp['enabled_tools'] elif 'enabledTools' not in mcp and 'enabled_tools' not in mcp: mcp['enabledTools'] = [] + # Look for existing MCP with same profile_id found_match = False - for mcp in updated_custom_mcps: + for i, mcp in enumerate(updated_custom_mcps): + print(f"[DEBUG] Checking MCP {i}: type={mcp.get('type')}, profile_id={mcp.get('config', {}).get('profile_id')}") if (mcp.get('type') == 'pipedream' and mcp.get('config', {}).get('profile_id') == profile_id): + print(f"[DEBUG] Found existing MCP at index {i}, updating tools from {mcp.get('enabledTools', [])} to {enabled_tools}") mcp['enabledTools'] = enabled_tools mcp['enabled_tools'] = enabled_tools found_match = True break if not found_match: + print(f"[DEBUG] No existing MCP found, creating new one") new_mcp_config = { "name": profile.app_name, "type": "pipedream", @@ -413,7 +431,12 @@ class PipedreamManager: "enabledTools": enabled_tools, "enabled_tools": enabled_tools } + print(f"[DEBUG] New MCP config: {new_mcp_config}") updated_custom_mcps.append(new_mcp_config) + print(f"[DEBUG] After append - updated_custom_mcps count: {len(updated_custom_mcps)}") + + print(f"[DEBUG] Final updated_custom_mcps count: {len(updated_custom_mcps)}") + print(f"[DEBUG] Final updated_custom_mcps: {updated_custom_mcps}") version_service = await get_version_service() @@ -428,6 +451,8 @@ class PipedreamManager: change_description=f"Updated {profile.app_name} tools" ) + print(f"[DEBUG] Created new version {new_version.version_id} with {len(updated_custom_mcps)} custom MCPs") + update_result = await client.table('agents').update({ 'current_version_id': new_version.version_id }).eq('agent_id', agent_id).execute() @@ -440,7 +465,7 @@ class PipedreamManager: 'enabled_tools': enabled_tools, 'total_tools': len(enabled_tools), 'version_id': new_version.version_id, - 'version_name': new_version['version_name'] + 'version_name': new_version.version_name } async def close(self): diff --git a/frontend/src/components/thread/content/ThreadContent.tsx b/frontend/src/components/thread/content/ThreadContent.tsx index ae37cdb2..f2a8f81f 100644 --- a/frontend/src/components/thread/content/ThreadContent.tsx +++ b/frontend/src/components/thread/content/ThreadContent.tsx @@ -49,8 +49,6 @@ const HIDE_STREAMING_XML_TAGS = new Set([ 'crawl-webpage', 'web-search', 'see-image', - 'call-mcp-tool', - 'execute_data_provider_call', 'execute_data_provider_endpoint', diff --git a/frontend/src/components/thread/tool-views/get-current-agent-config/_utils.ts b/frontend/src/components/thread/tool-views/get-current-agent-config/_utils.ts index bd7f0477..9b5fbaf0 100644 --- a/frontend/src/components/thread/tool-views/get-current-agent-config/_utils.ts +++ b/frontend/src/components/thread/tool-views/get-current-agent-config/_utils.ts @@ -13,7 +13,7 @@ export interface CustomMcp { headers?: Record; profile_id?: string; }; - enabled_tools: string[]; + enabledTools: string[]; } export interface AgentConfiguration { diff --git a/frontend/src/components/thread/tool-views/get-current-agent-config/get-current-agent-config.tsx b/frontend/src/components/thread/tool-views/get-current-agent-config/get-current-agent-config.tsx index e7474315..0e885d37 100644 --- a/frontend/src/components/thread/tool-views/get-current-agent-config/get-current-agent-config.tsx +++ b/frontend/src/components/thread/tool-views/get-current-agent-config/get-current-agent-config.tsx @@ -83,7 +83,10 @@ export function GetCurrentAgentConfigToolView({ }; const getTotalMcpToolsCount = (mcps: CustomMcp[]) => { - return mcps.reduce((total, mcp) => total + mcp.enabled_tools.length, 0); + return mcps.reduce((total, mcp) => { + const enabledTools = mcp.enabledTools || []; + return total + (Array.isArray(enabledTools) ? enabledTools.length : 0); + }, 0); }; return ( @@ -268,12 +271,12 @@ export function GetCurrentAgentConfigToolView({ - {mcp.enabled_tools.length} tools + {mcp.enabledTools.length} tools
- {mcp.enabled_tools.map((tool, toolIndex) => ( + {mcp.enabledTools.map((tool, toolIndex) => (
diff --git a/frontend/src/components/thread/utils.ts b/frontend/src/components/thread/utils.ts index c5942ce4..57f03d69 100644 --- a/frontend/src/components/thread/utils.ts +++ b/frontend/src/components/thread/utils.ts @@ -387,8 +387,6 @@ const TOOL_DISPLAY_NAMES = new Map([ ['web_search', 'Searching Web'], ['see_image', 'Viewing Image'], - ['call_mcp_tool', 'External Tool'], - ['update_agent', 'Updating Agent'], ['get_current_agent_config', 'Getting Agent Config'], ['search_mcp_servers', 'Searching MCP Servers'],