mirror of https://github.com/kortix-ai/suna.git
Merge remote-tracking branch 'upstream/main' into feat/port-fix-and-panel-ajustments
This commit is contained in:
commit
0c3b520fe6
|
@ -49,6 +49,7 @@ class AgentCreateRequest(BaseModel):
|
|||
description: Optional[str] = None
|
||||
system_prompt: str
|
||||
configured_mcps: Optional[List[Dict[str, Any]]] = []
|
||||
custom_mcps: Optional[List[Dict[str, Any]]] = []
|
||||
agentpress_tools: Optional[Dict[str, Any]] = {}
|
||||
is_default: Optional[bool] = False
|
||||
avatar: Optional[str] = None
|
||||
|
@ -59,6 +60,7 @@ class AgentUpdateRequest(BaseModel):
|
|||
description: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
configured_mcps: Optional[List[Dict[str, Any]]] = None
|
||||
custom_mcps: Optional[List[Dict[str, Any]]] = None
|
||||
agentpress_tools: Optional[Dict[str, Any]] = None
|
||||
is_default: Optional[bool] = None
|
||||
avatar: Optional[str] = None
|
||||
|
@ -71,6 +73,7 @@ class AgentResponse(BaseModel):
|
|||
description: Optional[str]
|
||||
system_prompt: str
|
||||
configured_mcps: List[Dict[str, Any]]
|
||||
custom_mcps: Optional[List[Dict[str, Any]]] = []
|
||||
agentpress_tools: Dict[str, Any]
|
||||
is_default: bool
|
||||
is_public: Optional[bool] = False
|
||||
|
@ -566,6 +569,7 @@ async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_us
|
|||
description=agent_data.get('description'),
|
||||
system_prompt=agent_data['system_prompt'],
|
||||
configured_mcps=agent_data.get('configured_mcps', []),
|
||||
custom_mcps=agent_data.get('custom_mcps', []),
|
||||
agentpress_tools=agent_data.get('agentpress_tools', {}),
|
||||
is_default=agent_data.get('is_default', False),
|
||||
is_public=agent_data.get('is_public', False),
|
||||
|
@ -1184,6 +1188,7 @@ async def get_agents(
|
|||
description=agent.get('description'),
|
||||
system_prompt=agent['system_prompt'],
|
||||
configured_mcps=agent.get('configured_mcps', []),
|
||||
custom_mcps=agent.get('custom_mcps', []),
|
||||
agentpress_tools=agent.get('agentpress_tools', {}),
|
||||
is_default=agent.get('is_default', False),
|
||||
is_public=agent.get('is_public', False),
|
||||
|
@ -1239,6 +1244,7 @@ async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_fr
|
|||
description=agent_data.get('description'),
|
||||
system_prompt=agent_data['system_prompt'],
|
||||
configured_mcps=agent_data.get('configured_mcps', []),
|
||||
custom_mcps=agent_data.get('custom_mcps', []),
|
||||
agentpress_tools=agent_data.get('agentpress_tools', {}),
|
||||
is_default=agent_data.get('is_default', False),
|
||||
is_public=agent_data.get('is_public', False),
|
||||
|
@ -1283,6 +1289,7 @@ async def create_agent(
|
|||
"description": agent_data.description,
|
||||
"system_prompt": agent_data.system_prompt,
|
||||
"configured_mcps": agent_data.configured_mcps or [],
|
||||
"custom_mcps": agent_data.custom_mcps or [],
|
||||
"agentpress_tools": agent_data.agentpress_tools or {},
|
||||
"is_default": agent_data.is_default or False,
|
||||
"avatar": agent_data.avatar,
|
||||
|
@ -1304,6 +1311,7 @@ async def create_agent(
|
|||
description=agent.get('description'),
|
||||
system_prompt=agent['system_prompt'],
|
||||
configured_mcps=agent.get('configured_mcps', []),
|
||||
custom_mcps=agent.get('custom_mcps', []),
|
||||
agentpress_tools=agent.get('agentpress_tools', {}),
|
||||
is_default=agent.get('is_default', False),
|
||||
is_public=agent.get('is_public', False),
|
||||
|
@ -1357,6 +1365,8 @@ async def update_agent(
|
|||
update_data["system_prompt"] = agent_data.system_prompt
|
||||
if agent_data.configured_mcps is not None:
|
||||
update_data["configured_mcps"] = agent_data.configured_mcps
|
||||
if agent_data.custom_mcps is not None:
|
||||
update_data["custom_mcps"] = agent_data.custom_mcps
|
||||
if agent_data.agentpress_tools is not None:
|
||||
update_data["agentpress_tools"] = agent_data.agentpress_tools
|
||||
if agent_data.is_default is not None:
|
||||
|
@ -1396,6 +1406,7 @@ async def update_agent(
|
|||
description=agent.get('description'),
|
||||
system_prompt=agent['system_prompt'],
|
||||
configured_mcps=agent.get('configured_mcps', []),
|
||||
custom_mcps=agent.get('custom_mcps', []),
|
||||
agentpress_tools=agent.get('agentpress_tools', {}),
|
||||
is_default=agent.get('is_default', False),
|
||||
is_public=agent.get('is_public', False),
|
||||
|
|
|
@ -130,43 +130,65 @@ async def run_agent(
|
|||
if config.RAPID_API_KEY and enabled_tools.get('data_providers_tool', {}).get('enabled', False):
|
||||
thread_manager.add_tool(DataProvidersTool)
|
||||
|
||||
# Register MCP tool wrapper if agent has configured MCPs
|
||||
# Register MCP tool wrapper if agent has configured MCPs or custom MCPs
|
||||
mcp_wrapper_instance = None
|
||||
if agent_config and agent_config.get('configured_mcps'):
|
||||
logger.info(f"Registering MCP tool wrapper for {len(agent_config['configured_mcps'])} MCP servers")
|
||||
# Register the tool
|
||||
thread_manager.add_tool(MCPToolWrapper, mcp_configs=agent_config['configured_mcps'])
|
||||
if agent_config:
|
||||
# Merge configured_mcps and custom_mcps
|
||||
all_mcps = []
|
||||
|
||||
# Get the tool instance from the registry
|
||||
# The tool is registered with method names as keys
|
||||
for tool_name, tool_info in thread_manager.tool_registry.tools.items():
|
||||
if isinstance(tool_info['instance'], MCPToolWrapper):
|
||||
mcp_wrapper_instance = tool_info['instance']
|
||||
break
|
||||
# Add standard configured MCPs
|
||||
if agent_config.get('configured_mcps'):
|
||||
all_mcps.extend(agent_config['configured_mcps'])
|
||||
|
||||
# Initialize the MCP tools asynchronously
|
||||
if mcp_wrapper_instance:
|
||||
try:
|
||||
await mcp_wrapper_instance.initialize_and_register_tools()
|
||||
logger.info("MCP tools initialized successfully")
|
||||
# Add custom MCPs
|
||||
if agent_config.get('custom_mcps'):
|
||||
for custom_mcp in agent_config['custom_mcps']:
|
||||
# Transform custom MCP to standard format
|
||||
mcp_config = {
|
||||
'name': custom_mcp['name'],
|
||||
'qualifiedName': f"custom_{custom_mcp['type']}_{custom_mcp['name'].replace(' ', '_').lower()}",
|
||||
'config': custom_mcp['config'],
|
||||
'enabledTools': custom_mcp.get('enabledTools', []),
|
||||
'isCustom': True,
|
||||
'customType': custom_mcp['type']
|
||||
}
|
||||
all_mcps.append(mcp_config)
|
||||
|
||||
if all_mcps:
|
||||
logger.info(f"Registering MCP tool wrapper for {len(all_mcps)} MCP servers (including {len(agent_config.get('custom_mcps', []))} custom)")
|
||||
# Register the tool with all MCPs
|
||||
thread_manager.add_tool(MCPToolWrapper, mcp_configs=all_mcps)
|
||||
|
||||
# Get the tool instance from the registry
|
||||
# The tool is registered with method names as keys
|
||||
for tool_name, tool_info in thread_manager.tool_registry.tools.items():
|
||||
if isinstance(tool_info['instance'], MCPToolWrapper):
|
||||
mcp_wrapper_instance = tool_info['instance']
|
||||
break
|
||||
|
||||
# Initialize the MCP tools asynchronously
|
||||
if mcp_wrapper_instance:
|
||||
try:
|
||||
await mcp_wrapper_instance.initialize_and_register_tools()
|
||||
logger.info("MCP tools initialized successfully")
|
||||
|
||||
# Re-register the updated schemas with the tool registry
|
||||
# This ensures the dynamically created tools are available for function calling
|
||||
updated_schemas = mcp_wrapper_instance.get_schemas()
|
||||
for method_name, schema_list in updated_schemas.items():
|
||||
if method_name != 'call_mcp_tool': # Skip the fallback method
|
||||
# Register each dynamic tool in the registry
|
||||
for schema in schema_list:
|
||||
if schema.schema_type == SchemaType.OPENAPI:
|
||||
thread_manager.tool_registry.tools[method_name] = {
|
||||
"instance": mcp_wrapper_instance,
|
||||
"schema": schema
|
||||
}
|
||||
logger.debug(f"Registered dynamic MCP tool: {method_name}")
|
||||
|
||||
# Re-register the updated schemas with the tool registry
|
||||
# This ensures the dynamically created tools are available for function calling
|
||||
updated_schemas = mcp_wrapper_instance.get_schemas()
|
||||
for method_name, schema_list in updated_schemas.items():
|
||||
if method_name != 'call_mcp_tool': # Skip the fallback method
|
||||
# Register each dynamic tool in the registry
|
||||
for schema in schema_list:
|
||||
if schema.schema_type == SchemaType.OPENAPI:
|
||||
thread_manager.tool_registry.tools[method_name] = {
|
||||
"instance": mcp_wrapper_instance,
|
||||
"schema": schema
|
||||
}
|
||||
logger.debug(f"Registered dynamic MCP tool: {method_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MCP tools: {e}")
|
||||
# Continue without MCP tools if initialization fails
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MCP tools: {e}")
|
||||
# Continue without MCP tools if initialization fails
|
||||
|
||||
# Prepare system prompt
|
||||
# First, get the default system prompt
|
||||
|
@ -200,7 +222,7 @@ async def run_agent(
|
|||
logger.info("Using default system prompt only")
|
||||
|
||||
# Add MCP tool information to system prompt if MCP tools are configured
|
||||
if agent_config and agent_config.get('configured_mcps') and mcp_wrapper_instance and mcp_wrapper_instance._initialized:
|
||||
if agent_config and (agent_config.get('configured_mcps') or agent_config.get('custom_mcps')) and mcp_wrapper_instance and mcp_wrapper_instance._initialized:
|
||||
mcp_info = "\n\n--- MCP Tools Available ---\n"
|
||||
mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n"
|
||||
mcp_info += "MCP tools can be called directly using their native function names in the standard function calling format:\n"
|
||||
|
|
|
@ -11,6 +11,12 @@ from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema, ToolSc
|
|||
from mcp_local.client import MCPManager
|
||||
from utils.logger import logger
|
||||
import inspect
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp import StdioServerParameters
|
||||
import asyncio
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
|
@ -34,26 +40,278 @@ class MCPToolWrapper(Tool):
|
|||
self._initialized = False
|
||||
self._dynamic_tools = {}
|
||||
self._schemas: Dict[str, List[ToolSchema]] = {}
|
||||
self._custom_tools = {} # Store custom MCP tools separately
|
||||
|
||||
# Now initialize the parent class which will call _register_schemas
|
||||
super().__init__()
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""Ensure MCP connections are initialized and dynamic tools are created."""
|
||||
if not self._initialized and self.mcp_configs:
|
||||
logger.info(f"Initializing MCP connections for {len(self.mcp_configs)} servers")
|
||||
"""Ensure MCP servers are initialized."""
|
||||
if not self._initialized:
|
||||
# Initialize standard MCP servers from Smithery
|
||||
standard_configs = [cfg for cfg in self.mcp_configs if not cfg.get('isCustom', False)]
|
||||
custom_configs = [cfg for cfg in self.mcp_configs if cfg.get('isCustom', False)]
|
||||
|
||||
# Initialize standard MCPs through MCPManager
|
||||
if standard_configs:
|
||||
for config in standard_configs:
|
||||
try:
|
||||
await self.mcp_manager.connect_server(config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MCP server {config['qualifiedName']}: {e}")
|
||||
|
||||
# Initialize custom MCPs directly
|
||||
if custom_configs:
|
||||
await self._initialize_custom_mcps(custom_configs)
|
||||
|
||||
# Create dynamic tools for all connected servers
|
||||
await self._create_dynamic_tools()
|
||||
self._initialized = True
|
||||
|
||||
async def _connect_sse_server(self, server_name, server_config, all_tools, timeout):
|
||||
url = server_config["url"]
|
||||
headers = server_config.get("headers", {})
|
||||
|
||||
async with asyncio.timeout(timeout):
|
||||
try:
|
||||
await self.mcp_manager.connect_all(self.mcp_configs)
|
||||
await self._create_dynamic_tools()
|
||||
self._initialized = True
|
||||
except ValueError as e:
|
||||
if "SMITHERY_API_KEY" in str(e):
|
||||
logger.error("MCP Error: SMITHERY_API_KEY environment variable is not set")
|
||||
logger.error("To use MCP tools, please:")
|
||||
logger.error("1. Get your API key from https://smithery.ai")
|
||||
logger.error("2. Set it as an environment variable: export SMITHERY_API_KEY='your-key-here'")
|
||||
logger.error("3. Or add it to your .env file: SMITHERY_API_KEY=your-key-here")
|
||||
raise
|
||||
async with sse_client(url, headers=headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
all_tools[server_name] = {
|
||||
"status": "connected",
|
||||
"transport": "sse",
|
||||
"url": url,
|
||||
"tools": tools_info
|
||||
}
|
||||
|
||||
logger.info(f" {server_name}: Connected via SSE ({len(tools_info)} tools)")
|
||||
except TypeError as e:
|
||||
if "unexpected keyword argument" in str(e):
|
||||
async with sse_client(url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
all_tools[server_name] = {
|
||||
"status": "connected",
|
||||
"transport": "sse",
|
||||
"url": url,
|
||||
"tools": tools_info
|
||||
}
|
||||
logger.info(f" {server_name}: Connected via SSE ({len(tools_info)} tools)")
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _connect_streamable_http_server(self, url):
|
||||
async with streamablehttp_client(url) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
tool_result = await session.list_tools()
|
||||
print(f"Connected via HTTP ({len(tool_result.tools)} tools)")
|
||||
|
||||
tools_info = []
|
||||
for tool in tool_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
return tools_info
|
||||
|
||||
async def _connect_stdio_server(self, server_name, server_config, all_tools, timeout):
|
||||
"""Connect to a stdio-based MCP server."""
|
||||
server_params = StdioServerParameters(
|
||||
command=server_config["command"],
|
||||
args=server_config.get("args", []),
|
||||
env=server_config.get("env", {})
|
||||
)
|
||||
|
||||
async with asyncio.timeout(timeout):
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
all_tools[server_name] = {
|
||||
"status": "connected",
|
||||
"transport": "stdio",
|
||||
"tools": tools_info
|
||||
}
|
||||
|
||||
logger.info(f" {server_name}: Connected via stdio ({len(tools_info)} tools)")
|
||||
|
||||
async def _initialize_custom_mcps(self, custom_configs):
|
||||
"""Initialize custom MCP servers."""
|
||||
for config in custom_configs:
|
||||
try:
|
||||
logger.info(f"Initializing custom MCP: {config}")
|
||||
custom_type = config.get('customType', 'sse')
|
||||
server_config = config.get('config', {})
|
||||
enabled_tools = config.get('enabledTools', [])
|
||||
server_name = config.get('name', 'Unknown')
|
||||
|
||||
logger.info(f"Initializing custom MCP: {server_name} (type: {custom_type})")
|
||||
|
||||
if custom_type == 'sse':
|
||||
if 'url' not in server_config:
|
||||
logger.error(f"Custom MCP {server_name}: Missing 'url' in config")
|
||||
continue
|
||||
|
||||
url = server_config['url']
|
||||
logger.info(f"Initializing custom MCP {url} with SSE type")
|
||||
|
||||
try:
|
||||
# Use the working connect_sse_server method
|
||||
all_tools = {}
|
||||
await self._connect_sse_server(server_name, server_config, all_tools, 15)
|
||||
|
||||
# Process the results
|
||||
if server_name in all_tools and all_tools[server_name].get('status') == 'connected':
|
||||
tools_info = all_tools[server_name].get('tools', [])
|
||||
tools_registered = 0
|
||||
|
||||
for tool_info in tools_info:
|
||||
tool_name_from_server = tool_info['name']
|
||||
if not enabled_tools or tool_name_from_server in enabled_tools:
|
||||
tool_name = f"custom_{server_name.replace(' ', '_').lower()}_{tool_name_from_server}"
|
||||
self._custom_tools[tool_name] = {
|
||||
'name': tool_name,
|
||||
'description': tool_info['description'],
|
||||
'parameters': tool_info['input_schema'],
|
||||
'server': server_name,
|
||||
'original_name': tool_name_from_server,
|
||||
'is_custom': True,
|
||||
'custom_type': custom_type,
|
||||
'custom_config': server_config
|
||||
}
|
||||
tools_registered += 1
|
||||
logger.debug(f"Registered custom tool: {tool_name}")
|
||||
|
||||
logger.info(f"Successfully initialized custom MCP {server_name} with {tools_registered} tools")
|
||||
else:
|
||||
logger.error(f"Failed to connect to custom MCP {server_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom MCP {server_name}: Connection failed - {str(e)}")
|
||||
continue
|
||||
|
||||
elif custom_type == 'http':
|
||||
if 'url' not in server_config:
|
||||
logger.error(f"Custom MCP {server_name}: Missing 'url' in config")
|
||||
continue
|
||||
|
||||
url = server_config['url']
|
||||
logger.info(f"Initializing custom MCP {url} with HTTP type")
|
||||
|
||||
try:
|
||||
|
||||
tools_info = await self._connect_streamable_http_server(url)
|
||||
tools_registered = 0
|
||||
|
||||
for tool_info in tools_info:
|
||||
tool_name_from_server = tool_info['name']
|
||||
if not enabled_tools or tool_name_from_server in enabled_tools:
|
||||
tool_name = f"custom_{server_name.replace(' ', '_').lower()}_{tool_name_from_server}"
|
||||
self._custom_tools[tool_name] = {
|
||||
'name': tool_name,
|
||||
'description': tool_info['description'],
|
||||
'parameters': tool_info['inputSchema'],
|
||||
'server': server_name,
|
||||
'original_name': tool_name_from_server,
|
||||
'is_custom': True,
|
||||
'custom_type': custom_type,
|
||||
'custom_config': server_config
|
||||
}
|
||||
tools_registered += 1
|
||||
logger.debug(f"Registered custom tool: {tool_name}")
|
||||
|
||||
logger.info(f"Successfully initialized custom MCP {server_name} with {tools_registered} tools")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom MCP {server_name}: Connection failed - {str(e)}")
|
||||
continue
|
||||
|
||||
elif custom_type == 'json':
|
||||
if 'command' not in server_config:
|
||||
logger.error(f"Custom MCP {server_name}: Missing 'command' in config")
|
||||
continue
|
||||
|
||||
logger.info(f"Initializing custom MCP {server_name} with JSON/stdio type")
|
||||
|
||||
try:
|
||||
# Use the stdio connection method
|
||||
all_tools = {}
|
||||
await self._connect_stdio_server(server_name, server_config, all_tools, 15)
|
||||
|
||||
# Process the results
|
||||
if server_name in all_tools and all_tools[server_name].get('status') == 'connected':
|
||||
tools_info = all_tools[server_name].get('tools', [])
|
||||
tools_registered = 0
|
||||
|
||||
for tool_info in tools_info:
|
||||
tool_name_from_server = tool_info['name']
|
||||
if not enabled_tools or tool_name_from_server in enabled_tools:
|
||||
tool_name = f"custom_{server_name.replace(' ', '_').lower()}_{tool_name_from_server}"
|
||||
self._custom_tools[tool_name] = {
|
||||
'name': tool_name,
|
||||
'description': tool_info['description'],
|
||||
'parameters': tool_info['input_schema'],
|
||||
'server': server_name,
|
||||
'original_name': tool_name_from_server,
|
||||
'is_custom': True,
|
||||
'custom_type': custom_type,
|
||||
'custom_config': server_config
|
||||
}
|
||||
tools_registered += 1
|
||||
logger.debug(f"Registered custom tool: {tool_name}")
|
||||
|
||||
logger.info(f"Successfully initialized custom MCP {server_name} with {tools_registered} tools")
|
||||
else:
|
||||
logger.error(f"Failed to connect to custom MCP {server_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom MCP {server_name}: Connection failed - {str(e)}")
|
||||
continue
|
||||
|
||||
else:
|
||||
logger.error(f"Custom MCP {server_name}: Unsupported type '{custom_type}', supported types are 'sse', 'http' and 'json'")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize custom MCP {config.get('name', 'Unknown')}: {e}")
|
||||
continue
|
||||
|
||||
async def initialize_and_register_tools(self, tool_registry=None):
|
||||
"""Initialize MCP tools and optionally update the tool registry.
|
||||
|
@ -80,6 +338,7 @@ class MCPToolWrapper(Tool):
|
|||
async def _create_dynamic_tools(self):
|
||||
"""Create dynamic tool methods for each available MCP tool."""
|
||||
try:
|
||||
# Get standard MCP tools
|
||||
available_tools = self.mcp_manager.get_all_tools_openapi()
|
||||
|
||||
for tool_info in available_tools:
|
||||
|
@ -87,6 +346,16 @@ class MCPToolWrapper(Tool):
|
|||
if tool_name:
|
||||
# Create a dynamic method for this tool with proper OpenAI schema
|
||||
self._create_dynamic_method(tool_name, tool_info)
|
||||
|
||||
# Get custom MCP tools
|
||||
for tool_name, tool_info in self._custom_tools.items():
|
||||
# Convert custom tool info to the expected format
|
||||
openapi_tool_info = {
|
||||
"name": tool_name,
|
||||
"description": tool_info['description'],
|
||||
"parameters": tool_info['parameters']
|
||||
}
|
||||
self._create_dynamic_method(tool_name, openapi_tool_info)
|
||||
|
||||
logger.info(f"Created {len(self._dynamic_tools)} dynamic MCP tool methods")
|
||||
|
||||
|
@ -200,121 +469,175 @@ class MCPToolWrapper(Tool):
|
|||
return self.mcp_manager.get_all_tools_openapi()
|
||||
|
||||
async def _execute_mcp_tool(self, tool_name: str, arguments: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute an MCP tool call (internal implementation).
|
||||
|
||||
Args:
|
||||
tool_name: The MCP tool name (e.g., "mcp_exa_web_search_exa")
|
||||
arguments: The arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
ToolResult with the tool execution result
|
||||
"""
|
||||
"""Execute an MCP tool call."""
|
||||
await self._ensure_initialized()
|
||||
logger.info(f"Executing MCP tool {tool_name} with arguments {arguments}")
|
||||
try:
|
||||
# Ensure MCP connections are initialized
|
||||
await self._ensure_initialized()
|
||||
|
||||
logger.info(f"Executing MCP tool {tool_name} with args: {arguments}")
|
||||
|
||||
# Parse arguments if they're provided as a JSON string
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
return self.fail_response(f"Invalid JSON in arguments: {str(e)}")
|
||||
|
||||
# Execute the tool through MCP manager
|
||||
result = await self.mcp_manager.execute_tool(tool_name, arguments)
|
||||
|
||||
# Parse tool name to extract server and tool info for metadata
|
||||
parts = tool_name.split("_", 2)
|
||||
server_name = parts[1] if len(parts) > 1 else "unknown"
|
||||
original_tool_name = parts[2] if len(parts) > 2 else tool_name
|
||||
|
||||
# Check if it's an error
|
||||
if result.get("isError", False):
|
||||
error_result = {
|
||||
"mcp_metadata": {
|
||||
"server_name": server_name,
|
||||
"tool_name": original_tool_name,
|
||||
"full_tool_name": tool_name,
|
||||
"arguments_count": len(arguments) if isinstance(arguments, dict) else 0,
|
||||
"is_mcp_tool": True
|
||||
},
|
||||
"content": result.get("content", ""),
|
||||
"isError": True,
|
||||
"raw_result": result
|
||||
}
|
||||
return self.fail_response(json.dumps(error_result, indent=2))
|
||||
|
||||
# Format the result in an LLM-friendly way with content first
|
||||
actual_content = result.get("content", "")
|
||||
|
||||
# Create a clear, LLM-friendly response that puts the content first
|
||||
llm_friendly_result = f"""MCP Tool Result from {server_name.upper()}:
|
||||
|
||||
{actual_content}
|
||||
|
||||
---
|
||||
Tool Metadata: {json.dumps({
|
||||
"server": server_name,
|
||||
"tool": original_tool_name,
|
||||
"full_tool_name": tool_name,
|
||||
"arguments_used": arguments,
|
||||
"is_mcp_tool": True
|
||||
}, indent=2)}"""
|
||||
# Check if it's a custom MCP tool first
|
||||
if tool_name in self._custom_tools:
|
||||
tool_info = self._custom_tools[tool_name]
|
||||
return await self._execute_custom_mcp_tool(tool_name, arguments, tool_info)
|
||||
else:
|
||||
# Use standard MCP manager for Smithery servers
|
||||
result = await self.mcp_manager.execute_tool(tool_name, arguments)
|
||||
|
||||
# Return successful result with LLM-friendly formatting
|
||||
return self.success_response(llm_friendly_result)
|
||||
|
||||
except ValueError as e:
|
||||
# Handle specific MCP errors (like invalid tool name format)
|
||||
error_msg = str(e)
|
||||
logger.error(f"ValueError executing MCP tool {tool_name}: {error_msg}")
|
||||
|
||||
# Parse tool name for metadata even in error case
|
||||
parts = tool_name.split("_", 2) if "_" in tool_name else ["", "unknown", "unknown"]
|
||||
server_name = parts[1] if len(parts) > 1 else "unknown"
|
||||
original_tool_name = parts[2] if len(parts) > 2 else "unknown"
|
||||
|
||||
error_result = {
|
||||
"mcp_metadata": {
|
||||
"server_name": server_name,
|
||||
"tool_name": original_tool_name,
|
||||
"full_tool_name": tool_name,
|
||||
"arguments_count": len(arguments) if isinstance(arguments, dict) else 0,
|
||||
"is_mcp_tool": True
|
||||
},
|
||||
"content": error_msg,
|
||||
"isError": True,
|
||||
"error_type": "ValueError"
|
||||
}
|
||||
|
||||
return self.fail_response(json.dumps(error_result, indent=2))
|
||||
|
||||
if isinstance(result, dict):
|
||||
if result.get('isError', False):
|
||||
return self.fail_response(result.get('content', 'Tool execution failed'))
|
||||
else:
|
||||
return self.success_response(result.get('content', result))
|
||||
else:
|
||||
return self.success_response(result)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing MCP tool {tool_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error(f"Error executing MCP tool {tool_name}: {str(e)}")
|
||||
return self.fail_response(f"Error executing tool: {str(e)}")
|
||||
|
||||
async def _execute_custom_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], tool_info: Dict[str, Any]) -> ToolResult:
|
||||
"""Execute a custom MCP tool call."""
|
||||
try:
|
||||
custom_type = tool_info['custom_type']
|
||||
custom_config = tool_info['custom_config']
|
||||
original_tool_name = tool_info['original_name']
|
||||
|
||||
# Parse tool name for metadata even in error case
|
||||
parts = tool_name.split("_", 2) if "_" in tool_name else ["", "unknown", "unknown"]
|
||||
server_name = parts[1] if len(parts) > 1 else "unknown"
|
||||
original_tool_name = parts[2] if len(parts) > 2 else "unknown"
|
||||
if custom_type == 'sse':
|
||||
# Execute SSE-based custom MCP using the same pattern as _connect_sse_server
|
||||
url = custom_config['url']
|
||||
headers = custom_config.get('headers', {})
|
||||
|
||||
async with asyncio.timeout(30): # 30 second timeout for tool execution
|
||||
try:
|
||||
# Try with headers first (same pattern as _connect_sse_server)
|
||||
async with sse_client(url, headers=headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(original_tool_name, arguments)
|
||||
|
||||
# Handle the result properly
|
||||
if hasattr(result, 'content'):
|
||||
content = result.content
|
||||
if isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if hasattr(item, 'text'):
|
||||
text_parts.append(item.text)
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content_str = "\n".join(text_parts)
|
||||
elif hasattr(content, 'text'):
|
||||
content_str = content.text
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
return self.success_response(content_str)
|
||||
else:
|
||||
return self.success_response(str(result))
|
||||
|
||||
except TypeError as e:
|
||||
if "unexpected keyword argument" in str(e):
|
||||
# Fallback: try without headers (exact pattern from _connect_sse_server)
|
||||
async with sse_client(url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(original_tool_name, arguments)
|
||||
|
||||
# Handle the result properly
|
||||
if hasattr(result, 'content'):
|
||||
content = result.content
|
||||
if isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if hasattr(item, 'text'):
|
||||
text_parts.append(item.text)
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content_str = "\n".join(text_parts)
|
||||
elif hasattr(content, 'text'):
|
||||
content_str = content.text
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
return self.success_response(content_str)
|
||||
else:
|
||||
return self.success_response(str(result))
|
||||
else:
|
||||
raise
|
||||
|
||||
error_result = {
|
||||
"mcp_metadata": {
|
||||
"server_name": server_name,
|
||||
"tool_name": original_tool_name,
|
||||
"full_tool_name": tool_name,
|
||||
"arguments_count": len(arguments) if isinstance(arguments, dict) else 0,
|
||||
"is_mcp_tool": True
|
||||
},
|
||||
"content": error_msg,
|
||||
"isError": True,
|
||||
"error_type": "Exception"
|
||||
}
|
||||
|
||||
return self.fail_response(json.dumps(error_result, indent=2))
|
||||
elif custom_type == 'http':
|
||||
# Execute HTTP-based custom MCP
|
||||
url = custom_config['url']
|
||||
|
||||
async with asyncio.timeout(30): # 30 second timeout for tool execution
|
||||
async with streamablehttp_client(url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(original_tool_name, arguments)
|
||||
|
||||
# Handle the result properly
|
||||
if hasattr(result, 'content'):
|
||||
content = result.content
|
||||
if isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if hasattr(item, 'text'):
|
||||
text_parts.append(item.text)
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content_str = "\n".join(text_parts)
|
||||
elif hasattr(content, 'text'):
|
||||
content_str = content.text
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
return self.success_response(content_str)
|
||||
else:
|
||||
return self.success_response(str(result))
|
||||
|
||||
elif custom_type == 'json':
|
||||
# Execute stdio-based custom MCP using the same pattern as _connect_stdio_server
|
||||
server_params = StdioServerParameters(
|
||||
command=custom_config["command"],
|
||||
args=custom_config.get("args", []),
|
||||
env=custom_config.get("env", {})
|
||||
)
|
||||
|
||||
async with asyncio.timeout(30): # 30 second timeout for tool execution
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(original_tool_name, arguments)
|
||||
|
||||
# Handle the result properly
|
||||
if hasattr(result, 'content'):
|
||||
content = result.content
|
||||
if isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if hasattr(item, 'text'):
|
||||
text_parts.append(item.text)
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content_str = "\n".join(text_parts)
|
||||
elif hasattr(content, 'text'):
|
||||
content_str = content.text
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
return self.success_response(content_str)
|
||||
else:
|
||||
return self.success_response(str(result))
|
||||
else:
|
||||
return self.fail_response(f"Unsupported custom MCP type: {custom_type}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return self.fail_response(f"Tool execution timeout for {tool_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing custom MCP tool {tool_name}: {str(e)}")
|
||||
return self.fail_response(f"Error executing custom tool: {str(e)}")
|
||||
|
||||
# Keep the original call_mcp_tool method as a fallback
|
||||
@openapi_schema({
|
||||
|
|
|
@ -1712,36 +1712,35 @@ class ResponseProcessor:
|
|||
}
|
||||
}
|
||||
|
||||
structured_result_v2 = {
|
||||
"tool_execution": {
|
||||
"function_name": function_name,
|
||||
"xml_tag_name": xml_tag_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"arguments": arguments,
|
||||
"result": {
|
||||
"success": result.success if hasattr(result, 'success') else True,
|
||||
"output": output, # Now properly structured for frontend
|
||||
"error": getattr(result, 'error', None) if hasattr(result, 'error') else None
|
||||
},
|
||||
"execution_details": {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"parsing_details": parsing_details
|
||||
}
|
||||
}
|
||||
}
|
||||
# STRUCTURED_OUTPUT_TOOLS = {
|
||||
# "str_replace",
|
||||
# "get_data_provider_endpoints",
|
||||
# }
|
||||
|
||||
# summary_output = result.output if hasattr(result, 'output') else str(result)
|
||||
|
||||
# For backwards compatibility with LLM, also include a human-readable summary
|
||||
# Use the original string output for the summary to avoid complex object representation
|
||||
# if xml_tag_name:
|
||||
# status = "completed successfully" if structured_result_v1["tool_execution"]["result"]["success"] else "failed"
|
||||
# summary = f"Tool '{xml_tag_name}' {status}. Output: {summary_output}"
|
||||
# else:
|
||||
# status = "completed successfully" if structured_result_v1["tool_execution"]["result"]["success"] else "failed"
|
||||
# summary = f"Function '{function_name}' {status}. Output: {summary_output}"
|
||||
|
||||
# if self.is_agent_builder:
|
||||
# return summary
|
||||
# if function_name in STRUCTURED_OUTPUT_TOOLS:
|
||||
# return structured_result_v1
|
||||
# else:
|
||||
# return summary
|
||||
|
||||
summary_output = result.output if hasattr(result, 'output') else str(result)
|
||||
success_status = structured_result_v1["tool_execution"]["result"]["success"]
|
||||
|
||||
# Create a more comprehensive summary for the LLM
|
||||
if xml_tag_name:
|
||||
# For XML tools, create a readable summary
|
||||
status = "completed successfully" if structured_result_v1["tool_execution"]["result"]["success"] else "failed"
|
||||
summary = f"Tool '{xml_tag_name}' {status}. Output: {summary_output}"
|
||||
else:
|
||||
# For native tools, create a readable summary
|
||||
status = "completed successfully" if structured_result_v1["tool_execution"]["result"]["success"] else "failed"
|
||||
summary = f"Function '{function_name}' {status}. Output: {summary_output}"
|
||||
|
||||
|
@ -1750,8 +1749,9 @@ class ResponseProcessor:
|
|||
elif function_name == "get_data_provider_endpoints":
|
||||
logger.info(f"Returning sumnary for data provider call: {summary}")
|
||||
return summary
|
||||
|
||||
else:
|
||||
return structured_result_v1
|
||||
return json.dumps(structured_result_v1)
|
||||
|
||||
def _format_xml_tool_result(self, tool_call: Dict[str, Any], result: ToolResult) -> str:
|
||||
"""Format a tool result wrapped in a <tool_result> tag.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, HTTPException, Response, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import sentry
|
||||
from contextlib import asynccontextmanager
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
|
@ -10,19 +10,24 @@ from dotenv import load_dotenv
|
|||
from utils.config import config, EnvMode
|
||||
import asyncio
|
||||
from utils.logger import logger
|
||||
import uuid
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
# Import the agent API module
|
||||
from agent import api as agent_api
|
||||
from sandbox import api as sandbox_api
|
||||
from services import billing as billing_api
|
||||
from services import transcription as transcription_api
|
||||
from services.mcp_custom import discover_custom_tools
|
||||
import sys
|
||||
|
||||
# Load environment variables (these will be available through config)
|
||||
load_dotenv()
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
# Initialize managers
|
||||
db = DBConnection()
|
||||
instance_id = "single"
|
||||
|
@ -33,20 +38,15 @@ MAX_CONCURRENT_IPS = 25
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info(f"Starting up FastAPI application with instance ID: {instance_id} in {config.ENV_MODE.value} mode")
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
await db.initialize()
|
||||
|
||||
# Initialize the agent API with shared resources
|
||||
agent_api.initialize(
|
||||
db,
|
||||
instance_id
|
||||
)
|
||||
|
||||
# Initialize the sandbox API with shared resources
|
||||
sandbox_api.initialize(db)
|
||||
|
||||
# Initialize Redis connection
|
||||
|
@ -124,19 +124,16 @@ app.add_middleware(
|
|||
allow_headers=["Content-Type", "Authorization"],
|
||||
)
|
||||
|
||||
# Include the agent router with a prefix
|
||||
app.include_router(agent_api.router, prefix="/api")
|
||||
|
||||
# Include the sandbox router with a prefix
|
||||
app.include_router(sandbox_api.router, prefix="/api")
|
||||
|
||||
# Include the billing router with a prefix
|
||||
app.include_router(billing_api.router, prefix="/api")
|
||||
|
||||
# Import and include the MCP router
|
||||
from mcp_local import api as mcp_api
|
||||
|
||||
app.include_router(mcp_api.router, prefix="/api")
|
||||
# Include the transcription router with a prefix
|
||||
|
||||
app.include_router(transcription_api.router, prefix="/api")
|
||||
|
||||
@app.get("/api/health")
|
||||
|
@ -149,10 +146,28 @@ async def health_check():
|
|||
"instance_id": instance_id
|
||||
}
|
||||
|
||||
class CustomMCPDiscoverRequest(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@app.post("/api/mcp/discover-custom-tools")
|
||||
async def discover_custom_mcp_tools(request: CustomMCPDiscoverRequest):
|
||||
try:
|
||||
return await discover_custom_tools(request.type, request.config)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error discovering custom MCP tools: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
workers = 2
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
workers = 1
|
||||
|
||||
logger.info(f"Starting server on 0.0.0.0:8000 with {workers} workers")
|
||||
uvicorn.run(
|
||||
|
@ -160,5 +175,5 @@ if __name__ == "__main__":
|
|||
host="0.0.0.0",
|
||||
port=8000,
|
||||
workers=workers,
|
||||
# reload=True
|
||||
loop="asyncio"
|
||||
)
|
|
@ -36,4 +36,6 @@ langfuse>=2.60.5
|
|||
httpx>=0.24.0
|
||||
Pillow>=10.0.0
|
||||
sentry-sdk[fastapi]>=2.29.1
|
||||
mcp>=1.0.0
|
||||
mcp>=1.0.0
|
||||
mcp_use>=1.0.0
|
||||
aiohttp>=3.9.0
|
|
@ -0,0 +1,129 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import Dict, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from fastapi import HTTPException # type: ignore
|
||||
from utils.logger import logger
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client # type: ignore
|
||||
from mcp.client.streamable_http import streamablehttp_client # type: ignore
|
||||
|
||||
async def connect_streamable_http_server(url):
|
||||
async with streamablehttp_client(url) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
tool_result = await session.list_tools()
|
||||
print(f"Connected via HTTP ({len(tool_result.tools)} tools)")
|
||||
|
||||
tools_info = []
|
||||
for tool in tool_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
return tools_info
|
||||
|
||||
async def discover_custom_tools(request_type: str, config: Dict[str, Any]):
|
||||
logger.info(f"Received custom MCP discovery request: type={request_type}")
|
||||
logger.debug(f"Request config: {config}")
|
||||
|
||||
tools = []
|
||||
server_name = None
|
||||
|
||||
if request_type == 'http':
|
||||
if 'url' not in config:
|
||||
raise HTTPException(status_code=400, detail="HTTP configuration must include 'url' field")
|
||||
url = config['url']
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(15):
|
||||
tools_info = await connect_streamable_http_server(url)
|
||||
for tool_info in tools_info:
|
||||
tools.append({
|
||||
"name": tool_info["name"],
|
||||
"description": tool_info["description"],
|
||||
"inputSchema": tool_info["inputSchema"]
|
||||
})
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(status_code=408, detail="Connection timeout - server took too long to respond")
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to HTTP MCP server: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}")
|
||||
|
||||
elif request_type == 'sse':
|
||||
if 'url' not in config:
|
||||
raise HTTPException(status_code=400, detail="SSE configuration must include 'url' field")
|
||||
|
||||
url = config['url']
|
||||
headers = config.get('headers', {})
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(15):
|
||||
try:
|
||||
async with sse_client(url, headers=headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
for tool_info in tools_info:
|
||||
tools.append({
|
||||
"name": tool_info["name"],
|
||||
"description": tool_info["description"],
|
||||
"inputSchema": tool_info["input_schema"]
|
||||
})
|
||||
except TypeError as e:
|
||||
if "unexpected keyword argument" in str(e):
|
||||
async with sse_client(url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
for tool_info in tools_info:
|
||||
tools.append({
|
||||
"name": tool_info["name"],
|
||||
"description": tool_info["description"],
|
||||
"inputSchema": tool_info["input_schema"]
|
||||
})
|
||||
else:
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(status_code=408, detail="Connection timeout - server took too long to respond")
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to SSE MCP server: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid server type. Must be 'http' or 'sse'")
|
||||
|
||||
response_data = {"tools": tools, "count": len(tools)}
|
||||
|
||||
if server_name:
|
||||
response_data["serverName"] = server_name
|
||||
|
||||
logger.info(f"Returning {len(tools)} tools for server {server_name}")
|
||||
return response_data
|
|
@ -0,0 +1,299 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import Dict, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from fastapi import HTTPException # type: ignore
|
||||
from utils.logger import logger
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client # type: ignore
|
||||
from mcp.client.streamable_http import streamablehttp_client # type: ignore
|
||||
|
||||
windows_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# def run_mcp_stdio_sync(command, args, env_vars, timeout=30):
|
||||
# try:
|
||||
# env = os.environ.copy()
|
||||
# env.update(env_vars)
|
||||
|
||||
# full_command = [command] + args
|
||||
|
||||
# process = subprocess.Popen(
|
||||
# full_command,
|
||||
# stdin=subprocess.PIPE,
|
||||
# stdout=subprocess.PIPE,
|
||||
# stderr=subprocess.PIPE,
|
||||
# env=env,
|
||||
# text=True,
|
||||
# bufsize=0,
|
||||
# creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0
|
||||
# )
|
||||
|
||||
# init_request = {
|
||||
# "jsonrpc": "2.0",
|
||||
# "id": 1,
|
||||
# "method": "initialize",
|
||||
# "params": {
|
||||
# "protocolVersion": "2024-11-05",
|
||||
# "capabilities": {},
|
||||
# "clientInfo": {"name": "mcp-client", "version": "1.0.0"}
|
||||
# }
|
||||
# }
|
||||
|
||||
# process.stdin.write(json.dumps(init_request) + "\n")
|
||||
# process.stdin.flush()
|
||||
|
||||
# init_response_line = process.stdout.readline().strip()
|
||||
# if not init_response_line:
|
||||
# raise Exception("No response from MCP server during initialization")
|
||||
|
||||
# init_response = json.loads(init_response_line)
|
||||
|
||||
# init_notification = {
|
||||
# "jsonrpc": "2.0",
|
||||
# "method": "notifications/initialized"
|
||||
# }
|
||||
# process.stdin.write(json.dumps(init_notification) + "\n")
|
||||
# process.stdin.flush()
|
||||
|
||||
# tools_request = {
|
||||
# "jsonrpc": "2.0",
|
||||
# "id": 2,
|
||||
# "method": "tools/list",
|
||||
# "params": {}
|
||||
# }
|
||||
|
||||
# process.stdin.write(json.dumps(tools_request) + "\n")
|
||||
# process.stdin.flush()
|
||||
|
||||
# tools_response_line = process.stdout.readline().strip()
|
||||
# if not tools_response_line:
|
||||
# raise Exception("No response from MCP server for tools list")
|
||||
|
||||
# tools_response = json.loads(tools_response_line)
|
||||
|
||||
# tools_info = []
|
||||
# if "result" in tools_response and "tools" in tools_response["result"]:
|
||||
# for tool in tools_response["result"]["tools"]:
|
||||
# tool_info = {
|
||||
# "name": tool["name"],
|
||||
# "description": tool.get("description", ""),
|
||||
# "input_schema": tool.get("inputSchema", {})
|
||||
# }
|
||||
# tools_info.append(tool_info)
|
||||
|
||||
# return {
|
||||
# "status": "connected",
|
||||
# "transport": "stdio",
|
||||
# "tools": tools_info
|
||||
# }
|
||||
|
||||
# except subprocess.TimeoutExpired:
|
||||
# return {
|
||||
# "status": "error",
|
||||
# "error": f"Process timeout after {timeout} seconds",
|
||||
# "tools": []
|
||||
# }
|
||||
# except json.JSONDecodeError as e:
|
||||
# return {
|
||||
# "status": "error",
|
||||
# "error": f"Invalid JSON response: {str(e)}",
|
||||
# "tools": []
|
||||
# }
|
||||
# except Exception as e:
|
||||
# return {
|
||||
# "status": "error",
|
||||
# "error": str(e),
|
||||
# "tools": []
|
||||
# }
|
||||
# finally:
|
||||
# try:
|
||||
# if 'process' in locals():
|
||||
# process.terminate()
|
||||
# process.wait(timeout=5)
|
||||
# except:
|
||||
# pass
|
||||
|
||||
|
||||
# async def connect_stdio_server_windows(server_name, server_config, all_tools, timeout):
|
||||
# """Windows-compatible stdio connection using subprocess"""
|
||||
|
||||
# logger.info(f"Connecting to {server_name} using Windows subprocess method")
|
||||
|
||||
# command = server_config["command"]
|
||||
# args = server_config.get("args", [])
|
||||
# env_vars = server_config.get("env", {})
|
||||
|
||||
# loop = asyncio.get_event_loop()
|
||||
# result = await loop.run_in_executor(
|
||||
# windows_executor,
|
||||
# run_mcp_stdio_sync,
|
||||
# command,
|
||||
# args,
|
||||
# env_vars,
|
||||
# timeout
|
||||
# )
|
||||
|
||||
# all_tools[server_name] = result
|
||||
|
||||
# if result["status"] == "connected":
|
||||
# logger.info(f" {server_name}: Connected via Windows subprocess ({len(result['tools'])} tools)")
|
||||
# else:
|
||||
# logger.error(f" {server_name}: Error - {result['error']}")
|
||||
|
||||
|
||||
# async def list_mcp_tools_mixed_windows(config, timeout=15):
|
||||
# all_tools = {}
|
||||
|
||||
# if "mcpServers" not in config:
|
||||
# return all_tools
|
||||
|
||||
# mcp_servers = config["mcpServers"]
|
||||
|
||||
# for server_name, server_config in mcp_servers.items():
|
||||
# logger.info(f"Connecting to MCP server: {server_name}")
|
||||
# if server_config.get("disabled", False):
|
||||
# all_tools[server_name] = {"status": "disabled", "tools": []}
|
||||
# logger.info(f" {server_name}: Disabled")
|
||||
# continue
|
||||
|
||||
# try:
|
||||
# await connect_stdio_server_windows(server_name, server_config, all_tools, timeout)
|
||||
|
||||
# except asyncio.TimeoutError:
|
||||
# all_tools[server_name] = {
|
||||
# "status": "error",
|
||||
# "error": f"Connection timeout after {timeout} seconds",
|
||||
# "tools": []
|
||||
# }
|
||||
# logger.error(f" {server_name}: Timeout after {timeout} seconds")
|
||||
# except Exception as e:
|
||||
# error_msg = str(e)
|
||||
# all_tools[server_name] = {
|
||||
# "status": "error",
|
||||
# "error": error_msg,
|
||||
# "tools": []
|
||||
# }
|
||||
# logger.error(f" {server_name}: Error - {error_msg}")
|
||||
# import traceback
|
||||
# logger.debug(f"Full traceback for {server_name}: {traceback.format_exc()}")
|
||||
|
||||
# return all_tools
|
||||
|
||||
|
||||
async def discover_custom_tools(request_type: str, config: Dict[str, Any]):
|
||||
logger.info(f"Received custom MCP discovery request: type={request_type}")
|
||||
logger.debug(f"Request config: {config}")
|
||||
|
||||
tools = []
|
||||
server_name = None
|
||||
|
||||
# if request_type == 'json':
|
||||
# try:
|
||||
# all_tools = await list_mcp_tools_mixed_windows(config, timeout=30)
|
||||
# if "mcpServers" in config and config["mcpServers"]:
|
||||
# server_name = list(config["mcpServers"].keys())[0]
|
||||
|
||||
# if server_name in all_tools:
|
||||
# server_info = all_tools[server_name]
|
||||
# if server_info["status"] == "connected":
|
||||
# tools = server_info["tools"]
|
||||
# logger.info(f"Found {len(tools)} tools for server {server_name}")
|
||||
# else:
|
||||
# error_msg = server_info.get("error", "Unknown error")
|
||||
# logger.error(f"Server {server_name} failed: {error_msg}")
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"Failed to connect to MCP server '{server_name}': {error_msg}"
|
||||
# )
|
||||
# else:
|
||||
# logger.error(f"Server {server_name} not found in results")
|
||||
# raise HTTPException(status_code=400, detail=f"Server '{server_name}' not found in results")
|
||||
# else:
|
||||
# logger.error("No MCP servers configured")
|
||||
# raise HTTPException(status_code=400, detail="No MCP servers configured")
|
||||
|
||||
# except HTTPException:
|
||||
# raise
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error connecting to stdio MCP server: {e}")
|
||||
# import traceback
|
||||
# logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
# raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}")
|
||||
|
||||
# if request_type == 'http':
|
||||
# if 'url' not in config:
|
||||
# raise HTTPException(status_code=400, detail="HTTP configuration must include 'url' field")
|
||||
# url = config['url']
|
||||
# await connect_streamable_http_server(url)
|
||||
# tools = await connect_streamable_http_server(url)
|
||||
|
||||
# elif request_type == 'sse':
|
||||
# if 'url' not in config:
|
||||
# raise HTTPException(status_code=400, detail="SSE configuration must include 'url' field")
|
||||
|
||||
# url = config['url']
|
||||
# headers = config.get('headers', {})
|
||||
|
||||
# try:
|
||||
# async with asyncio.timeout(15):
|
||||
# try:
|
||||
# async with sse_client(url, headers=headers) as (read, write):
|
||||
# async with ClientSession(read, write) as session:
|
||||
# await session.initialize()
|
||||
# tools_result = await session.list_tools()
|
||||
# tools_info = []
|
||||
# for tool in tools_result.tools:
|
||||
# tool_info = {
|
||||
# "name": tool.name,
|
||||
# "description": tool.description,
|
||||
# "input_schema": tool.inputSchema
|
||||
# }
|
||||
# tools_info.append(tool_info)
|
||||
|
||||
# for tool_info in tools_info:
|
||||
# tools.append({
|
||||
# "name": tool_info["name"],
|
||||
# "description": tool_info["description"],
|
||||
# "inputSchema": tool_info["input_schema"]
|
||||
# })
|
||||
# except TypeError as e:
|
||||
# if "unexpected keyword argument" in str(e):
|
||||
# async with sse_client(url) as (read, write):
|
||||
# async with ClientSession(read, write) as session:
|
||||
# await session.initialize()
|
||||
# tools_result = await session.list_tools()
|
||||
# tools_info = []
|
||||
# for tool in tools_result.tools:
|
||||
# tool_info = {
|
||||
# "name": tool.name,
|
||||
# "description": tool.description,
|
||||
# "input_schema": tool.inputSchema
|
||||
# }
|
||||
# tools_info.append(tool_info)
|
||||
|
||||
# for tool_info in tools_info:
|
||||
# tools.append({
|
||||
# "name": tool_info["name"],
|
||||
# "description": tool_info["description"],
|
||||
# "inputSchema": tool_info["input_schema"]
|
||||
# })
|
||||
# else:
|
||||
# raise
|
||||
# except asyncio.TimeoutError:
|
||||
# raise HTTPException(status_code=408, detail="Connection timeout - server took too long to respond")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error connecting to SSE MCP server: {e}")
|
||||
# raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}")
|
||||
# else:
|
||||
# raise HTTPException(status_code=400, detail="Invalid server type. Must be 'json' or 'sse'")
|
||||
|
||||
# response_data = {"tools": tools, "count": len(tools)}
|
||||
|
||||
# if server_name:
|
||||
# response_data["serverName"] = server_name
|
||||
|
||||
# logger.info(f"Returning {len(tools)} tools for server {server_name}")
|
||||
# return response_data
|
|
@ -0,0 +1,9 @@
|
|||
BEGIN;
|
||||
|
||||
ALTER TABLE agents ADD COLUMN IF NOT EXISTS custom_mcps JSONB DEFAULT '[]'::jsonb;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_agents_custom_mcps ON agents USING GIN (custom_mcps);
|
||||
|
||||
COMMENT ON COLUMN agents.custom_mcps IS 'Stores custom MCP server configurations added by users (JSON or SSE endpoints)';
|
||||
|
||||
COMMIT;
|
|
@ -0,0 +1,106 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for custom MCP functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from agent.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
async def test_custom_mcp():
|
||||
"""Test custom MCP configuration and tool discovery"""
|
||||
|
||||
# Example custom MCP configuration (Playwright)
|
||||
custom_mcp_config = {
|
||||
'name': 'Playwright Test',
|
||||
'qualifiedName': 'custom_json_playwright_test',
|
||||
'config': {
|
||||
'command': 'npx',
|
||||
'args': ['@modelcontextprotocol/server-playwright'],
|
||||
'env': {'DISPLAY': ':1'}
|
||||
},
|
||||
'enabledTools': ['screenshot', 'click', 'type'],
|
||||
'isCustom': True,
|
||||
'customType': 'json'
|
||||
}
|
||||
|
||||
# Example SSE custom MCP configuration
|
||||
sse_custom_mcp_config = {
|
||||
'name': 'Mem0 Test',
|
||||
'qualifiedName': 'custom_sse_mem0_test',
|
||||
'config': {
|
||||
'url': 'https://mcp.composio.dev/partner/composio/mem0/sse?customerId=test',
|
||||
'headers': {}
|
||||
},
|
||||
'enabledTools': ['add_memory', 'search_memory'],
|
||||
'isCustom': True,
|
||||
'customType': 'sse'
|
||||
}
|
||||
|
||||
print("🧪 Testing Custom MCP Tool Wrapper")
|
||||
print("=" * 50)
|
||||
|
||||
# Test with just the JSON custom MCP
|
||||
try:
|
||||
print("\n1. Testing JSON Custom MCP (Playwright)...")
|
||||
wrapper = MCPToolWrapper(mcp_configs=[custom_mcp_config])
|
||||
|
||||
# Initialize the wrapper
|
||||
await wrapper._ensure_initialized()
|
||||
|
||||
# Get available tools
|
||||
tools = await wrapper.get_available_tools()
|
||||
print(f" ✅ Found {len(tools)} tools")
|
||||
|
||||
for tool in tools:
|
||||
print(f" - {tool.get('name', 'Unknown')}: {tool.get('description', 'No description')}")
|
||||
|
||||
# Get schemas
|
||||
schemas = wrapper.get_schemas()
|
||||
print(f" ✅ Generated {len(schemas)} tool schemas")
|
||||
|
||||
for method_name, schema_list in schemas.items():
|
||||
print(f" - Method: {method_name}")
|
||||
for schema in schema_list:
|
||||
if hasattr(schema, 'schema') and 'function' in schema.schema:
|
||||
func_name = schema.schema['function'].get('name', 'Unknown')
|
||||
func_desc = schema.schema['function'].get('description', 'No description')
|
||||
print(f" Function: {func_name} - {func_desc}")
|
||||
|
||||
await wrapper.cleanup()
|
||||
print(" ✅ JSON Custom MCP test completed")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ JSON Custom MCP test failed: {e}")
|
||||
|
||||
# Test with SSE custom MCP (this might fail if the endpoint is not accessible)
|
||||
try:
|
||||
print("\n2. Testing SSE Custom MCP (Mem0)...")
|
||||
wrapper2 = MCPToolWrapper(mcp_configs=[sse_custom_mcp_config])
|
||||
|
||||
# This might timeout or fail if the endpoint is not accessible
|
||||
await asyncio.wait_for(wrapper2._ensure_initialized(), timeout=10)
|
||||
|
||||
tools = await wrapper2.get_available_tools()
|
||||
print(f" ✅ Found {len(tools)} tools")
|
||||
|
||||
schemas = wrapper2.get_schemas()
|
||||
print(f" ✅ Generated {len(schemas)} tool schemas")
|
||||
|
||||
await wrapper2.cleanup()
|
||||
print(" ✅ SSE Custom MCP test completed")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(" ⚠️ SSE Custom MCP test timed out (expected if endpoint not accessible)")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ SSE Custom MCP test failed: {e} (expected if endpoint not accessible)")
|
||||
|
||||
print("\n🎉 Custom MCP testing completed!")
|
||||
print("\nTo use custom MCPs in your agent:")
|
||||
print("1. Add custom MCPs through the frontend dialog")
|
||||
print("2. Save the agent configuration")
|
||||
print("3. Start a new agent run - custom MCP tools will be available")
|
||||
print("4. The LLM can call custom MCP tools directly by their function names")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_custom_mcp())
|
|
@ -0,0 +1,258 @@
|
|||
import asyncio
|
||||
import warnings
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp import StdioServerParameters
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning)
|
||||
|
||||
|
||||
async def list_mcp_tools_mixed(config, timeout=15):
|
||||
all_tools = {}
|
||||
|
||||
if "mcpServers" not in config:
|
||||
return all_tools
|
||||
|
||||
mcp_servers = config["mcpServers"]
|
||||
|
||||
for server_name, server_config in mcp_servers.items():
|
||||
print(f"Connecting to {server_name}...")
|
||||
if server_config.get("disabled", False):
|
||||
all_tools[server_name] = {"status": "disabled", "tools": []}
|
||||
print(f" {server_name}: Disabled")
|
||||
continue
|
||||
|
||||
try:
|
||||
if "url" in server_config:
|
||||
url = server_config["url"]
|
||||
await connect_streamable_http_server(url)
|
||||
# if "/sse" in url or server_config.get("transport") == "sse":
|
||||
# await connect_sse_server(server_name, server_config, all_tools, timeout)
|
||||
else:
|
||||
await connect_stdio_server(server_name, server_config, all_tools, timeout)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
all_tools[server_name] = {
|
||||
"status": "error",
|
||||
"error": f"Connection timeout after {timeout} seconds",
|
||||
"tools": []
|
||||
}
|
||||
print(f" {server_name}: Timeout")
|
||||
except Exception as e:
|
||||
all_tools[server_name] = {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"tools": []
|
||||
}
|
||||
print(f" {server_name}: Error - {str(e)[:50]}...")
|
||||
|
||||
return all_tools
|
||||
|
||||
|
||||
def extract_tools_from_response(data):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
for key in ["tools", "data", "result", "items", "response"]:
|
||||
if key in data:
|
||||
value = data[key]
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
elif isinstance(value, dict) and "tools" in value:
|
||||
return value["tools"]
|
||||
|
||||
if "result" in data and isinstance(data["result"], dict):
|
||||
if "tools" in data["result"]:
|
||||
return data["result"]["tools"]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
async def connect_streamable_http_server(url):
|
||||
async with streamablehttp_client(url) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
tool_result = await session.list_tools()
|
||||
print(f"Connected via SSE ({len(tool_result.tools)} tools)")
|
||||
return tool_result
|
||||
|
||||
async def connect_sse_server(server_name, server_config, all_tools, timeout):
|
||||
url = server_config["url"]
|
||||
headers = server_config.get("headers", {})
|
||||
|
||||
async with asyncio.timeout(timeout):
|
||||
try:
|
||||
async with sse_client(url, headers=headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
all_tools[server_name] = {
|
||||
"status": "connected",
|
||||
"transport": "sse",
|
||||
"url": url,
|
||||
"tools": tools_info
|
||||
}
|
||||
|
||||
print(f" {server_name}: Connected via SSE ({len(tools_info)} tools)")
|
||||
except TypeError as e:
|
||||
if "unexpected keyword argument" in str(e):
|
||||
async with sse_client(url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
all_tools[server_name] = {
|
||||
"status": "connected",
|
||||
"transport": "sse",
|
||||
"url": url,
|
||||
"tools": tools_info
|
||||
}
|
||||
print(f" {server_name}: Connected via SSE ({len(tools_info)} tools)")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
async def connect_stdio_server(server_name, server_config, all_tools, timeout):
|
||||
server_params = StdioServerParameters(
|
||||
command=server_config["command"],
|
||||
args=server_config.get("args", []),
|
||||
env=server_config.get("env", {})
|
||||
)
|
||||
|
||||
async with asyncio.timeout(timeout):
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools_info = []
|
||||
for tool in tools_result.tools:
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.inputSchema
|
||||
}
|
||||
tools_info.append(tool_info)
|
||||
|
||||
all_tools[server_name] = {
|
||||
"status": "connected",
|
||||
"transport": "stdio",
|
||||
"tools": tools_info
|
||||
}
|
||||
|
||||
print(f" {server_name}: Connected via stdio ({len(tools_info)} tools)")
|
||||
|
||||
|
||||
def print_mcp_tools(all_tools):
|
||||
if not all_tools:
|
||||
print("No MCP servers configured.")
|
||||
return
|
||||
|
||||
total_tools = sum(len(server_info["tools"]) for server_info in all_tools.values())
|
||||
print(f"Found {len(all_tools)} MCP server(s) with {total_tools} total tools:")
|
||||
print("=" * 60)
|
||||
|
||||
for server_name, server_info in all_tools.items():
|
||||
status = server_info["status"]
|
||||
tools = server_info["tools"]
|
||||
transport = server_info.get("transport", "unknown")
|
||||
|
||||
print(f"\nServer: {server_name}")
|
||||
print(f"Status: {status.upper()}")
|
||||
print(f"Transport: {transport.upper()}")
|
||||
|
||||
if server_info.get("url"):
|
||||
print(f"URL: {server_info['url']}")
|
||||
|
||||
if status == "error":
|
||||
print(f"Error: {server_info['error']}")
|
||||
elif status == "disabled":
|
||||
print("Server is disabled in configuration")
|
||||
elif status == "connected":
|
||||
if tools:
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
for tool in tools:
|
||||
print(f" • {tool['name']}")
|
||||
if tool['description']:
|
||||
print(f" Description: {tool['description']}")
|
||||
if tool.get('input_schema'):
|
||||
schema = tool['input_schema']
|
||||
if 'properties' in schema:
|
||||
params = list(schema['properties'].keys())
|
||||
print(f" Parameters: {', '.join(params)}")
|
||||
print()
|
||||
else:
|
||||
print("No tools available")
|
||||
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
async def main():
|
||||
config = {
|
||||
"mcpServers": {
|
||||
"mem0": {
|
||||
"url": "https://mcp.composio.dev/composio/server/8f56a575-1a7d-422a-a383-0e9701af9d61/mcp?useComposioHelperActions=true",
|
||||
# "transport": "sse"
|
||||
},
|
||||
# "airbnb": {
|
||||
# "command": "npx",
|
||||
# "args": ["-y", "@openbnb/mcp-server-airbnb", "--ignore-robots-txt"]
|
||||
# },
|
||||
# "playwright": {
|
||||
# "command": "npx",
|
||||
# "args": ["@playwright/mcp@latest"],
|
||||
# "env": {"DISPLAY": ":1"}
|
||||
# },
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
print("Discovering MCP tools from mixed transports (stdio, SSE, HTTP)...")
|
||||
all_tools = await list_mcp_tools_mixed(config, timeout=20)
|
||||
print("\n" + "="*60)
|
||||
print_mcp_tools(all_tools)
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
print("Done.")
|
||||
|
||||
|
||||
def list_tools_sync(config):
|
||||
return asyncio.run(list_mcp_tools_mixed(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
finally:
|
||||
import sys
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
|
@ -393,7 +393,7 @@ export const AgentBuilderChat = React.memo(function AgentBuilderChat({
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-shrink-0 md:pb-4 md:px-12">
|
||||
<div className="flex-shrink-0 md:pb-4 md:px-12 px-4">
|
||||
<ChatInput
|
||||
ref={chatInputRef}
|
||||
onSubmit={threadId ? handleSubmitMessage : handleSubmitFirstMessage}
|
||||
|
|
|
@ -5,17 +5,60 @@ import { MCPConfiguration } from './mcp-configuration';
|
|||
import { MCPConfigurationNew } from './mcp/mcp-configuration-new';
|
||||
|
||||
interface AgentMCPConfigurationProps {
|
||||
mcps: Array<{ name: string; qualifiedName: string; config: any; enabledTools?: string[] }>;
|
||||
onMCPsChange: (mcps: Array<{ name: string; qualifiedName: string; config: any; enabledTools?: string[] }>) => void;
|
||||
mcps: Array<{ name: string; qualifiedName: string; config: any; enabledTools?: string[]; isCustom?: boolean; customType?: 'http' | 'sse' }>;
|
||||
customMcps?: Array<{ name: string; type: 'http' | 'sse'; config: any; enabledTools: string[] }>;
|
||||
onMCPsChange: (mcps: Array<{ name: string; qualifiedName: string; config: any; enabledTools?: string[]; isCustom?: boolean; customType?: 'http' | 'sse' }>) => void;
|
||||
onCustomMCPsChange?: (customMcps: Array<{ name: string; type: 'http' | 'sse'; config: any; enabledTools: string[] }>) => void;
|
||||
onBatchMCPChange?: (updates: { configured_mcps: any[]; custom_mcps: any[] }) => void;
|
||||
}
|
||||
|
||||
export const AgentMCPConfiguration = ({ mcps, onMCPsChange }: AgentMCPConfigurationProps) => {
|
||||
export const AgentMCPConfiguration = ({ mcps, customMcps = [], onMCPsChange, onCustomMCPsChange, onBatchMCPChange }: AgentMCPConfigurationProps) => {
|
||||
const allMcps = React.useMemo(() => {
|
||||
const combined = [...mcps];
|
||||
customMcps.forEach(customMcp => {
|
||||
combined.push({
|
||||
name: customMcp.name,
|
||||
qualifiedName: `custom_${customMcp.type}_${customMcp.name.replace(' ', '_').toLowerCase()}`,
|
||||
config: customMcp.config,
|
||||
enabledTools: customMcp.enabledTools,
|
||||
isCustom: true,
|
||||
customType: customMcp.type as 'http' | 'sse'
|
||||
});
|
||||
});
|
||||
|
||||
return combined;
|
||||
}, [mcps, customMcps]);
|
||||
|
||||
const handleConfigurationChange = (updatedMcps: Array<{ name: string; qualifiedName: string; config: any; enabledTools?: string[]; isCustom?: boolean; customType?: 'http' | 'sse' }>) => {
|
||||
const standardMcps = updatedMcps.filter(mcp => !mcp.isCustom);
|
||||
const customMcpsList = updatedMcps.filter(mcp => mcp.isCustom);
|
||||
|
||||
const transformedCustomMcps = customMcpsList.map(mcp => ({
|
||||
name: mcp.name,
|
||||
type: (mcp.customType || 'http') as 'http' | 'sse',
|
||||
config: mcp.config,
|
||||
enabledTools: mcp.enabledTools || []
|
||||
}));
|
||||
|
||||
if (onBatchMCPChange) {
|
||||
onBatchMCPChange({
|
||||
configured_mcps: standardMcps,
|
||||
custom_mcps: transformedCustomMcps
|
||||
});
|
||||
} else {
|
||||
onMCPsChange(standardMcps);
|
||||
if (onCustomMCPsChange) {
|
||||
onCustomMCPsChange(transformedCustomMcps);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card className='px-0 bg-transparent border-none shadow-none'>
|
||||
<CardContent className='px-0'>
|
||||
<MCPConfigurationNew
|
||||
configuredMCPs={mcps}
|
||||
onConfigurationChange={onMCPsChange}
|
||||
configuredMCPs={allMcps}
|
||||
onConfigurationChange={handleConfigurationChange}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
|
|
@ -17,6 +17,7 @@ interface AgentCreateRequest {
|
|||
description: string;
|
||||
system_prompt: string;
|
||||
configured_mcps: Array<{ name: string; qualifiedName: string; config: any; enabledTools?: string[] }>;
|
||||
custom_mcps?: Array<{ name: string; type: 'json' | 'sse'; config: any; enabledTools: string[] }>;
|
||||
agentpress_tools: Record<string, { enabled: boolean; description: string }>;
|
||||
is_default: boolean;
|
||||
}
|
||||
|
@ -34,6 +35,7 @@ const initialFormData: AgentCreateRequest = {
|
|||
description: '',
|
||||
system_prompt: 'Describe the agent\'s role, behavior, and expertise...',
|
||||
configured_mcps: [],
|
||||
custom_mcps: [],
|
||||
agentpress_tools: Object.fromEntries(
|
||||
Object.entries(DEFAULT_AGENTPRESS_TOOLS).map(([key, value]) => [
|
||||
key,
|
||||
|
@ -78,7 +80,17 @@ export const CreateAgentDialog = ({ isOpen, onOpenChange, onAgentCreated }: Crea
|
|||
};
|
||||
|
||||
const handleMCPConfigurationChange = (mcps: any[]) => {
|
||||
handleInputChange('configured_mcps', mcps);
|
||||
// Separate standard and custom MCPs
|
||||
const standardMcps = mcps.filter(mcp => !mcp.isCustom);
|
||||
const customMcps = mcps.filter(mcp => mcp.isCustom).map(mcp => ({
|
||||
name: mcp.name,
|
||||
type: mcp.customType as 'json' | 'sse',
|
||||
config: mcp.config,
|
||||
enabledTools: mcp.enabledTools || []
|
||||
}));
|
||||
|
||||
handleInputChange('configured_mcps', standardMcps);
|
||||
handleInputChange('custom_mcps', customMcps);
|
||||
};
|
||||
|
||||
const getSelectedToolsCount = (): number => {
|
||||
|
|
|
@ -0,0 +1,458 @@
|
|||
import React, { useState } from 'react';
|
||||
import { Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from '@/components/ui/dialog';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Textarea } from '@/components/ui/textarea';
|
||||
import { Label } from '@/components/ui/label';
|
||||
import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group';
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert';
|
||||
import { Loader2, AlertCircle, CheckCircle2, Zap, Globe, Code, ChevronRight, Sparkles, Database, Wifi, Server } from 'lucide-react';
|
||||
import { ScrollArea } from '@/components/ui/scroll-area';
|
||||
import { Checkbox } from '@/components/ui/checkbox';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { createClient } from '@/lib/supabase/client';
|
||||
import { Input } from '@/components/ui/input';
|
||||
|
||||
const API_URL = process.env.NEXT_PUBLIC_BACKEND_URL || '';
|
||||
|
||||
interface CustomMCPDialogProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onSave: (config: CustomMCPConfiguration) => void;
|
||||
}
|
||||
|
||||
interface CustomMCPConfiguration {
|
||||
name: string;
|
||||
type: 'http' | 'sse';
|
||||
config: any;
|
||||
enabledTools: string[];
|
||||
}
|
||||
|
||||
interface MCPTool {
|
||||
name: string;
|
||||
description: string;
|
||||
inputSchema?: any;
|
||||
}
|
||||
|
||||
export const CustomMCPDialog: React.FC<CustomMCPDialogProps> = ({
|
||||
open,
|
||||
onOpenChange,
|
||||
onSave
|
||||
}) => {
|
||||
const [step, setStep] = useState<'setup' | 'tools'>('setup');
|
||||
const [serverType, setServerType] = useState<'http' | 'sse'>('sse');
|
||||
const [configText, setConfigText] = useState('');
|
||||
const [serverName, setServerName] = useState('');
|
||||
const [manualServerName, setManualServerName] = useState('');
|
||||
const [isValidating, setIsValidating] = useState(false);
|
||||
const [validationError, setValidationError] = useState<string | null>(null);
|
||||
const [discoveredTools, setDiscoveredTools] = useState<MCPTool[]>([]);
|
||||
const [selectedTools, setSelectedTools] = useState<Set<string>>(new Set());
|
||||
const [processedConfig, setProcessedConfig] = useState<any>(null);
|
||||
|
||||
const validateAndDiscoverTools = async () => {
|
||||
setIsValidating(true);
|
||||
setValidationError(null);
|
||||
setDiscoveredTools([]);
|
||||
|
||||
try {
|
||||
let parsedConfig: any;
|
||||
|
||||
if (serverType === 'sse' || serverType === 'http') {
|
||||
const url = configText.trim();
|
||||
if (!url) {
|
||||
throw new Error('Please enter the connection URL.');
|
||||
}
|
||||
if (!manualServerName.trim()) {
|
||||
throw new Error('Please enter a name for this connection.');
|
||||
}
|
||||
|
||||
parsedConfig = { url };
|
||||
setServerName(manualServerName.trim());
|
||||
}
|
||||
|
||||
const supabase = createClient();
|
||||
const { data: { session } } = await supabase.auth.getSession();
|
||||
|
||||
if (!session) {
|
||||
throw new Error('You must be logged in to discover tools');
|
||||
}
|
||||
|
||||
const response = await fetch(`${API_URL}/mcp/discover-custom-tools`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${session.access_token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
type: serverType,
|
||||
config: parsedConfig
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json();
|
||||
throw new Error(error.message || 'Failed to connect to the service. Please check your configuration.');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!data.tools || data.tools.length === 0) {
|
||||
throw new Error('No tools found. Please check your configuration.');
|
||||
}
|
||||
|
||||
if (data.serverName) {
|
||||
setServerName(data.serverName);
|
||||
}
|
||||
|
||||
if (data.processedConfig) {
|
||||
setProcessedConfig(data.processedConfig);
|
||||
}
|
||||
|
||||
setDiscoveredTools(data.tools);
|
||||
setSelectedTools(new Set(data.tools.map((tool: MCPTool) => tool.name)));
|
||||
setStep('tools');
|
||||
|
||||
} catch (error: any) {
|
||||
setValidationError(error.message);
|
||||
} finally {
|
||||
setIsValidating(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = () => {
|
||||
if (discoveredTools.length === 0 || selectedTools.size === 0) {
|
||||
setValidationError('Please select at least one tool to continue.');
|
||||
return;
|
||||
}
|
||||
|
||||
if (!serverName.trim()) {
|
||||
setValidationError('Please provide a name for this connection.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
let configToSave: any = { url: configText.trim() };
|
||||
|
||||
onSave({
|
||||
name: serverName,
|
||||
type: serverType,
|
||||
config: configToSave,
|
||||
enabledTools: Array.from(selectedTools)
|
||||
});
|
||||
|
||||
setConfigText('');
|
||||
setManualServerName('');
|
||||
setDiscoveredTools([]);
|
||||
setSelectedTools(new Set());
|
||||
setServerName('');
|
||||
setProcessedConfig(null);
|
||||
setValidationError(null);
|
||||
setStep('setup');
|
||||
onOpenChange(false);
|
||||
} catch (error) {
|
||||
setValidationError('Invalid configuration format.');
|
||||
}
|
||||
};
|
||||
|
||||
const handleToolToggle = (toolName: string) => {
|
||||
const newTools = new Set(selectedTools);
|
||||
if (newTools.has(toolName)) {
|
||||
newTools.delete(toolName);
|
||||
} else {
|
||||
newTools.add(toolName);
|
||||
}
|
||||
setSelectedTools(newTools);
|
||||
};
|
||||
|
||||
const handleBack = () => {
|
||||
setStep('setup');
|
||||
setValidationError(null);
|
||||
};
|
||||
|
||||
const handleReset = () => {
|
||||
setConfigText('');
|
||||
setManualServerName('');
|
||||
setDiscoveredTools([]);
|
||||
setSelectedTools(new Set());
|
||||
setServerName('');
|
||||
setProcessedConfig(null);
|
||||
setValidationError(null);
|
||||
setStep('setup');
|
||||
};
|
||||
|
||||
const exampleConfigs = {
|
||||
http: `https://server.example.com/mcp`,
|
||||
sse: `https://mcp.composio.dev/partner/composio/gmail/sse?customerId=YOUR_CUSTOMER_ID`
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={(open) => {
|
||||
onOpenChange(open);
|
||||
if (!open) handleReset();
|
||||
}}>
|
||||
<DialogContent className="max-w-4xl max-h-[85vh] overflow-hidden flex flex-col">
|
||||
<DialogHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-8 w-8 rounded-full bg-primary/10 flex items-center justify-center">
|
||||
<Zap className="h-4 w-4 text-primary" />
|
||||
</div>
|
||||
<DialogTitle>Connect New Service</DialogTitle>
|
||||
</div>
|
||||
<DialogDescription>
|
||||
{step === 'setup'
|
||||
? 'Connect to external services to expand your capabilities with new tools and integrations.'
|
||||
: 'Choose which tools you\'d like to enable from this service connection.'
|
||||
}
|
||||
</DialogDescription>
|
||||
<div className="flex items-center gap-2 pt-2">
|
||||
<div className={cn(
|
||||
"flex items-center gap-2 text-sm font-medium",
|
||||
step === 'setup' ? "text-primary" : "text-muted-foreground"
|
||||
)}>
|
||||
<div className={cn(
|
||||
"w-6 h-6 rounded-full flex items-center justify-center text-xs",
|
||||
step === 'setup' ? "bg-primary text-primary-foreground" : "bg-muted text-muted-foreground"
|
||||
)}>
|
||||
1
|
||||
</div>
|
||||
Setup Connection
|
||||
</div>
|
||||
<ChevronRight className="h-4 w-4 text-muted-foreground" />
|
||||
<div className={cn(
|
||||
"flex items-center gap-2 text-sm font-medium",
|
||||
step === 'tools' ? "text-primary" : "text-muted-foreground"
|
||||
)}>
|
||||
<div className={cn(
|
||||
"w-6 h-6 rounded-full flex items-center justify-center text-xs",
|
||||
step === 'tools' ? "bg-primary text-primary-foreground" : "bg-muted-foreground/20 text-muted-foreground"
|
||||
)}>
|
||||
2
|
||||
</div>
|
||||
Select Tools
|
||||
</div>
|
||||
</div>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="flex-1 overflow-hidden flex flex-col">
|
||||
{step === 'setup' ? (
|
||||
<div className="space-y-6 p-1 flex-1">
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-3">
|
||||
<Label className="text-base font-medium">How would you like to connect?</Label>
|
||||
<RadioGroup
|
||||
value={serverType}
|
||||
onValueChange={(value: 'http' | 'sse') => setServerType(value)}
|
||||
className="grid grid-cols-1 gap-3"
|
||||
>
|
||||
<div className={cn(
|
||||
"flex items-start space-x-3 p-4 rounded-lg border cursor-pointer transition-all hover:bg-muted/50",
|
||||
serverType === 'http' ? "border-primary bg-primary/5" : "border-border"
|
||||
)}>
|
||||
<RadioGroupItem value="http" id="http" className="mt-1" />
|
||||
<div className="flex-1 space-y-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<Server className="h-4 w-4 text-primary" />
|
||||
<Label htmlFor="http" className="text-base font-medium cursor-pointer">
|
||||
Streamable HTTP
|
||||
</Label>
|
||||
</div>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Standard streamable HTTP connection
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className={cn(
|
||||
"flex items-start space-x-3 p-4 rounded-lg border cursor-pointer transition-all hover:bg-muted/50",
|
||||
serverType === 'sse' ? "border-primary bg-primary/5" : "border-border"
|
||||
)}>
|
||||
<RadioGroupItem value="sse" id="sse" className="mt-1" />
|
||||
<div className="flex-1 space-y-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<Wifi className="h-4 w-4 text-primary" />
|
||||
<Label htmlFor="sse" className="text-base font-medium cursor-pointer">
|
||||
SSE (Server-Sent Events)
|
||||
</Label>
|
||||
</div>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Real-time connection using Server-Sent Events for streaming updates
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="serverName" className="text-base font-medium">
|
||||
Connection Name
|
||||
</Label>
|
||||
<input
|
||||
id="serverName"
|
||||
type="text"
|
||||
placeholder="e.g., Gmail, Slack, Customer Support Tools"
|
||||
value={manualServerName}
|
||||
onChange={(e) => setManualServerName(e.target.value)}
|
||||
className="w-full px-4 py-3 border border-input bg-background rounded-lg text-base focus:outline-none focus:ring-2 focus:ring-ring focus:border-transparent"
|
||||
/>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Give this connection a memorable name
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="config" className="text-base font-medium">
|
||||
Connection URL
|
||||
</Label>
|
||||
<Input
|
||||
id="config"
|
||||
type="url"
|
||||
placeholder={exampleConfigs[serverType]}
|
||||
value={configText}
|
||||
onChange={(e) => setConfigText(e.target.value)}
|
||||
className="w-full px-4 py-3 border border-input bg-muted rounded-lg text-base focus:outline-none focus:ring-2 focus:ring-ring focus:border-transparent font-mono"
|
||||
/>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Paste the complete connection URL provided by your service
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{validationError && (
|
||||
<Alert variant="destructive">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>{validationError}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-6 p-1 flex-1 flex flex-col">
|
||||
<Alert className="border-green-200 bg-green-50 text-green-800">
|
||||
<CheckCircle2 className="h-5 w-5 text-green-600" />
|
||||
<div className="ml-2">
|
||||
<h3 className="font-medium text-green-900 mb-1">
|
||||
Connection Successful!
|
||||
</h3>
|
||||
<p className="text-sm text-green-700">
|
||||
Found {discoveredTools.length} available tools from <strong>{serverName}</strong>
|
||||
</p>
|
||||
</div>
|
||||
</Alert>
|
||||
|
||||
<div className="space-y-4 flex-1 flex flex-col">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h3 className="text-base font-medium">Available Tools</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Select the tools you want to enable
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
if (selectedTools.size === discoveredTools.length) {
|
||||
setSelectedTools(new Set());
|
||||
} else {
|
||||
setSelectedTools(new Set(discoveredTools.map(t => t.name)));
|
||||
}
|
||||
}}
|
||||
>
|
||||
{selectedTools.size === discoveredTools.length ? 'Deselect All' : 'Select All'}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-h-0">
|
||||
<ScrollArea className="h-[400px] border border-border rounded-lg">
|
||||
<div className="space-y-3 p-4">
|
||||
{discoveredTools.map((tool) => (
|
||||
<div
|
||||
key={tool.name}
|
||||
className={cn(
|
||||
"flex items-start space-x-3 p-4 rounded-lg border transition-all cursor-pointer hover:bg-muted/50",
|
||||
selectedTools.has(tool.name)
|
||||
? "border-primary bg-primary/5"
|
||||
: "border-border"
|
||||
)}
|
||||
onClick={() => handleToolToggle(tool.name)}
|
||||
>
|
||||
<Checkbox
|
||||
id={tool.name}
|
||||
checked={selectedTools.has(tool.name)}
|
||||
onCheckedChange={() => handleToolToggle(tool.name)}
|
||||
className="mt-1"
|
||||
/>
|
||||
<div className="flex-1 space-y-2 min-w-0">
|
||||
<Label
|
||||
htmlFor={tool.name}
|
||||
className="text-base font-medium cursor-pointer block"
|
||||
>
|
||||
{tool.name.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase())}
|
||||
</Label>
|
||||
{tool.description && (
|
||||
<p className="text-sm text-muted-foreground leading-relaxed">
|
||||
{tool.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{validationError && (
|
||||
<Alert variant="destructive">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>{validationError}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<DialogFooter className="flex-shrink-0 pt-4">
|
||||
{step === 'tools' ? (
|
||||
<>
|
||||
<Button variant="outline" onClick={handleBack}>
|
||||
Back
|
||||
</Button>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleSave}
|
||||
disabled={selectedTools.size === 0}
|
||||
>
|
||||
Add Connection ({selectedTools.size} tools)
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Button variant="outline" onClick={() => onOpenChange(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={validateAndDiscoverTools}
|
||||
disabled={!configText.trim() || !manualServerName.trim() || isValidating}
|
||||
>
|
||||
{isValidating ? (
|
||||
<>
|
||||
<Loader2 className="h-5 w-5 animate-spin" />
|
||||
Discovering tools...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Sparkles className="h-5 w-5" />
|
||||
Connect
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
|
@ -1,17 +1,19 @@
|
|||
import React, { useState } from 'react';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Plus, Settings, Zap } from 'lucide-react';
|
||||
import { Plus, Settings, Zap, Code2, Server } from 'lucide-react';
|
||||
import { Dialog } from '@/components/ui/dialog';
|
||||
import { MCPConfigurationProps, MCPConfiguration as MCPConfigurationType } from './types';
|
||||
import { ConfiguredMcpList } from './configured-mcp-list';
|
||||
import { BrowseDialog } from './browse-dialog';
|
||||
import { ConfigDialog } from './config-dialog';
|
||||
import { CustomMCPDialog } from './custom-mcp-dialog';
|
||||
|
||||
export const MCPConfigurationNew: React.FC<MCPConfigurationProps> = ({
|
||||
configuredMCPs,
|
||||
onConfigurationChange,
|
||||
}) => {
|
||||
const [showBrowseDialog, setShowBrowseDialog] = useState(false);
|
||||
const [showCustomDialog, setShowCustomDialog] = useState(false);
|
||||
const [configuringServer, setConfiguringServer] = useState<any>(null);
|
||||
const [editingIndex, setEditingIndex] = useState<number | null>(null);
|
||||
|
||||
|
@ -23,6 +25,12 @@ export const MCPConfigurationNew: React.FC<MCPConfigurationProps> = ({
|
|||
|
||||
const handleEditMCP = (index: number) => {
|
||||
const mcp = configuredMCPs[index];
|
||||
// Check if it's a custom MCP
|
||||
if (mcp.isCustom) {
|
||||
// For custom MCPs, we'll need to handle editing differently
|
||||
// For now, just remove and re-add
|
||||
return;
|
||||
}
|
||||
setConfiguringServer({
|
||||
qualifiedName: mcp.qualifiedName,
|
||||
displayName: mcp.name,
|
||||
|
@ -38,17 +46,38 @@ export const MCPConfigurationNew: React.FC<MCPConfigurationProps> = ({
|
|||
};
|
||||
|
||||
const handleSaveConfiguration = (config: MCPConfigurationType) => {
|
||||
const regularMCPConfig = {
|
||||
...config,
|
||||
isCustom: false,
|
||||
customType: undefined
|
||||
};
|
||||
|
||||
if (editingIndex !== null) {
|
||||
const newMCPs = [...configuredMCPs];
|
||||
newMCPs[editingIndex] = config;
|
||||
newMCPs[editingIndex] = regularMCPConfig;
|
||||
onConfigurationChange(newMCPs);
|
||||
} else {
|
||||
onConfigurationChange([...configuredMCPs, config]);
|
||||
onConfigurationChange([...configuredMCPs, regularMCPConfig]);
|
||||
}
|
||||
setConfiguringServer(null);
|
||||
setEditingIndex(null);
|
||||
};
|
||||
|
||||
const handleSaveCustomMCP = (customConfig: any) => {
|
||||
console.log('Saving custom MCP config:', customConfig);
|
||||
const mcpConfig: MCPConfigurationType = {
|
||||
name: customConfig.name,
|
||||
qualifiedName: `custom_${customConfig.type}_${Date.now()}`,
|
||||
config: customConfig.config,
|
||||
enabledTools: customConfig.enabledTools,
|
||||
isCustom: true,
|
||||
customType: customConfig.type as 'http' | 'sse'
|
||||
};
|
||||
console.log('Transformed MCP config:', mcpConfig);
|
||||
onConfigurationChange([...configuredMCPs, mcpConfig]);
|
||||
console.log('Updated MCPs list:', [...configuredMCPs, mcpConfig]);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="rounded-xl p-6 border">
|
||||
|
@ -73,16 +102,28 @@ export const MCPConfigurationNew: React.FC<MCPConfigurationProps> = ({
|
|||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => setShowBrowseDialog(true)}
|
||||
className="transition-all duration-200"
|
||||
>
|
||||
<Plus className="h-4 w-4" />
|
||||
Add Server
|
||||
</Button>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => setShowCustomDialog(true)}
|
||||
className="transition-all duration-200"
|
||||
>
|
||||
<Server className="h-4 w-4" />
|
||||
Custom
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => setShowBrowseDialog(true)}
|
||||
className="transition-all duration-200"
|
||||
>
|
||||
<Plus className="h-4 w-4" />
|
||||
Browse
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{configuredMCPs.length === 0 && (
|
||||
<div className="text-center py-12 px-6 bg-muted/30 rounded-xl border-2 border-dashed border-border">
|
||||
<div className="mx-auto w-12 h-12 bg-muted rounded-full flex items-center justify-center mb-4">
|
||||
|
@ -117,6 +158,11 @@ export const MCPConfigurationNew: React.FC<MCPConfigurationProps> = ({
|
|||
onOpenChange={setShowBrowseDialog}
|
||||
onServerSelect={handleAddMCP}
|
||||
/>
|
||||
<CustomMCPDialog
|
||||
open={showCustomDialog}
|
||||
onOpenChange={setShowCustomDialog}
|
||||
onSave={handleSaveCustomMCP}
|
||||
/>
|
||||
{configuringServer && (
|
||||
<Dialog open={!!configuringServer} onOpenChange={() => setConfiguringServer(null)}>
|
||||
<ConfigDialog
|
||||
|
|
|
@ -3,6 +3,8 @@ export interface MCPConfiguration {
|
|||
qualifiedName: string;
|
||||
config: Record<string, any>;
|
||||
enabledTools?: string[];
|
||||
isCustom?: boolean;
|
||||
customType?: 'http' | 'sse';
|
||||
}
|
||||
|
||||
export interface MCPConfigurationProps {
|
||||
|
|
|
@ -19,6 +19,7 @@ interface AgentUpdateRequest {
|
|||
description?: string;
|
||||
system_prompt?: string;
|
||||
configured_mcps?: Array<{ name: string; qualifiedName: string; config: any; enabledTools?: string[] }>;
|
||||
custom_mcps?: Array<{ name: string; type: 'json' | 'sse'; config: any; enabledTools: string[] }>;
|
||||
agentpress_tools?: Record<string, { enabled: boolean; description: string }>;
|
||||
is_default?: boolean;
|
||||
}
|
||||
|
@ -65,6 +66,7 @@ export const UpdateAgentDialog = ({ agentId, isOpen, onOpenChange, onAgentUpdate
|
|||
config: mcp.config,
|
||||
enabledTools: (mcp as any).enabledTools || []
|
||||
})),
|
||||
custom_mcps: agent.custom_mcps || [],
|
||||
agentpress_tools: agent.agentpress_tools || {},
|
||||
is_default: agent.is_default,
|
||||
});
|
||||
|
@ -100,7 +102,17 @@ export const UpdateAgentDialog = ({ agentId, isOpen, onOpenChange, onAgentUpdate
|
|||
};
|
||||
|
||||
const handleMCPConfigurationChange = (mcps: any[]) => {
|
||||
handleInputChange('configured_mcps', mcps);
|
||||
// Separate standard and custom MCPs
|
||||
const standardMcps = mcps.filter(mcp => !mcp.isCustom);
|
||||
const customMcps = mcps.filter(mcp => mcp.isCustom).map(mcp => ({
|
||||
name: mcp.name,
|
||||
type: mcp.customType as 'json' | 'sse',
|
||||
config: mcp.config,
|
||||
enabledTools: mcp.enabledTools || []
|
||||
}));
|
||||
|
||||
handleInputChange('configured_mcps', standardMcps);
|
||||
handleInputChange('custom_mcps', customMcps);
|
||||
};
|
||||
|
||||
const getAllAgentPressTools = () => {
|
||||
|
@ -349,7 +361,14 @@ export const UpdateAgentDialog = ({ agentId, isOpen, onOpenChange, onAgentUpdate
|
|||
|
||||
<TabsContent value="mcp" className="flex-1 m-0 p-6 overflow-y-auto">
|
||||
<MCPConfigurationNew
|
||||
configuredMCPs={formData.configured_mcps || []}
|
||||
configuredMCPs={[...(formData.configured_mcps || []), ...(formData.custom_mcps || []).map(customMcp => ({
|
||||
name: customMcp.name,
|
||||
qualifiedName: `custom_${customMcp.type}_${customMcp.name.replace(' ', '_').toLowerCase()}`,
|
||||
config: customMcp.config,
|
||||
enabledTools: customMcp.enabledTools,
|
||||
isCustom: true,
|
||||
customType: customMcp.type
|
||||
}))]}
|
||||
onConfigurationChange={handleMCPConfigurationChange}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
|
|
@ -42,6 +42,7 @@ export default function AgentConfigurationPage() {
|
|||
system_prompt: '',
|
||||
agentpress_tools: {},
|
||||
configured_mcps: [],
|
||||
custom_mcps: [],
|
||||
is_default: false,
|
||||
avatar: '',
|
||||
avatar_color: '',
|
||||
|
@ -55,7 +56,6 @@ export default function AgentConfigurationPage() {
|
|||
const [activeTab, setActiveTab] = useState('agent-builder');
|
||||
const accordionRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Effect to automatically close sidebar on page load
|
||||
useEffect(() => {
|
||||
if (!initialLayoutAppliedRef.current) {
|
||||
setOpen(false);
|
||||
|
@ -72,6 +72,7 @@ export default function AgentConfigurationPage() {
|
|||
system_prompt: agentData.system_prompt || '',
|
||||
agentpress_tools: agentData.agentpress_tools || {},
|
||||
configured_mcps: agentData.configured_mcps || [],
|
||||
custom_mcps: agentData.custom_mcps || [],
|
||||
is_default: agentData.is_default || false,
|
||||
avatar: agentData.avatar || '',
|
||||
avatar_color: agentData.avatar_color || '',
|
||||
|
@ -108,7 +109,8 @@ export default function AgentConfigurationPage() {
|
|||
return true;
|
||||
}
|
||||
if (JSON.stringify(newData.agentpress_tools) !== JSON.stringify(originalData.agentpress_tools) ||
|
||||
JSON.stringify(newData.configured_mcps) !== JSON.stringify(originalData.configured_mcps)) {
|
||||
JSON.stringify(newData.configured_mcps) !== JSON.stringify(originalData.configured_mcps) ||
|
||||
JSON.stringify(newData.custom_mcps) !== JSON.stringify(originalData.custom_mcps)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -158,6 +160,16 @@ export default function AgentConfigurationPage() {
|
|||
debouncedSave(newFormData);
|
||||
}, [debouncedSave]);
|
||||
|
||||
const handleBatchMCPChange = useCallback((updates: { configured_mcps: any[]; custom_mcps: any[] }) => {
|
||||
const newFormData = {
|
||||
...currentFormDataRef.current,
|
||||
configured_mcps: updates.configured_mcps,
|
||||
custom_mcps: updates.custom_mcps
|
||||
};
|
||||
|
||||
setFormData(newFormData);
|
||||
debouncedSave(newFormData);
|
||||
}, [debouncedSave]);
|
||||
|
||||
const scrollToAccordion = useCallback(() => {
|
||||
if (accordionRef.current) {
|
||||
|
@ -369,7 +381,10 @@ export default function AgentConfigurationPage() {
|
|||
<AccordionContent className="pb-4 overflow-x-hidden">
|
||||
<AgentMCPConfiguration
|
||||
mcps={formData.configured_mcps}
|
||||
onMCPsChange={(mcps) => handleFieldChange('configured_mcps', mcps)}
|
||||
customMcps={formData.custom_mcps}
|
||||
onMCPsChange={(mcps) => handleBatchMCPChange({ configured_mcps: mcps, custom_mcps: formData.custom_mcps })}
|
||||
onCustomMCPsChange={(customMcps) => handleBatchMCPChange({ configured_mcps: formData.configured_mcps, custom_mcps: customMcps })}
|
||||
onBatchMCPChange={handleBatchMCPChange}
|
||||
/>
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
|
@ -398,7 +413,8 @@ export default function AgentConfigurationPage() {
|
|||
setIsPreviewOpen,
|
||||
setActiveTab,
|
||||
scrollToAccordion,
|
||||
getSaveStatusBadge
|
||||
getSaveStatusBadge,
|
||||
handleBatchMCPChange
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
|
@ -632,7 +632,7 @@ export default function ThreadPage({
|
|||
value={newMessage}
|
||||
onChange={setNewMessage}
|
||||
onSubmit={handleSubmitMessage}
|
||||
placeholder={`Ask ${agent && agent.name} anything...`}
|
||||
placeholder={`Ask ${agent ? agent.name : 'Suna'} anything...`}
|
||||
loading={isSending}
|
||||
disabled={isSending || agentStatus === 'running' || agentStatus === 'connecting'}
|
||||
isAgentRunning={agentStatus === 'running' || agentStatus === 'connecting'}
|
||||
|
|
|
@ -205,20 +205,15 @@ export function useToolCalls(
|
|||
}
|
||||
});
|
||||
|
||||
// Update the ref with the new mapping
|
||||
assistantMessageToToolIndex.current = messageIdToIndex;
|
||||
setToolCalls(historicalToolPairs);
|
||||
|
||||
// Auto-navigation logic
|
||||
if (historicalToolPairs.length > 0) {
|
||||
// If agent is running and user hasn't manually navigated, always show the last tool
|
||||
if (agentStatus === 'running' && !userNavigatedRef.current) {
|
||||
setCurrentToolIndex(historicalToolPairs.length - 1);
|
||||
} else if (isSidePanelOpen && !userClosedPanelRef.current && !userNavigatedRef.current) {
|
||||
// If panel is open and user hasn't manually navigated, jump to latest
|
||||
setCurrentToolIndex(historicalToolPairs.length - 1);
|
||||
} else if (!isSidePanelOpen && !autoOpenedPanel && !userClosedPanelRef.current) {
|
||||
// Auto-open the panel only the first time tools are detected
|
||||
setCurrentToolIndex(historicalToolPairs.length - 1);
|
||||
setIsSidePanelOpen(true);
|
||||
setAutoOpenedPanel(true);
|
||||
|
|
|
@ -655,24 +655,18 @@ export function ToolCallSidePanel({
|
|||
</div>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{showJumpToLive && (
|
||||
<div className="flex cursor-pointer items-center gap-1.5 px-2 py-0.5 rounded-full bg-green-50 dark:bg-green-900/20 border border-green-200 dark:border-green-800" onClick={jumpToLive}>
|
||||
<div className="w-1.5 h-1.5 bg-green-500 rounded-full animate-pulse"></div>
|
||||
<span className="text-xs font-medium text-green-700 dark:text-green-400">Jump to Live</span>
|
||||
</div>
|
||||
)}
|
||||
{showJumpToLatest && (
|
||||
<div className="flex cursor-pointer items-center gap-1.5 px-2 py-0.5 rounded-full bg-neutral-50 dark:bg-neutral-900/20 border border-neutral-200 dark:border-neutral-800" onClick={jumpToLatest}>
|
||||
<div className="w-1.5 h-1.5 bg-neutral-500 rounded-full"></div>
|
||||
<span className="text-xs font-medium text-neutral-700 dark:text-neutral-400">Jump to Latest</span>
|
||||
</div>
|
||||
)}
|
||||
{isLiveMode && agentStatus === 'running' && !showJumpToLive && (
|
||||
{isLiveMode && agentStatus === 'running' && (
|
||||
<div className="flex items-center gap-1.5 px-2 py-0.5 rounded-full bg-green-50 dark:bg-green-900/20 border border-green-200 dark:border-green-800">
|
||||
<div className="w-1.5 h-1.5 bg-green-500 rounded-full animate-pulse"></div>
|
||||
<span className="text-xs font-medium text-green-700 dark:text-green-400">Live</span>
|
||||
</div>
|
||||
)}
|
||||
{!isLiveMode && agentStatus !== 'running' && (
|
||||
<div className="flex items-center gap-1.5 px-2 py-0.5 rounded-full bg-neutral-50 dark:bg-neutral-900/20 border border-neutral-200 dark:border-neutral-800">
|
||||
<div className="w-1.5 h-1.5 bg-neutral-500 rounded-full"></div>
|
||||
<span className="text-xs font-medium text-neutral-700 dark:text-neutral-400">Live</span>
|
||||
</div>
|
||||
)}
|
||||
<span className="text-xs text-zinc-500 dark:text-zinc-400 flex-shrink-0">
|
||||
Step {displayIndex + 1} of {displayTotalCalls}
|
||||
</span>
|
||||
|
@ -749,6 +743,21 @@ export function ToolCallSidePanel({
|
|||
</div>
|
||||
|
||||
<div className="relative w-full">
|
||||
{(showJumpToLive || showJumpToLatest) && (
|
||||
<div className="absolute -top-12 left-1/2 transform -translate-x-1/2 z-10">
|
||||
{showJumpToLive && (
|
||||
<Button className='rounded-full bg-red-500 hover:bg-red-400 text-white' onClick={jumpToLive}>
|
||||
Jump to Live
|
||||
</Button>
|
||||
)}
|
||||
{showJumpToLatest && (
|
||||
<Button className='rounded-full' onClick={jumpToLive}>
|
||||
Jump to Latest
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Slider
|
||||
min={0}
|
||||
max={displayTotalCalls - 1}
|
||||
|
@ -757,8 +766,6 @@ export function ToolCallSidePanel({
|
|||
onValueChange={handleSliderChange}
|
||||
className="w-full [&>span:first-child]:h-1 [&>span:first-child]:bg-zinc-200 dark:[&>span:first-child]:bg-zinc-800 [&>span:first-child>span]:bg-zinc-500 dark:[&>span:first-child>span]:bg-zinc-400 [&>span:first-child>span]:h-1"
|
||||
/>
|
||||
|
||||
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
|
|
@ -52,16 +52,6 @@ const extractFromNewFormat = (content: any): WebSearchData => {
|
|||
success: toolExecution.result?.success,
|
||||
timestamp: toolExecution.execution_details?.timestamp
|
||||
};
|
||||
|
||||
console.log('WebSearchToolView: Extracted from new format:', {
|
||||
query: extractedData.query,
|
||||
resultsCount: extractedData.results.length,
|
||||
hasAnswer: !!extractedData.answer,
|
||||
imagesCount: extractedData.images.length,
|
||||
success: extractedData.success,
|
||||
firstResult: extractedData.results[0]
|
||||
});
|
||||
|
||||
return extractedData;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,12 @@ export type Agent = {
|
|||
name: string;
|
||||
config: Record<string, any>;
|
||||
}>;
|
||||
custom_mcps?: Array<{
|
||||
name: string;
|
||||
type: 'json' | 'sse';
|
||||
config: Record<string, any>;
|
||||
enabledTools: string[];
|
||||
}>;
|
||||
agentpress_tools: Record<string, any>;
|
||||
is_default: boolean;
|
||||
is_public?: boolean;
|
||||
|
@ -62,6 +68,12 @@ export type AgentCreateRequest = {
|
|||
name: string;
|
||||
config: Record<string, any>;
|
||||
}>;
|
||||
custom_mcps?: Array<{
|
||||
name: string;
|
||||
type: 'json' | 'sse';
|
||||
config: Record<string, any>;
|
||||
enabledTools: string[];
|
||||
}>;
|
||||
agentpress_tools?: Record<string, any>;
|
||||
is_default?: boolean;
|
||||
};
|
||||
|
@ -74,6 +86,12 @@ export type AgentUpdateRequest = {
|
|||
name: string;
|
||||
config: Record<string, any>;
|
||||
}>;
|
||||
custom_mcps?: Array<{
|
||||
name: string;
|
||||
type: 'json' | 'sse';
|
||||
config: Record<string, any>;
|
||||
enabledTools: string[];
|
||||
}>;
|
||||
agentpress_tools?: Record<string, any>;
|
||||
is_default?: boolean;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue