mirror of https://github.com/kortix-ai/suna.git
refactor versioning & run.py
This commit is contained in:
parent
f848d5c10f
commit
8e1cce5cbd
|
@ -1,9 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional, Dict, List, Any, AsyncGenerator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
# from agent.tools.message_tool import MessageTool
|
|
||||||
from agent.tools.message_tool import MessageTool
|
from agent.tools.message_tool import MessageTool
|
||||||
from agent.tools.sb_deploy_tool import SandboxDeployTool
|
from agent.tools.sb_deploy_tool import SandboxDeployTool
|
||||||
from agent.tools.sb_expose_tool import SandboxExposeTool
|
from agent.tools.sb_expose_tool import SandboxExposeTool
|
||||||
|
@ -27,202 +27,129 @@ from agent.tools.sb_vision_tool import SandboxVisionTool
|
||||||
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
|
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
|
||||||
from services.langfuse import langfuse
|
from services.langfuse import langfuse
|
||||||
from langfuse.client import StatefulTraceClient
|
from langfuse.client import StatefulTraceClient
|
||||||
from services.langfuse import langfuse
|
|
||||||
from agent.gemini_prompt import get_gemini_system_prompt
|
from agent.gemini_prompt import get_gemini_system_prompt
|
||||||
from agent.tools.mcp_tool_wrapper import MCPToolWrapper
|
from agent.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||||
from agentpress.tool import SchemaType
|
from agentpress.tool import SchemaType
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
async def run_agent(
|
|
||||||
thread_id: str,
|
@dataclass
|
||||||
project_id: str,
|
class AgentConfig:
|
||||||
stream: bool,
|
thread_id: str
|
||||||
thread_manager: Optional[ThreadManager] = None,
|
project_id: str
|
||||||
native_max_auto_continues: int = 25,
|
stream: bool
|
||||||
max_iterations: int = 100,
|
native_max_auto_continues: int = 25
|
||||||
model_name: str = "anthropic/claude-sonnet-4-20250514",
|
max_iterations: int = 100
|
||||||
enable_thinking: Optional[bool] = False,
|
model_name: str = "anthropic/claude-sonnet-4-20250514"
|
||||||
reasoning_effort: Optional[str] = 'low',
|
enable_thinking: Optional[bool] = False
|
||||||
enable_context_manager: bool = True,
|
reasoning_effort: Optional[str] = 'low'
|
||||||
agent_config: Optional[dict] = None,
|
enable_context_manager: bool = True
|
||||||
trace: Optional[StatefulTraceClient] = None,
|
agent_config: Optional[dict] = None
|
||||||
is_agent_builder: Optional[bool] = False,
|
trace: Optional[StatefulTraceClient] = None
|
||||||
|
is_agent_builder: Optional[bool] = False
|
||||||
target_agent_id: Optional[str] = None
|
target_agent_id: Optional[str] = None
|
||||||
):
|
|
||||||
"""Run the development agent with specified configuration."""
|
|
||||||
logger.info(f"🚀 Starting agent with model: {model_name}")
|
|
||||||
if agent_config:
|
|
||||||
logger.info(f"Using custom agent: {agent_config.get('name', 'Unknown')}")
|
|
||||||
|
|
||||||
if not trace:
|
|
||||||
trace = langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id})
|
|
||||||
thread_manager = ThreadManager(trace=trace, is_agent_builder=is_agent_builder or False, target_agent_id=target_agent_id, agent_config=agent_config)
|
|
||||||
|
|
||||||
client = await thread_manager.db.client
|
|
||||||
|
|
||||||
# Get account ID from thread for billing checks
|
|
||||||
account_id = await get_account_id_from_thread(client, thread_id)
|
|
||||||
if not account_id:
|
|
||||||
raise ValueError("Could not determine account ID for thread")
|
|
||||||
|
|
||||||
# Get sandbox info from project
|
|
||||||
project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
|
||||||
if not project.data or len(project.data) == 0:
|
|
||||||
raise ValueError(f"Project {project_id} not found")
|
|
||||||
|
|
||||||
project_data = project.data[0]
|
|
||||||
sandbox_info = project_data.get('sandbox', {})
|
|
||||||
if not sandbox_info.get('id'):
|
|
||||||
raise ValueError(f"No sandbox found for project {project_id}")
|
|
||||||
|
|
||||||
enabled_tools = {}
|
|
||||||
if agent_config and 'agentpress_tools' in agent_config:
|
|
||||||
raw_tools = agent_config['agentpress_tools']
|
|
||||||
logger.info(f"Raw agentpress_tools type: {type(raw_tools)}, value: {raw_tools}")
|
|
||||||
|
|
||||||
if isinstance(raw_tools, dict):
|
|
||||||
enabled_tools = raw_tools
|
|
||||||
logger.info(f"Using custom tool configuration from agent")
|
|
||||||
else:
|
|
||||||
logger.warning(f"agentpress_tools is not a dict (got {type(raw_tools)}), using empty dict")
|
|
||||||
enabled_tools = {}
|
|
||||||
|
|
||||||
|
|
||||||
# Check if this is Suna (default agent) and enable builder capabilities for self-configuration
|
class ToolManager:
|
||||||
if agent_config and agent_config.get('is_suna_default', False):
|
def __init__(self, thread_manager: ThreadManager, project_id: str, thread_id: str):
|
||||||
logger.info("Detected Suna default agent - enabling self-configuration capabilities")
|
self.thread_manager = thread_manager
|
||||||
|
self.project_id = project_id
|
||||||
|
self.thread_id = thread_id
|
||||||
|
|
||||||
from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool
|
def register_all_tools(self):
|
||||||
from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool
|
self.thread_manager.add_tool(SandboxShellTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool
|
self.thread_manager.add_tool(SandboxFilesTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool
|
self.thread_manager.add_tool(SandboxBrowserTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager)
|
||||||
from agent.tools.agent_builder_tools.trigger_tool import TriggerTool
|
self.thread_manager.add_tool(SandboxDeployTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
from services.supabase import DBConnection
|
self.thread_manager.add_tool(SandboxExposeTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
db = DBConnection()
|
self.thread_manager.add_tool(ExpandMessageTool, thread_id=self.thread_id, thread_manager=self.thread_manager)
|
||||||
|
self.thread_manager.add_tool(MessageTool)
|
||||||
# Use Suna's own agent ID for self-configuration
|
self.thread_manager.add_tool(SandboxWebSearchTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
suna_agent_id = agent_config['agent_id']
|
self.thread_manager.add_tool(SandboxVisionTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager)
|
||||||
|
self.thread_manager.add_tool(SandboxImageEditTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager)
|
||||||
thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id)
|
|
||||||
thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id)
|
|
||||||
thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id)
|
|
||||||
thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id)
|
|
||||||
thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=suna_agent_id)
|
|
||||||
|
|
||||||
logger.info(f"Enabled Suna self-configuration with agent ID: {suna_agent_id}")
|
|
||||||
|
|
||||||
# Original agent builder logic for custom agents (preserved)
|
|
||||||
if is_agent_builder:
|
|
||||||
from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool
|
|
||||||
from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool
|
|
||||||
from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool
|
|
||||||
from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool
|
|
||||||
from agent.tools.agent_builder_tools.trigger_tool import TriggerTool
|
|
||||||
from services.supabase import DBConnection
|
|
||||||
db = DBConnection()
|
|
||||||
|
|
||||||
thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
|
|
||||||
thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
|
|
||||||
thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
|
|
||||||
thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
|
|
||||||
thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
|
|
||||||
|
|
||||||
|
|
||||||
if enabled_tools is None:
|
|
||||||
logger.info("No agent specified - registering all tools for full Suna capabilities")
|
|
||||||
thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(MessageTool)
|
|
||||||
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
|
||||||
thread_manager.add_tool(SandboxImageEditTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
|
||||||
if config.RAPID_API_KEY:
|
if config.RAPID_API_KEY:
|
||||||
thread_manager.add_tool(DataProvidersTool)
|
self.thread_manager.add_tool(DataProvidersTool)
|
||||||
else:
|
|
||||||
logger.info("Custom agent specified - registering only enabled tools")
|
|
||||||
|
|
||||||
# Final safety check: ensure enabled_tools is always a dictionary
|
def register_agent_builder_tools(self, agent_id: str):
|
||||||
if not isinstance(enabled_tools, dict):
|
from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool
|
||||||
logger.error(f"CRITICAL: enabled_tools is still not a dict at runtime! Type: {type(enabled_tools)}, Value: {enabled_tools}")
|
from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool
|
||||||
enabled_tools = {}
|
from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool
|
||||||
|
from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool
|
||||||
|
from agent.tools.agent_builder_tools.trigger_tool import TriggerTool
|
||||||
|
from services.supabase import DBConnection
|
||||||
|
|
||||||
thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
|
db = DBConnection()
|
||||||
thread_manager.add_tool(MessageTool)
|
self.thread_manager.add_tool(AgentConfigTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id)
|
||||||
|
self.thread_manager.add_tool(MCPSearchTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id)
|
||||||
|
self.thread_manager.add_tool(CredentialProfileTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id)
|
||||||
|
self.thread_manager.add_tool(WorkflowTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id)
|
||||||
|
self.thread_manager.add_tool(TriggerTool, thread_manager=self.thread_manager, db_connection=db, agent_id=agent_id)
|
||||||
|
|
||||||
|
def register_custom_tools(self, enabled_tools: Dict[str, Any]):
|
||||||
|
self.thread_manager.add_tool(ExpandMessageTool, thread_id=self.thread_id, thread_manager=self.thread_manager)
|
||||||
|
self.thread_manager.add_tool(MessageTool)
|
||||||
|
|
||||||
def safe_tool_check(tool_name: str) -> bool:
|
def safe_tool_check(tool_name: str) -> bool:
|
||||||
try:
|
try:
|
||||||
if not isinstance(enabled_tools, dict):
|
if not isinstance(enabled_tools, dict):
|
||||||
logger.error(f"enabled_tools is {type(enabled_tools)} at tool check for {tool_name}")
|
|
||||||
return False
|
return False
|
||||||
tool_config = enabled_tools.get(tool_name, {})
|
tool_config = enabled_tools.get(tool_name, {})
|
||||||
if not isinstance(tool_config, dict):
|
if not isinstance(tool_config, dict):
|
||||||
return bool(tool_config) if isinstance(tool_config, bool) else False
|
return bool(tool_config) if isinstance(tool_config, bool) else False
|
||||||
return tool_config.get('enabled', False)
|
return tool_config.get('enabled', False)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Exception in tool check for {tool_name}: {e}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if safe_tool_check('sb_shell_tool'):
|
if safe_tool_check('sb_shell_tool'):
|
||||||
thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
|
self.thread_manager.add_tool(SandboxShellTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
if safe_tool_check('sb_files_tool'):
|
if safe_tool_check('sb_files_tool'):
|
||||||
thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager)
|
self.thread_manager.add_tool(SandboxFilesTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
if safe_tool_check('sb_browser_tool'):
|
if safe_tool_check('sb_browser_tool'):
|
||||||
thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
self.thread_manager.add_tool(SandboxBrowserTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager)
|
||||||
if safe_tool_check('sb_deploy_tool'):
|
if safe_tool_check('sb_deploy_tool'):
|
||||||
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
self.thread_manager.add_tool(SandboxDeployTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
if safe_tool_check('sb_expose_tool'):
|
if safe_tool_check('sb_expose_tool'):
|
||||||
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
self.thread_manager.add_tool(SandboxExposeTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
if safe_tool_check('web_search_tool'):
|
if safe_tool_check('web_search_tool'):
|
||||||
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
self.thread_manager.add_tool(SandboxWebSearchTool, project_id=self.project_id, thread_manager=self.thread_manager)
|
||||||
if safe_tool_check('sb_vision_tool'):
|
if safe_tool_check('sb_vision_tool'):
|
||||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
self.thread_manager.add_tool(SandboxVisionTool, project_id=self.project_id, thread_id=self.thread_id, thread_manager=self.thread_manager)
|
||||||
if config.RAPID_API_KEY and safe_tool_check('data_providers_tool'):
|
if config.RAPID_API_KEY and safe_tool_check('data_providers_tool'):
|
||||||
thread_manager.add_tool(DataProvidersTool)
|
self.thread_manager.add_tool(DataProvidersTool)
|
||||||
|
|
||||||
# Register MCP tool wrapper if agent has configured MCPs or custom MCPs
|
|
||||||
mcp_wrapper_instance = None
|
class MCPManager:
|
||||||
if agent_config:
|
def __init__(self, thread_manager: ThreadManager, account_id: str):
|
||||||
# Merge configured_mcps and custom_mcps
|
self.thread_manager = thread_manager
|
||||||
|
self.account_id = account_id
|
||||||
|
|
||||||
|
async def register_mcp_tools(self, agent_config: dict) -> Optional[MCPToolWrapper]:
|
||||||
all_mcps = []
|
all_mcps = []
|
||||||
|
|
||||||
# Add standard configured MCPs
|
|
||||||
if agent_config.get('configured_mcps'):
|
if agent_config.get('configured_mcps'):
|
||||||
all_mcps.extend(agent_config['configured_mcps'])
|
all_mcps.extend(agent_config['configured_mcps'])
|
||||||
|
|
||||||
# Add custom MCPs
|
|
||||||
if agent_config.get('custom_mcps'):
|
if agent_config.get('custom_mcps'):
|
||||||
for custom_mcp in agent_config['custom_mcps']:
|
for custom_mcp in agent_config['custom_mcps']:
|
||||||
# Transform custom MCP to standard format
|
|
||||||
custom_type = custom_mcp.get('customType', custom_mcp.get('type', 'sse'))
|
custom_type = custom_mcp.get('customType', custom_mcp.get('type', 'sse'))
|
||||||
|
|
||||||
# For Pipedream MCPs, ensure we have the user ID and proper config
|
|
||||||
if custom_type == 'pipedream':
|
if custom_type == 'pipedream':
|
||||||
# Get user ID from thread
|
|
||||||
if 'config' not in custom_mcp:
|
if 'config' not in custom_mcp:
|
||||||
custom_mcp['config'] = {}
|
custom_mcp['config'] = {}
|
||||||
|
|
||||||
# Get external_user_id from profile if not present
|
|
||||||
if not custom_mcp['config'].get('external_user_id'):
|
if not custom_mcp['config'].get('external_user_id'):
|
||||||
profile_id = custom_mcp['config'].get('profile_id')
|
profile_id = custom_mcp['config'].get('profile_id')
|
||||||
if profile_id:
|
if profile_id:
|
||||||
try:
|
try:
|
||||||
from pipedream.profiles import get_profile_manager
|
from pipedream.facade import get_profile_manager
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
profile_db = DBConnection()
|
profile_db = DBConnection()
|
||||||
profile_manager = get_profile_manager(profile_db)
|
profile_manager = get_profile_manager(profile_db)
|
||||||
|
|
||||||
# Get the profile to retrieve external_user_id
|
profile = await profile_manager.get_profile(self.account_id, profile_id)
|
||||||
profile = await profile_manager.get_profile(account_id, profile_id)
|
|
||||||
if profile:
|
if profile:
|
||||||
custom_mcp['config']['external_user_id'] = profile.external_user_id
|
custom_mcp['config']['external_user_id'] = profile.external_user_id
|
||||||
logger.info(f"Retrieved external_user_id from profile {profile_id} for Pipedream MCP")
|
|
||||||
else:
|
|
||||||
logger.error(f"Could not find profile {profile_id} for Pipedream MCP")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving external_user_id from profile {profile_id}: {e}")
|
logger.error(f"Error retrieving external_user_id from profile {profile_id}: {e}")
|
||||||
|
|
||||||
|
@ -240,70 +167,50 @@ async def run_agent(
|
||||||
}
|
}
|
||||||
all_mcps.append(mcp_config)
|
all_mcps.append(mcp_config)
|
||||||
|
|
||||||
if all_mcps:
|
if not all_mcps:
|
||||||
logger.info(f"Registering MCP tool wrapper for {len(all_mcps)} MCP servers (including {len(agent_config.get('custom_mcps', []))} custom)")
|
return None
|
||||||
thread_manager.add_tool(MCPToolWrapper, mcp_configs=all_mcps)
|
|
||||||
|
|
||||||
for tool_name, tool_info in thread_manager.tool_registry.tools.items():
|
mcp_wrapper_instance = MCPToolWrapper(mcp_configs=all_mcps)
|
||||||
if isinstance(tool_info['instance'], MCPToolWrapper):
|
|
||||||
mcp_wrapper_instance = tool_info['instance']
|
|
||||||
break
|
|
||||||
|
|
||||||
if mcp_wrapper_instance:
|
|
||||||
try:
|
try:
|
||||||
await mcp_wrapper_instance.initialize_and_register_tools()
|
await mcp_wrapper_instance.initialize_and_register_tools()
|
||||||
logger.info("MCP tools initialized successfully")
|
|
||||||
updated_schemas = mcp_wrapper_instance.get_schemas()
|
updated_schemas = mcp_wrapper_instance.get_schemas()
|
||||||
logger.info(f"MCP wrapper has {len(updated_schemas)} schemas available")
|
|
||||||
for method_name, schema_list in updated_schemas.items():
|
for method_name, schema_list in updated_schemas.items():
|
||||||
if method_name != 'call_mcp_tool':
|
|
||||||
for schema in schema_list:
|
for schema in schema_list:
|
||||||
if schema.schema_type == SchemaType.OPENAPI:
|
self.thread_manager.tool_registry.tools[method_name] = {
|
||||||
thread_manager.tool_registry.tools[method_name] = {
|
|
||||||
"instance": mcp_wrapper_instance,
|
"instance": mcp_wrapper_instance,
|
||||||
"schema": schema
|
"schema": schema
|
||||||
}
|
}
|
||||||
logger.info(f"Registered dynamic MCP tool: {method_name}")
|
|
||||||
|
|
||||||
# Log all registered tools for debugging
|
|
||||||
all_tools = list(thread_manager.tool_registry.tools.keys())
|
|
||||||
logger.info(f"All registered tools after MCP initialization: {all_tools}")
|
|
||||||
mcp_tools = [tool for tool in all_tools if tool not in ['call_mcp_tool', 'sb_files_tool', 'message_tool', 'expand_msg_tool', 'web_search_tool', 'sb_shell_tool', 'sb_vision_tool', 'sb_browser_tool', 'computer_use_tool', 'data_providers_tool', 'sb_deploy_tool', 'sb_expose_tool', 'update_agent_tool']]
|
|
||||||
logger.info(f"MCP tools registered: {mcp_tools}")
|
|
||||||
|
|
||||||
|
return mcp_wrapper_instance
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize MCP tools: {e}")
|
logger.error(f"Failed to initialize MCP tools: {e}")
|
||||||
# Continue without MCP tools if initialization fails
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManager:
|
||||||
|
@staticmethod
|
||||||
|
async def build_system_prompt(model_name: str, agent_config: Optional[dict],
|
||||||
|
is_agent_builder: bool, thread_id: str,
|
||||||
|
mcp_wrapper_instance: Optional[MCPToolWrapper]) -> dict:
|
||||||
|
|
||||||
# Prepare system prompt
|
|
||||||
# First, get the default system prompt
|
|
||||||
if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower():
|
if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower():
|
||||||
default_system_content = get_gemini_system_prompt()
|
default_system_content = get_gemini_system_prompt()
|
||||||
else:
|
else:
|
||||||
# Use the original prompt - the LLM can only use tools that are registered
|
|
||||||
default_system_content = get_system_prompt()
|
default_system_content = get_system_prompt()
|
||||||
|
|
||||||
# Add sample response for non-anthropic models
|
|
||||||
if "anthropic" not in model_name.lower():
|
if "anthropic" not in model_name.lower():
|
||||||
sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt')
|
sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt')
|
||||||
with open(sample_response_path, 'r') as file:
|
with open(sample_response_path, 'r') as file:
|
||||||
sample_response = file.read()
|
sample_response = file.read()
|
||||||
default_system_content = default_system_content + "\n\n <sample_assistant_response>" + sample_response + "</sample_assistant_response>"
|
default_system_content = default_system_content + "\n\n <sample_assistant_response>" + sample_response + "</sample_assistant_response>"
|
||||||
|
|
||||||
# Handle custom agent system prompt
|
|
||||||
if agent_config and agent_config.get('system_prompt'):
|
if agent_config and agent_config.get('system_prompt'):
|
||||||
custom_system_prompt = agent_config['system_prompt'].strip()
|
system_content = agent_config['system_prompt'].strip()
|
||||||
# Completely replace the default system prompt with the custom one
|
|
||||||
# This prevents confusion and tool hallucination
|
|
||||||
system_content = custom_system_prompt
|
|
||||||
logger.info(f"Using ONLY custom agent system prompt for: {agent_config.get('name', 'Unknown')}")
|
|
||||||
elif is_agent_builder:
|
elif is_agent_builder:
|
||||||
system_content = get_agent_builder_prompt()
|
system_content = get_agent_builder_prompt()
|
||||||
logger.info("Using agent builder system prompt")
|
|
||||||
else:
|
else:
|
||||||
# Use just the default system prompt
|
|
||||||
system_content = default_system_content
|
system_content = default_system_content
|
||||||
logger.info("Using default system prompt only")
|
|
||||||
|
|
||||||
if await is_enabled("knowledge_base"):
|
if await is_enabled("knowledge_base"):
|
||||||
try:
|
try:
|
||||||
|
@ -320,15 +227,11 @@ async def run_agent(
|
||||||
}).execute()
|
}).execute()
|
||||||
|
|
||||||
if kb_result.data and kb_result.data.strip():
|
if kb_result.data and kb_result.data.strip():
|
||||||
logger.info(f"Adding combined knowledge base context to system prompt for thread {thread_id}, agent {current_agent_id}")
|
|
||||||
system_content += "\n\n" + kb_result.data
|
system_content += "\n\n" + kb_result.data
|
||||||
else:
|
|
||||||
logger.debug(f"No knowledge base context found for thread {thread_id}, agent {current_agent_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving knowledge base context for thread {thread_id}: {e}")
|
logger.error(f"Error retrieving knowledge base context for thread {thread_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
if agent_config and (agent_config.get('configured_mcps') or agent_config.get('custom_mcps')) and mcp_wrapper_instance and mcp_wrapper_instance._initialized:
|
if agent_config and (agent_config.get('configured_mcps') or agent_config.get('custom_mcps')) and mcp_wrapper_instance and mcp_wrapper_instance._initialized:
|
||||||
mcp_info = "\n\n--- MCP Tools Available ---\n"
|
mcp_info = "\n\n--- MCP Tools Available ---\n"
|
||||||
mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n"
|
mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n"
|
||||||
|
@ -340,31 +243,16 @@ async def run_agent(
|
||||||
mcp_info += '</invoke>\n'
|
mcp_info += '</invoke>\n'
|
||||||
mcp_info += '</function_calls>\n\n'
|
mcp_info += '</function_calls>\n\n'
|
||||||
|
|
||||||
# List available MCP tools
|
|
||||||
mcp_info += "Available MCP tools:\n"
|
mcp_info += "Available MCP tools:\n"
|
||||||
try:
|
try:
|
||||||
# Get the actual registered schemas from the wrapper
|
|
||||||
registered_schemas = mcp_wrapper_instance.get_schemas()
|
registered_schemas = mcp_wrapper_instance.get_schemas()
|
||||||
for method_name, schema_list in registered_schemas.items():
|
for method_name, schema_list in registered_schemas.items():
|
||||||
if method_name == 'call_mcp_tool':
|
|
||||||
continue # Skip the fallback method
|
|
||||||
|
|
||||||
# Get the schema info
|
|
||||||
for schema in schema_list:
|
for schema in schema_list:
|
||||||
if schema.schema_type == SchemaType.OPENAPI:
|
if schema.schema_type == SchemaType.OPENAPI:
|
||||||
func_info = schema.schema.get('function', {})
|
func_info = schema.schema.get('function', {})
|
||||||
description = func_info.get('description', 'No description available')
|
description = func_info.get('description', 'No description available')
|
||||||
# Extract server name from description if available
|
|
||||||
server_match = description.find('(MCP Server: ')
|
|
||||||
if server_match != -1:
|
|
||||||
server_end = description.find(')', server_match)
|
|
||||||
server_info = description[server_match:server_end+1]
|
|
||||||
else:
|
|
||||||
server_info = ''
|
|
||||||
|
|
||||||
mcp_info += f"- **{method_name}**: {description}\n"
|
mcp_info += f"- **{method_name}**: {description}\n"
|
||||||
|
|
||||||
# Show parameter info
|
|
||||||
params = func_info.get('parameters', {})
|
params = func_info.get('parameters', {})
|
||||||
props = params.get('properties', {})
|
props = params.get('properties', {})
|
||||||
if props:
|
if props:
|
||||||
|
@ -374,7 +262,6 @@ async def run_agent(
|
||||||
logger.error(f"Error listing MCP tools: {e}")
|
logger.error(f"Error listing MCP tools: {e}")
|
||||||
mcp_info += "- Error loading MCP tool list\n"
|
mcp_info += "- Error loading MCP tool list\n"
|
||||||
|
|
||||||
# Add critical instructions for using search results
|
|
||||||
mcp_info += "\n🚨 CRITICAL MCP TOOL RESULT INSTRUCTIONS 🚨\n"
|
mcp_info += "\n🚨 CRITICAL MCP TOOL RESULT INSTRUCTIONS 🚨\n"
|
||||||
mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n"
|
mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n"
|
||||||
mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n"
|
mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n"
|
||||||
|
@ -390,53 +277,20 @@ async def run_agent(
|
||||||
|
|
||||||
system_content += mcp_info
|
system_content += mcp_info
|
||||||
|
|
||||||
system_message = { "role": "system", "content": system_content }
|
return {"role": "system", "content": system_content}
|
||||||
|
|
||||||
iteration_count = 0
|
|
||||||
continue_execution = True
|
|
||||||
|
|
||||||
latest_user_message = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute()
|
class MessageManager:
|
||||||
if latest_user_message.data and len(latest_user_message.data) > 0:
|
def __init__(self, client, thread_id: str, model_name: str, trace: Optional[StatefulTraceClient]):
|
||||||
data = latest_user_message.data[0]['content']
|
self.client = client
|
||||||
if isinstance(data, str):
|
self.thread_id = thread_id
|
||||||
data = json.loads(data)
|
self.model_name = model_name
|
||||||
if trace:
|
self.trace = trace
|
||||||
trace.update(input=data['content'])
|
|
||||||
|
|
||||||
while continue_execution and iteration_count < max_iterations:
|
async def build_temporary_message(self) -> Optional[dict]:
|
||||||
iteration_count += 1
|
temp_message_content_list = []
|
||||||
logger.info(f"🔄 Running iteration {iteration_count} of {max_iterations}...")
|
|
||||||
|
|
||||||
# Billing check on each iteration - still needed within the iterations
|
latest_browser_state_msg = await self.client.table('messages').select('*').eq('thread_id', self.thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
||||||
can_run, message, subscription = await check_billing_status(client, account_id)
|
|
||||||
if not can_run:
|
|
||||||
error_msg = f"Billing limit reached: {message}"
|
|
||||||
if trace:
|
|
||||||
trace.event(name="billing_limit_reached", level="ERROR", status_message=(f"{error_msg}"))
|
|
||||||
# Yield a special message to indicate billing limit reached
|
|
||||||
yield {
|
|
||||||
"type": "status",
|
|
||||||
"status": "stopped",
|
|
||||||
"message": error_msg
|
|
||||||
}
|
|
||||||
break
|
|
||||||
# Check if last message is from assistant using direct Supabase query
|
|
||||||
latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute()
|
|
||||||
if latest_message.data and len(latest_message.data) > 0:
|
|
||||||
message_type = latest_message.data[0].get('type')
|
|
||||||
if message_type == 'assistant':
|
|
||||||
logger.info(f"Last message was from assistant, stopping execution")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="last_message_from_assistant", level="DEFAULT", status_message=(f"Last message was from assistant, stopping execution"))
|
|
||||||
continue_execution = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# ---- Temporary Message Handling (Browser State & Image Context) ----
|
|
||||||
temporary_message = None
|
|
||||||
temp_message_content_list = [] # List to hold text/image blocks
|
|
||||||
|
|
||||||
# Get the latest browser_state message
|
|
||||||
latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
|
||||||
if latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0:
|
if latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0:
|
||||||
try:
|
try:
|
||||||
browser_content = latest_browser_state_msg.data[0]["content"]
|
browser_content = latest_browser_state_msg.data[0]["content"]
|
||||||
|
@ -445,7 +299,6 @@ async def run_agent(
|
||||||
screenshot_base64 = browser_content.get("screenshot_base64")
|
screenshot_base64 = browser_content.get("screenshot_base64")
|
||||||
screenshot_url = browser_content.get("image_url")
|
screenshot_url = browser_content.get("image_url")
|
||||||
|
|
||||||
# Create a copy of the browser state without screenshot data
|
|
||||||
browser_state_text = browser_content.copy()
|
browser_state_text = browser_content.copy()
|
||||||
browser_state_text.pop('screenshot_base64', None)
|
browser_state_text.pop('screenshot_base64', None)
|
||||||
browser_state_text.pop('image_url', None)
|
browser_state_text.pop('image_url', None)
|
||||||
|
@ -456,9 +309,7 @@ async def run_agent(
|
||||||
"text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}"
|
"text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}"
|
||||||
})
|
})
|
||||||
|
|
||||||
# Only add screenshot if model is not Gemini, Anthropic, or OpenAI
|
if 'gemini' in self.model_name.lower() or 'anthropic' in self.model_name.lower() or 'openai' in self.model_name.lower():
|
||||||
if 'gemini' in model_name.lower() or 'anthropic' in model_name.lower() or 'openai' in model_name.lower():
|
|
||||||
# Prioritize screenshot_url if available
|
|
||||||
if screenshot_url:
|
if screenshot_url:
|
||||||
temp_message_content_list.append({
|
temp_message_content_list.append({
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
|
@ -467,34 +318,18 @@ async def run_agent(
|
||||||
"format": "image/jpeg"
|
"format": "image/jpeg"
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if trace:
|
|
||||||
trace.event(name="screenshot_url_added_to_temporary_message", level="DEFAULT", status_message=(f"Screenshot URL added to temporary message."))
|
|
||||||
elif screenshot_base64:
|
elif screenshot_base64:
|
||||||
# Fallback to base64 if URL not available
|
|
||||||
temp_message_content_list.append({
|
temp_message_content_list.append({
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/jpeg;base64,{screenshot_base64}",
|
"url": f"data:image/jpeg;base64,{screenshot_base64}",
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if trace:
|
|
||||||
trace.event(name="screenshot_base64_added_to_temporary_message", level="WARNING", status_message=(f"Screenshot base64 added to temporary message. Prefer screenshot_url if available."))
|
|
||||||
else:
|
|
||||||
logger.warning("Browser state found but no screenshot data.")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="browser_state_found_but_no_screenshot_data", level="WARNING", status_message=(f"Browser state found but no screenshot data."))
|
|
||||||
else:
|
|
||||||
logger.warning("Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message.")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="model_is_gemini_anthropic_or_openai", level="WARNING", status_message=(f"Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message."))
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing browser state: {e}")
|
logger.error(f"Error parsing browser state: {e}")
|
||||||
if trace:
|
|
||||||
trace.event(name="error_parsing_browser_state", level="ERROR", status_message=(f"{e}"))
|
|
||||||
|
|
||||||
# Get the latest image_context message (NEW)
|
latest_image_context_msg = await self.client.table('messages').select('*').eq('thread_id', self.thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute()
|
||||||
latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute()
|
|
||||||
if latest_image_context_msg.data and len(latest_image_context_msg.data) > 0:
|
if latest_image_context_msg.data and len(latest_image_context_msg.data) > 0:
|
||||||
try:
|
try:
|
||||||
image_context_content = latest_image_context_msg.data[0]["content"] if isinstance(latest_image_context_msg.data[0]["content"], dict) else json.loads(latest_image_context_msg.data[0]["content"])
|
image_context_content = latest_image_context_msg.data[0]["content"] if isinstance(latest_image_context_msg.data[0]["content"], dict) else json.loads(latest_image_context_msg.data[0]["content"])
|
||||||
|
@ -513,43 +348,146 @@ async def run_agent(
|
||||||
"url": f"data:{mime_type};base64,{base64_image}",
|
"url": f"data:{mime_type};base64,{base64_image}",
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
else:
|
|
||||||
logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.")
|
|
||||||
|
|
||||||
await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute()
|
await self.client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing image context: {e}")
|
logger.error(f"Error parsing image context: {e}")
|
||||||
if trace:
|
|
||||||
trace.event(name="error_parsing_image_context", level="ERROR", status_message=(f"{e}"))
|
|
||||||
|
|
||||||
# If we have any content, construct the temporary_message
|
|
||||||
if temp_message_content_list:
|
if temp_message_content_list:
|
||||||
temporary_message = {"role": "user", "content": temp_message_content_list}
|
return {"role": "user", "content": temp_message_content_list}
|
||||||
# logger.debug(f"Constructed temporary message with {len(temp_message_content_list)} content blocks.")
|
return None
|
||||||
# ---- End Temporary Message Handling ----
|
|
||||||
|
|
||||||
# Set max_tokens based on model
|
|
||||||
max_tokens = None
|
|
||||||
if "sonnet" in model_name.lower():
|
|
||||||
# Claude 3.5 Sonnet has a limit of 8192 tokens
|
|
||||||
max_tokens = 8192
|
|
||||||
elif "gpt-4" in model_name.lower():
|
|
||||||
max_tokens = 4096
|
|
||||||
elif "gemini-2.5-pro" in model_name.lower():
|
|
||||||
# Gemini 2.5 Pro has 64k max output tokens
|
|
||||||
max_tokens = 64000
|
|
||||||
elif "kimi-k2" in model_name.lower():
|
|
||||||
# Kimi-K2 has 120K context, set reasonable max output tokens
|
|
||||||
max_tokens = 8192
|
|
||||||
|
|
||||||
generation = trace.generation(name="thread_manager.run_thread") if trace else None
|
class AgentRunner:
|
||||||
|
def __init__(self, config: AgentConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
if not self.config.trace:
|
||||||
|
self.config.trace = langfuse.trace(name="run_agent", session_id=self.config.thread_id, metadata={"project_id": self.config.project_id})
|
||||||
|
|
||||||
|
self.thread_manager = ThreadManager(
|
||||||
|
trace=self.config.trace,
|
||||||
|
is_agent_builder=self.config.is_agent_builder or False,
|
||||||
|
target_agent_id=self.config.target_agent_id,
|
||||||
|
agent_config=self.config.agent_config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = await self.thread_manager.db.client
|
||||||
|
self.account_id = await get_account_id_from_thread(self.client, self.config.thread_id)
|
||||||
|
if not self.account_id:
|
||||||
|
raise ValueError("Could not determine account ID for thread")
|
||||||
|
|
||||||
|
project = await self.client.table('projects').select('*').eq('project_id', self.config.project_id).execute()
|
||||||
|
if not project.data or len(project.data) == 0:
|
||||||
|
raise ValueError(f"Project {self.config.project_id} not found")
|
||||||
|
|
||||||
|
project_data = project.data[0]
|
||||||
|
sandbox_info = project_data.get('sandbox', {})
|
||||||
|
if not sandbox_info.get('id'):
|
||||||
|
raise ValueError(f"No sandbox found for project {self.config.project_id}")
|
||||||
|
|
||||||
|
async def setup_tools(self):
|
||||||
|
tool_manager = ToolManager(self.thread_manager, self.config.project_id, self.config.thread_id)
|
||||||
|
|
||||||
|
if self.config.agent_config and self.config.agent_config.get('is_suna_default', False):
|
||||||
|
suna_agent_id = self.config.agent_config['agent_id']
|
||||||
|
tool_manager.register_agent_builder_tools(suna_agent_id)
|
||||||
|
|
||||||
|
if self.config.is_agent_builder:
|
||||||
|
tool_manager.register_agent_builder_tools(self.config.target_agent_id)
|
||||||
|
|
||||||
|
enabled_tools = None
|
||||||
|
if self.config.agent_config and 'agentpress_tools' in self.config.agent_config:
|
||||||
|
raw_tools = self.config.agent_config['agentpress_tools']
|
||||||
|
|
||||||
|
if isinstance(raw_tools, dict):
|
||||||
|
if self.config.agent_config.get('is_suna_default', False) and not raw_tools:
|
||||||
|
enabled_tools = None
|
||||||
|
else:
|
||||||
|
enabled_tools = raw_tools
|
||||||
|
else:
|
||||||
|
enabled_tools = None
|
||||||
|
|
||||||
|
if enabled_tools is None:
|
||||||
|
tool_manager.register_all_tools()
|
||||||
|
else:
|
||||||
|
if not isinstance(enabled_tools, dict):
|
||||||
|
enabled_tools = {}
|
||||||
|
tool_manager.register_custom_tools(enabled_tools)
|
||||||
|
|
||||||
|
async def setup_mcp_tools(self) -> Optional[MCPToolWrapper]:
|
||||||
|
if not self.config.agent_config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mcp_manager = MCPManager(self.thread_manager, self.account_id)
|
||||||
|
return await mcp_manager.register_mcp_tools(self.config.agent_config)
|
||||||
|
|
||||||
|
def get_max_tokens(self) -> Optional[int]:
|
||||||
|
if "sonnet" in self.config.model_name.lower():
|
||||||
|
return 8192
|
||||||
|
elif "gpt-4" in self.config.model_name.lower():
|
||||||
|
return 4096
|
||||||
|
elif "gemini-2.5-pro" in self.config.model_name.lower():
|
||||||
|
return 64000
|
||||||
|
elif "kimi-k2" in self.config.model_name.lower():
|
||||||
|
return 8192
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def run(self) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
await self.setup()
|
||||||
|
await self.setup_tools()
|
||||||
|
mcp_wrapper_instance = await self.setup_mcp_tools()
|
||||||
|
|
||||||
|
system_message = await PromptManager.build_system_prompt(
|
||||||
|
self.config.model_name, self.config.agent_config,
|
||||||
|
self.config.is_agent_builder, self.config.thread_id,
|
||||||
|
mcp_wrapper_instance
|
||||||
|
)
|
||||||
|
|
||||||
|
iteration_count = 0
|
||||||
|
continue_execution = True
|
||||||
|
|
||||||
|
latest_user_message = await self.client.table('messages').select('*').eq('thread_id', self.config.thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute()
|
||||||
|
if latest_user_message.data and len(latest_user_message.data) > 0:
|
||||||
|
data = latest_user_message.data[0]['content']
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = json.loads(data)
|
||||||
|
if self.config.trace:
|
||||||
|
self.config.trace.update(input=data['content'])
|
||||||
|
|
||||||
|
message_manager = MessageManager(self.client, self.config.thread_id, self.config.model_name, self.config.trace)
|
||||||
|
|
||||||
|
while continue_execution and iteration_count < self.config.max_iterations:
|
||||||
|
iteration_count += 1
|
||||||
|
|
||||||
|
can_run, message, subscription = await check_billing_status(self.client, self.account_id)
|
||||||
|
if not can_run:
|
||||||
|
error_msg = f"Billing limit reached: {message}"
|
||||||
|
yield {
|
||||||
|
"type": "status",
|
||||||
|
"status": "stopped",
|
||||||
|
"message": error_msg
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
|
latest_message = await self.client.table('messages').select('*').eq('thread_id', self.config.thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute()
|
||||||
|
if latest_message.data and len(latest_message.data) > 0:
|
||||||
|
message_type = latest_message.data[0].get('type')
|
||||||
|
if message_type == 'assistant':
|
||||||
|
continue_execution = False
|
||||||
|
break
|
||||||
|
|
||||||
|
temporary_message = await message_manager.build_temporary_message()
|
||||||
|
max_tokens = self.get_max_tokens()
|
||||||
|
|
||||||
|
generation = self.config.trace.generation(name="thread_manager.run_thread") if self.config.trace else None
|
||||||
try:
|
try:
|
||||||
# Make the LLM call and process the response
|
response = await self.thread_manager.run_thread(
|
||||||
response = await thread_manager.run_thread(
|
thread_id=self.config.thread_id,
|
||||||
thread_id=thread_id,
|
|
||||||
system_prompt=system_message,
|
system_prompt=system_message,
|
||||||
stream=stream,
|
stream=self.config.stream,
|
||||||
llm_model=model_name,
|
llm_model=self.config.model_name,
|
||||||
llm_temperature=0,
|
llm_temperature=0,
|
||||||
llm_max_tokens=max_tokens,
|
llm_max_tokens=max_tokens,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
|
@ -563,56 +501,40 @@ async def run_agent(
|
||||||
tool_execution_strategy="parallel",
|
tool_execution_strategy="parallel",
|
||||||
xml_adding_strategy="user_message"
|
xml_adding_strategy="user_message"
|
||||||
),
|
),
|
||||||
native_max_auto_continues=native_max_auto_continues,
|
native_max_auto_continues=self.config.native_max_auto_continues,
|
||||||
include_xml_examples=True,
|
include_xml_examples=True,
|
||||||
enable_thinking=enable_thinking,
|
enable_thinking=self.config.enable_thinking,
|
||||||
reasoning_effort=reasoning_effort,
|
reasoning_effort=self.config.reasoning_effort,
|
||||||
enable_context_manager=enable_context_manager,
|
enable_context_manager=self.config.enable_context_manager,
|
||||||
generation=generation
|
generation=generation
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||||
logger.error(f"Error response from run_thread: {response.get('message', 'Unknown error')}")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="error_response_from_run_thread", level="ERROR", status_message=(f"{response.get('message', 'Unknown error')}"))
|
|
||||||
yield response
|
yield response
|
||||||
break
|
break
|
||||||
|
|
||||||
# Track if we see ask, complete, or web-browser-takeover tool calls
|
|
||||||
last_tool_call = None
|
last_tool_call = None
|
||||||
agent_should_terminate = False
|
agent_should_terminate = False
|
||||||
|
|
||||||
# Process the response
|
|
||||||
error_detected = False
|
error_detected = False
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if response is iterable (async generator) or a dict (error case)
|
|
||||||
if hasattr(response, '__aiter__') and not isinstance(response, dict):
|
if hasattr(response, '__aiter__') and not isinstance(response, dict):
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
# If we receive an error chunk, we should stop after this iteration
|
|
||||||
if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error':
|
if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error':
|
||||||
logger.error(f"Error chunk detected: {chunk.get('message', 'Unknown error')}")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="error_chunk_detected", level="ERROR", status_message=(f"{chunk.get('message', 'Unknown error')}"))
|
|
||||||
error_detected = True
|
error_detected = True
|
||||||
yield chunk # Forward the error chunk
|
yield chunk
|
||||||
continue # Continue processing other chunks but don't break yet
|
continue
|
||||||
|
|
||||||
# Check for termination signal in status messages
|
|
||||||
if chunk.get('type') == 'status':
|
if chunk.get('type') == 'status':
|
||||||
try:
|
try:
|
||||||
# Parse the metadata to check for termination signal
|
|
||||||
metadata = chunk.get('metadata', {})
|
metadata = chunk.get('metadata', {})
|
||||||
if isinstance(metadata, str):
|
if isinstance(metadata, str):
|
||||||
metadata = json.loads(metadata)
|
metadata = json.loads(metadata)
|
||||||
|
|
||||||
if metadata.get('agent_should_terminate'):
|
if metadata.get('agent_should_terminate'):
|
||||||
agent_should_terminate = True
|
agent_should_terminate = True
|
||||||
logger.info("Agent termination signal detected in status message")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="agent_termination_signal_detected", level="DEFAULT", status_message="Agent termination signal detected in status message")
|
|
||||||
|
|
||||||
# Extract the tool name from the status content if available
|
|
||||||
content = chunk.get('content', {})
|
content = chunk.get('content', {})
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
content = json.loads(content)
|
content = json.loads(content)
|
||||||
|
@ -622,20 +544,17 @@ async def run_agent(
|
||||||
elif content.get('xml_tag_name'):
|
elif content.get('xml_tag_name'):
|
||||||
last_tool_call = content['xml_tag_name']
|
last_tool_call = content['xml_tag_name']
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.debug(f"Error parsing status message for termination check: {e}")
|
pass
|
||||||
|
|
||||||
# Check for XML versions like <ask>, <complete>, or <web-browser-takeover> in assistant content chunks
|
|
||||||
if chunk.get('type') == 'assistant' and 'content' in chunk:
|
if chunk.get('type') == 'assistant' and 'content' in chunk:
|
||||||
try:
|
try:
|
||||||
# The content field might be a JSON string or object
|
|
||||||
content = chunk.get('content', '{}')
|
content = chunk.get('content', '{}')
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
assistant_content_json = json.loads(content)
|
assistant_content_json = json.loads(content)
|
||||||
else:
|
else:
|
||||||
assistant_content_json = content
|
assistant_content_json = content
|
||||||
|
|
||||||
# The actual text content is nested within
|
|
||||||
assistant_text = assistant_content_json.get('content', '')
|
assistant_text = assistant_content_json.get('content', '')
|
||||||
full_response += assistant_text
|
full_response += assistant_text
|
||||||
if isinstance(assistant_text, str):
|
if isinstance(assistant_text, str):
|
||||||
|
@ -648,49 +567,28 @@ async def run_agent(
|
||||||
xml_tool = 'web-browser-takeover'
|
xml_tool = 'web-browser-takeover'
|
||||||
|
|
||||||
last_tool_call = xml_tool
|
last_tool_call = xml_tool
|
||||||
logger.info(f"Agent used XML tool: {xml_tool}")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="agent_used_xml_tool", level="DEFAULT", status_message=(f"Agent used XML tool: {xml_tool}"))
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Handle cases where content might not be valid JSON
|
pass
|
||||||
logger.warning(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}")
|
except Exception:
|
||||||
if trace:
|
pass
|
||||||
trace.event(name="warning_could_not_parse_assistant_content_json", level="WARNING", status_message=(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}"))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing assistant chunk: {e}")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="error_processing_assistant_chunk", level="ERROR", status_message=(f"Error processing assistant chunk: {e}"))
|
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
# Response is not iterable, likely an error dict
|
|
||||||
logger.error(f"Response is not iterable: {response}")
|
|
||||||
error_detected = True
|
error_detected = True
|
||||||
|
|
||||||
# Check if we should stop based on the last tool call or error
|
|
||||||
if error_detected:
|
if error_detected:
|
||||||
logger.info(f"Stopping due to error detected in response")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="stopping_due_to_error_detected_in_response", level="DEFAULT", status_message=(f"Stopping due to error detected in response"))
|
|
||||||
if generation:
|
if generation:
|
||||||
generation.end(output=full_response, status_message="error_detected", level="ERROR")
|
generation.end(output=full_response, status_message="error_detected", level="ERROR")
|
||||||
break
|
break
|
||||||
|
|
||||||
if agent_should_terminate or last_tool_call in ['ask', 'complete', 'web-browser-takeover']:
|
if agent_should_terminate or last_tool_call in ['ask', 'complete', 'web-browser-takeover']:
|
||||||
logger.info(f"Agent decided to stop with tool: {last_tool_call}")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="agent_decided_to_stop_with_tool", level="DEFAULT", status_message=(f"Agent decided to stop with tool: {last_tool_call}"))
|
|
||||||
if generation:
|
if generation:
|
||||||
generation.end(output=full_response, status_message="agent_stopped")
|
generation.end(output=full_response, status_message="agent_stopped")
|
||||||
continue_execution = False
|
continue_execution = False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Just log the error and re-raise to stop all iterations
|
|
||||||
error_msg = f"Error during response streaming: {str(e)}"
|
error_msg = f"Error during response streaming: {str(e)}"
|
||||||
logger.error(f"Error: {error_msg}")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="error_during_response_streaming", level="ERROR", status_message=(f"Error during response streaming: {str(e)}"))
|
|
||||||
if generation:
|
if generation:
|
||||||
generation.end(output=full_response, status_message=error_msg, level="ERROR")
|
generation.end(output=full_response, status_message=error_msg, level="ERROR")
|
||||||
yield {
|
yield {
|
||||||
|
@ -698,23 +596,55 @@ async def run_agent(
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": error_msg
|
"message": error_msg
|
||||||
}
|
}
|
||||||
# Stop execution immediately on any error
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Just log the error and re-raise to stop all iterations
|
|
||||||
error_msg = f"Error running thread: {str(e)}"
|
error_msg = f"Error running thread: {str(e)}"
|
||||||
logger.error(f"Error: {error_msg}")
|
|
||||||
if trace:
|
|
||||||
trace.event(name="error_running_thread", level="ERROR", status_message=(f"Error running thread: {str(e)}"))
|
|
||||||
yield {
|
yield {
|
||||||
"type": "status",
|
"type": "status",
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": error_msg
|
"message": error_msg
|
||||||
}
|
}
|
||||||
# Stop execution immediately on any error
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if generation:
|
if generation:
|
||||||
generation.end(output=full_response)
|
generation.end(output=full_response)
|
||||||
|
|
||||||
asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush()))
|
asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush()))
|
||||||
|
|
||||||
|
|
||||||
|
async def run_agent(
|
||||||
|
thread_id: str,
|
||||||
|
project_id: str,
|
||||||
|
stream: bool,
|
||||||
|
thread_manager: Optional[ThreadManager] = None,
|
||||||
|
native_max_auto_continues: int = 25,
|
||||||
|
max_iterations: int = 100,
|
||||||
|
model_name: str = "anthropic/claude-sonnet-4-20250514",
|
||||||
|
enable_thinking: Optional[bool] = False,
|
||||||
|
reasoning_effort: Optional[str] = 'low',
|
||||||
|
enable_context_manager: bool = True,
|
||||||
|
agent_config: Optional[dict] = None,
|
||||||
|
trace: Optional[StatefulTraceClient] = None,
|
||||||
|
is_agent_builder: Optional[bool] = False,
|
||||||
|
target_agent_id: Optional[str] = None
|
||||||
|
):
|
||||||
|
config = AgentConfig(
|
||||||
|
thread_id=thread_id,
|
||||||
|
project_id=project_id,
|
||||||
|
stream=stream,
|
||||||
|
native_max_auto_continues=native_max_auto_continues,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
model_name=model_name,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
|
enable_context_manager=enable_context_manager,
|
||||||
|
agent_config=agent_config,
|
||||||
|
trace=trace,
|
||||||
|
is_agent_builder=is_agent_builder,
|
||||||
|
target_agent_id=target_agent_id
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(config)
|
||||||
|
async for chunk in runner.run():
|
||||||
|
yield chunk
|
|
@ -194,16 +194,40 @@ class AgentConfigTool(AgentBuilderBaseTool):
|
||||||
if isinstance(configured_mcps, str):
|
if isinstance(configured_mcps, str):
|
||||||
configured_mcps = json.loads(configured_mcps)
|
configured_mcps = json.loads(configured_mcps)
|
||||||
|
|
||||||
existing_mcps_by_name = {mcp.get('qualifiedName', ''): mcp for mcp in current_configured_mcps}
|
def get_mcp_identifier(mcp):
|
||||||
|
if not isinstance(mcp, dict):
|
||||||
|
return None
|
||||||
|
return (
|
||||||
|
mcp.get('qualifiedName') or
|
||||||
|
mcp.get('name') or
|
||||||
|
f"{mcp.get('type', 'unknown')}_{mcp.get('config', {}).get('url', 'nourl')}" or
|
||||||
|
str(hash(json.dumps(mcp, sort_keys=True)))
|
||||||
|
)
|
||||||
|
|
||||||
|
merged_mcps = []
|
||||||
|
existing_identifiers = set()
|
||||||
|
|
||||||
|
for existing_mcp in current_configured_mcps:
|
||||||
|
identifier = get_mcp_identifier(existing_mcp)
|
||||||
|
if identifier:
|
||||||
|
existing_identifiers.add(identifier)
|
||||||
|
merged_mcps.append(existing_mcp)
|
||||||
|
|
||||||
for new_mcp in configured_mcps:
|
for new_mcp in configured_mcps:
|
||||||
qualified_name = new_mcp.get('qualifiedName', '')
|
identifier = get_mcp_identifier(new_mcp)
|
||||||
if qualified_name:
|
|
||||||
existing_mcps_by_name[qualified_name] = new_mcp
|
|
||||||
else:
|
|
||||||
current_configured_mcps.append(new_mcp)
|
|
||||||
|
|
||||||
current_configured_mcps = list(existing_mcps_by_name.values())
|
if identifier and identifier in existing_identifiers:
|
||||||
|
for i, existing_mcp in enumerate(merged_mcps):
|
||||||
|
if get_mcp_identifier(existing_mcp) == identifier:
|
||||||
|
merged_mcps[i] = new_mcp
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
merged_mcps.append(new_mcp)
|
||||||
|
if identifier:
|
||||||
|
existing_identifiers.add(identifier)
|
||||||
|
|
||||||
|
current_configured_mcps = merged_mcps
|
||||||
|
logger.info(f"MCP merge result: {len(current_configured_mcps)} total MCPs (was {len(current_version.get('configured_mcps', []))}, adding {len(configured_mcps)})")
|
||||||
|
|
||||||
current_custom_mcps = current_version.get('custom_mcps', [])
|
current_custom_mcps = current_version.get('custom_mcps', [])
|
||||||
|
|
||||||
|
|
|
@ -278,13 +278,11 @@ class CredentialProfileTool(AgentBuilderBaseTool):
|
||||||
|
|
||||||
if profile.is_connected and connections:
|
if profile.is_connected and connections:
|
||||||
try:
|
try:
|
||||||
# directly discover MCP servers via the facade
|
|
||||||
from pipedream.domain.entities import ConnectionStatus
|
from pipedream.domain.entities import ConnectionStatus
|
||||||
servers = await self.pipedream_manager.discover_mcp_servers(
|
servers = await self.pipedream_manager.discover_mcp_servers(
|
||||||
external_user_id=profile.external_user_id.value if hasattr(profile.external_user_id, 'value') else str(profile.external_user_id),
|
external_user_id=profile.external_user_id.value if hasattr(profile.external_user_id, 'value') else str(profile.external_user_id),
|
||||||
app_slug=profile.app_slug.value if hasattr(profile.app_slug, 'value') else str(profile.app_slug)
|
app_slug=profile.app_slug.value if hasattr(profile.app_slug, 'value') else str(profile.app_slug)
|
||||||
)
|
)
|
||||||
# filter connected servers
|
|
||||||
connected_servers = [s for s in servers if s.status == ConnectionStatus.CONNECTED]
|
connected_servers = [s for s in servers if s.status == ConnectionStatus.CONNECTED]
|
||||||
if connected_servers:
|
if connected_servers:
|
||||||
tools = [t.name for t in connected_servers[0].available_tools]
|
tools = [t.name for t in connected_servers[0].available_tools]
|
||||||
|
@ -422,7 +420,6 @@ class CredentialProfileTool(AgentBuilderBaseTool):
|
||||||
if not profile:
|
if not profile:
|
||||||
return self.fail_response("Credential profile not found")
|
return self.fail_response("Credential profile not found")
|
||||||
|
|
||||||
# Get current version config
|
|
||||||
agent_result = await client.table('agents').select('current_version_id').eq('agent_id', self.agent_id).execute()
|
agent_result = await client.table('agents').select('current_version_id').eq('agent_id', self.agent_id).execute()
|
||||||
if agent_result.data and agent_result.data[0].get('current_version_id'):
|
if agent_result.data and agent_result.data[0].get('current_version_id'):
|
||||||
version_result = await client.table('agent_versions')\
|
version_result = await client.table('agent_versions')\
|
||||||
|
|
|
@ -145,10 +145,54 @@ class MCPSearchTool(AgentBuilderBaseTool):
|
||||||
"available_triggers": getattr(app_data, 'available_triggers', [])
|
"available_triggers": getattr(app_data, 'available_triggers', [])
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.success_response({
|
available_tools = []
|
||||||
"message": f"Retrieved details for {formatted_app['name']}",
|
try:
|
||||||
"app": formatted_app
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
|
url = f"https://remote.mcp.pipedream.net/?app={app_slug}&externalUserId=tools_preview"
|
||||||
|
payload = {"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 1}
|
||||||
|
headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
async with client.stream("POST", url, json=payload, headers=headers) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line or not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data_str = line[len("data:"):].strip()
|
||||||
|
try:
|
||||||
|
data_obj = json.loads(data_str)
|
||||||
|
tools = data_obj.get("result", {}).get("tools", [])
|
||||||
|
for tool in tools:
|
||||||
|
desc = tool.get("description", "") or ""
|
||||||
|
idx = desc.find("[")
|
||||||
|
if idx != -1:
|
||||||
|
desc = desc[:idx].strip()
|
||||||
|
|
||||||
|
available_tools.append({
|
||||||
|
"name": tool.get("name", ""),
|
||||||
|
"description": desc
|
||||||
})
|
})
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse JSON data: {data_str}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as tools_error:
|
||||||
|
logger.warning(f"Could not fetch MCP tools for {app_slug}: {tools_error}")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"message": f"Retrieved details for {formatted_app['name']}",
|
||||||
|
"app": formatted_app,
|
||||||
|
"available_mcp_tools": available_tools,
|
||||||
|
"total_mcp_tools": len(available_tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
if available_tools:
|
||||||
|
result["message"] += f" - {len(available_tools)} MCP tools available"
|
||||||
|
|
||||||
|
return self.success_response(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.fail_response(f"Error getting app details: {str(e)}")
|
return self.fail_response(f"Error getting app details: {str(e)}")
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema, ToolSchema, SchemaType
|
from agentpress.tool import Tool, ToolResult, ToolSchema, SchemaType
|
||||||
from mcp_module import mcp_manager
|
from mcp_module import mcp_manager
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -77,25 +77,47 @@ class MCPToolWrapper(Tool):
|
||||||
|
|
||||||
logger.info(f"Created {len(self._dynamic_tools)} dynamic MCP tool methods")
|
logger.info(f"Created {len(self._dynamic_tools)} dynamic MCP tool methods")
|
||||||
|
|
||||||
|
# Re-register schemas to pick up the dynamic methods
|
||||||
|
self._register_schemas()
|
||||||
|
logger.info(f"Re-registered schemas after creating dynamic tools - total: {len(self._schemas)}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating dynamic MCP tools: {e}")
|
logger.error(f"Error creating dynamic MCP tools: {e}")
|
||||||
|
|
||||||
def _register_schemas(self):
|
def _register_schemas(self):
|
||||||
|
self._schemas.clear()
|
||||||
|
|
||||||
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
||||||
if hasattr(method, 'tool_schemas'):
|
if hasattr(method, 'tool_schemas'):
|
||||||
self._schemas[name] = method.tool_schemas
|
self._schemas[name] = method.tool_schemas
|
||||||
logger.debug(f"Registered schemas for method '{name}' in {self.__class__.__name__}")
|
logger.debug(f"Registered schemas for method '{name}' in {self.__class__.__name__}")
|
||||||
|
|
||||||
logger.debug(f"Initial registration complete for MCPToolWrapper")
|
if hasattr(self, '_dynamic_tools') and self._dynamic_tools:
|
||||||
|
for tool_name, tool_data in self._dynamic_tools.items():
|
||||||
|
method_name = tool_data.get('method_name')
|
||||||
|
if method_name and method_name in self._schemas:
|
||||||
|
continue
|
||||||
|
|
||||||
|
method = tool_data.get('method')
|
||||||
|
if method and hasattr(method, 'tool_schemas'):
|
||||||
|
self._schemas[method_name] = method.tool_schemas
|
||||||
|
logger.debug(f"Registered dynamic method schemas for '{method_name}'")
|
||||||
|
|
||||||
|
logger.debug(f"Registration complete for MCPToolWrapper - total schemas: {len(self._schemas)}")
|
||||||
|
|
||||||
def get_schemas(self) -> Dict[str, List[ToolSchema]]:
|
def get_schemas(self) -> Dict[str, List[ToolSchema]]:
|
||||||
|
logger.debug(f"get_schemas called - returning {len(self._schemas)} schemas")
|
||||||
|
for method_name in self._schemas:
|
||||||
|
logger.debug(f" - Schema available for: {method_name}")
|
||||||
return self._schemas
|
return self._schemas
|
||||||
|
|
||||||
def __getattr__(self, name: str):
|
def __getattr__(self, name: str):
|
||||||
|
if hasattr(self, 'tool_builder') and self.tool_builder:
|
||||||
method = self.tool_builder.find_method_by_name(name)
|
method = self.tool_builder.find_method_by_name(name)
|
||||||
if method:
|
if method:
|
||||||
return method
|
return method
|
||||||
|
|
||||||
|
if hasattr(self, '_dynamic_tools') and self._dynamic_tools:
|
||||||
for tool_data in self._dynamic_tools.values():
|
for tool_data in self._dynamic_tools.values():
|
||||||
if tool_data.get('method_name') == name:
|
if tool_data.get('method_name') == name:
|
||||||
return tool_data.get('method')
|
return tool_data.get('method')
|
||||||
|
@ -111,9 +133,6 @@ class MCPToolWrapper(Tool):
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
if tool_registry and self._dynamic_tools:
|
if tool_registry and self._dynamic_tools:
|
||||||
logger.info(f"Updating tool registry with {len(self._dynamic_tools)} MCP tools")
|
logger.info(f"Updating tool registry with {len(self._dynamic_tools)} MCP tools")
|
||||||
for method_name, schemas in self._schemas.items():
|
|
||||||
if method_name not in ['call_mcp_tool']:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def get_available_tools(self) -> List[Dict[str, Any]]:
|
async def get_available_tools(self) -> List[Dict[str, Any]]:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
|
@ -123,46 +142,6 @@ class MCPToolWrapper(Tool):
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
return await self.tool_executor.execute_tool(tool_name, arguments)
|
return await self.tool_executor.execute_tool(tool_name, arguments)
|
||||||
|
|
||||||
@openapi_schema({
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "call_mcp_tool",
|
|
||||||
"description": "Execute a tool from any connected MCP server. This is a fallback wrapper that forwards calls to MCP tools. The tool_name should be in the format 'mcp_{server}_{tool}' where {server} is the MCP server's qualified name and {tool} is the specific tool name.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"tool_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The full MCP tool name in format 'mcp_{server}_{tool}', e.g., 'mcp_exa_web_search_exa'"
|
|
||||||
},
|
|
||||||
"arguments": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "The arguments to pass to the MCP tool, as a JSON object. The required arguments depend on the specific tool being called.",
|
|
||||||
"additionalProperties": True
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["tool_name", "arguments"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
@xml_schema(
|
|
||||||
tag_name="call-mcp-tool",
|
|
||||||
mappings=[
|
|
||||||
{"param_name": "tool_name", "node_type": "attribute", "path": "."},
|
|
||||||
{"param_name": "arguments", "node_type": "content", "path": "."}
|
|
||||||
],
|
|
||||||
example='''
|
|
||||||
<function_calls>
|
|
||||||
<invoke name="call_mcp_tool">
|
|
||||||
<parameter name="tool_name">mcp_exa_web_search_exa</parameter>
|
|
||||||
<parameter name="arguments">{"query": "latest developments in AI", "num_results": 10}</parameter>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
'''
|
|
||||||
)
|
|
||||||
async def call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any]) -> ToolResult:
|
|
||||||
return await self._execute_mcp_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -337,6 +337,7 @@ class PipedreamManager:
|
||||||
user_id: str,
|
user_id: str,
|
||||||
enabled_tools: List[str]
|
enabled_tools: List[str]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
from agent.versioning.version_service import get_version_service
|
from agent.versioning.version_service import get_version_service
|
||||||
import copy
|
import copy
|
||||||
|
@ -353,17 +354,24 @@ class PipedreamManager:
|
||||||
|
|
||||||
agent = agent_result.data[0]
|
agent = agent_result.data[0]
|
||||||
|
|
||||||
|
print(f"[DEBUG] Starting update_agent_profile_tools for agent {agent_id}, profile {profile_id}")
|
||||||
|
|
||||||
current_version_data = None
|
current_version_data = None
|
||||||
if agent.get('current_version_id'):
|
if agent.get('current_version_id'):
|
||||||
try:
|
try:
|
||||||
current_version_data = await get_version_service().get_version(
|
version_service = await get_version_service()
|
||||||
|
current_version_data = await version_service.get_version(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
version_id=agent['current_version_id'],
|
version_id=agent['current_version_id'],
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
version_data = current_version_data.to_dict()
|
version_data = current_version_data.to_dict()
|
||||||
current_custom_mcps = version_data.get('custom_mcps', [])
|
current_custom_mcps = version_data.get('custom_mcps', [])
|
||||||
|
print(f"[DEBUG] Retrieved current version {agent['current_version_id']}")
|
||||||
|
print(f"[DEBUG] Current custom_mcps count: {len(current_custom_mcps)}")
|
||||||
|
print(f"[DEBUG] Current custom_mcps: {current_custom_mcps}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"[DEBUG] Error getting current version: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -376,30 +384,40 @@ class PipedreamManager:
|
||||||
configured_mcps = current_version_data.configured_mcps
|
configured_mcps = current_version_data.configured_mcps
|
||||||
agentpress_tools = current_version_data.agentpress_tools
|
agentpress_tools = current_version_data.agentpress_tools
|
||||||
current_custom_mcps = current_version_data.custom_mcps
|
current_custom_mcps = current_version_data.custom_mcps
|
||||||
|
print(f"[DEBUG] Using version data - custom_mcps count: {len(current_custom_mcps)}")
|
||||||
else:
|
else:
|
||||||
system_prompt = ''
|
system_prompt = ''
|
||||||
configured_mcps = []
|
configured_mcps = []
|
||||||
agentpress_tools = {}
|
agentpress_tools = {}
|
||||||
current_custom_mcps = []
|
current_custom_mcps = []
|
||||||
|
print(f"[DEBUG] No version data - starting with empty custom_mcps")
|
||||||
|
|
||||||
|
|
||||||
updated_custom_mcps = copy.deepcopy(current_custom_mcps)
|
updated_custom_mcps = copy.deepcopy(current_custom_mcps)
|
||||||
|
print(f"[DEBUG] After deepcopy - updated_custom_mcps count: {len(updated_custom_mcps)}")
|
||||||
|
print(f"[DEBUG] After deepcopy - updated_custom_mcps: {updated_custom_mcps}")
|
||||||
|
|
||||||
|
# Normalize enabledTools vs enabled_tools
|
||||||
for mcp in updated_custom_mcps:
|
for mcp in updated_custom_mcps:
|
||||||
if 'enabled_tools' in mcp and 'enabledTools' not in mcp:
|
if 'enabled_tools' in mcp and 'enabledTools' not in mcp:
|
||||||
mcp['enabledTools'] = mcp['enabled_tools']
|
mcp['enabledTools'] = mcp['enabled_tools']
|
||||||
elif 'enabledTools' not in mcp and 'enabled_tools' not in mcp:
|
elif 'enabledTools' not in mcp and 'enabled_tools' not in mcp:
|
||||||
mcp['enabledTools'] = []
|
mcp['enabledTools'] = []
|
||||||
|
|
||||||
|
# Look for existing MCP with same profile_id
|
||||||
found_match = False
|
found_match = False
|
||||||
for mcp in updated_custom_mcps:
|
for i, mcp in enumerate(updated_custom_mcps):
|
||||||
|
print(f"[DEBUG] Checking MCP {i}: type={mcp.get('type')}, profile_id={mcp.get('config', {}).get('profile_id')}")
|
||||||
if (mcp.get('type') == 'pipedream' and
|
if (mcp.get('type') == 'pipedream' and
|
||||||
mcp.get('config', {}).get('profile_id') == profile_id):
|
mcp.get('config', {}).get('profile_id') == profile_id):
|
||||||
|
print(f"[DEBUG] Found existing MCP at index {i}, updating tools from {mcp.get('enabledTools', [])} to {enabled_tools}")
|
||||||
mcp['enabledTools'] = enabled_tools
|
mcp['enabledTools'] = enabled_tools
|
||||||
mcp['enabled_tools'] = enabled_tools
|
mcp['enabled_tools'] = enabled_tools
|
||||||
found_match = True
|
found_match = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if not found_match:
|
if not found_match:
|
||||||
|
print(f"[DEBUG] No existing MCP found, creating new one")
|
||||||
new_mcp_config = {
|
new_mcp_config = {
|
||||||
"name": profile.app_name,
|
"name": profile.app_name,
|
||||||
"type": "pipedream",
|
"type": "pipedream",
|
||||||
|
@ -413,7 +431,12 @@ class PipedreamManager:
|
||||||
"enabledTools": enabled_tools,
|
"enabledTools": enabled_tools,
|
||||||
"enabled_tools": enabled_tools
|
"enabled_tools": enabled_tools
|
||||||
}
|
}
|
||||||
|
print(f"[DEBUG] New MCP config: {new_mcp_config}")
|
||||||
updated_custom_mcps.append(new_mcp_config)
|
updated_custom_mcps.append(new_mcp_config)
|
||||||
|
print(f"[DEBUG] After append - updated_custom_mcps count: {len(updated_custom_mcps)}")
|
||||||
|
|
||||||
|
print(f"[DEBUG] Final updated_custom_mcps count: {len(updated_custom_mcps)}")
|
||||||
|
print(f"[DEBUG] Final updated_custom_mcps: {updated_custom_mcps}")
|
||||||
|
|
||||||
version_service = await get_version_service()
|
version_service = await get_version_service()
|
||||||
|
|
||||||
|
@ -428,6 +451,8 @@ class PipedreamManager:
|
||||||
change_description=f"Updated {profile.app_name} tools"
|
change_description=f"Updated {profile.app_name} tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(f"[DEBUG] Created new version {new_version.version_id} with {len(updated_custom_mcps)} custom MCPs")
|
||||||
|
|
||||||
update_result = await client.table('agents').update({
|
update_result = await client.table('agents').update({
|
||||||
'current_version_id': new_version.version_id
|
'current_version_id': new_version.version_id
|
||||||
}).eq('agent_id', agent_id).execute()
|
}).eq('agent_id', agent_id).execute()
|
||||||
|
@ -440,7 +465,7 @@ class PipedreamManager:
|
||||||
'enabled_tools': enabled_tools,
|
'enabled_tools': enabled_tools,
|
||||||
'total_tools': len(enabled_tools),
|
'total_tools': len(enabled_tools),
|
||||||
'version_id': new_version.version_id,
|
'version_id': new_version.version_id,
|
||||||
'version_name': new_version['version_name']
|
'version_name': new_version.version_name
|
||||||
}
|
}
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
|
|
@ -49,8 +49,6 @@ const HIDE_STREAMING_XML_TAGS = new Set([
|
||||||
'crawl-webpage',
|
'crawl-webpage',
|
||||||
'web-search',
|
'web-search',
|
||||||
'see-image',
|
'see-image',
|
||||||
'call-mcp-tool',
|
|
||||||
|
|
||||||
'execute_data_provider_call',
|
'execute_data_provider_call',
|
||||||
'execute_data_provider_endpoint',
|
'execute_data_provider_endpoint',
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ export interface CustomMcp {
|
||||||
headers?: Record<string, string>;
|
headers?: Record<string, string>;
|
||||||
profile_id?: string;
|
profile_id?: string;
|
||||||
};
|
};
|
||||||
enabled_tools: string[];
|
enabledTools: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AgentConfiguration {
|
export interface AgentConfiguration {
|
||||||
|
|
|
@ -83,7 +83,10 @@ export function GetCurrentAgentConfigToolView({
|
||||||
};
|
};
|
||||||
|
|
||||||
const getTotalMcpToolsCount = (mcps: CustomMcp[]) => {
|
const getTotalMcpToolsCount = (mcps: CustomMcp[]) => {
|
||||||
return mcps.reduce((total, mcp) => total + mcp.enabled_tools.length, 0);
|
return mcps.reduce((total, mcp) => {
|
||||||
|
const enabledTools = mcp.enabledTools || [];
|
||||||
|
return total + (Array.isArray(enabledTools) ? enabledTools.length : 0);
|
||||||
|
}, 0);
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -268,12 +271,12 @@ export function GetCurrentAgentConfigToolView({
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<Badge variant="outline" className="text-xs">
|
<Badge variant="outline" className="text-xs">
|
||||||
{mcp.enabled_tools.length} tools
|
{mcp.enabledTools.length} tools
|
||||||
</Badge>
|
</Badge>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="grid grid-cols-2 md:grid-cols-3 gap-2">
|
<div className="grid grid-cols-2 md:grid-cols-3 gap-2">
|
||||||
{mcp.enabled_tools.map((tool, toolIndex) => (
|
{mcp.enabledTools.map((tool, toolIndex) => (
|
||||||
<div key={toolIndex} className="flex items-center gap-1 p-2 bg-zinc-50 dark:bg-zinc-800/50 rounded text-xs">
|
<div key={toolIndex} className="flex items-center gap-1 p-2 bg-zinc-50 dark:bg-zinc-800/50 rounded text-xs">
|
||||||
<Zap className="w-3 h-3 text-zinc-500 dark:text-zinc-400" />
|
<Zap className="w-3 h-3 text-zinc-500 dark:text-zinc-400" />
|
||||||
<span className="text-zinc-700 dark:text-zinc-300 truncate">
|
<span className="text-zinc-700 dark:text-zinc-300 truncate">
|
||||||
|
|
|
@ -387,8 +387,6 @@ const TOOL_DISPLAY_NAMES = new Map([
|
||||||
['web_search', 'Searching Web'],
|
['web_search', 'Searching Web'],
|
||||||
['see_image', 'Viewing Image'],
|
['see_image', 'Viewing Image'],
|
||||||
|
|
||||||
['call_mcp_tool', 'External Tool'],
|
|
||||||
|
|
||||||
['update_agent', 'Updating Agent'],
|
['update_agent', 'Updating Agent'],
|
||||||
['get_current_agent_config', 'Getting Agent Config'],
|
['get_current_agent_config', 'Getting Agent Config'],
|
||||||
['search_mcp_servers', 'Searching MCP Servers'],
|
['search_mcp_servers', 'Searching MCP Servers'],
|
||||||
|
|
Loading…
Reference in New Issue