suna/backend/mcp_service/client.py

305 lines
12 KiB
Python

"""
MCP Client module for connecting to and using MCP servers
This module handles:
1. Connecting to MCP servers via Smithery
2. Converting MCP tools to OpenAPI format for LLMs
3. Executing MCP tool calls
"""
import asyncio
import json
import base64
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
# Import MCP components according to the official SDK
from mcp import ClientSession
try:
from mcp.client.streamable_http import streamablehttp_client
except ImportError:
# Fallback import if the module structure is different
try:
from mcp.client import streamablehttp_client
except ImportError:
raise ImportError(
"Could not import streamablehttp_client. "
"Make sure you have installed mcp with: pip install 'mcp[cli]'"
)
# Import types - these should be in mcp.types according to the docs
try:
from mcp.types import Tool, CallToolResult as ToolResult
except ImportError:
# Fallback to a different location if needed
try:
from mcp import types
Tool = types.Tool
ToolResult = types.CallToolResult
except AttributeError:
# If CallToolResult doesn't exist, create a simple class
Tool = Any
ToolResult = Any
from utils.logger import logger
import os
# Get Smithery API key from environment
SMITHERY_API_KEY = os.getenv("SMITHERY_API_KEY")
SMITHERY_SERVER_BASE_URL = "https://server.smithery.ai"
@dataclass
class MCPConnection:
"""Represents a connection to an MCP server"""
qualified_name: str
name: str
config: Dict[str, Any]
enabled_tools: List[str]
session: Optional[ClientSession] = None
tools: Optional[List[Tool]] = None
class MCPManager:
"""Manages connections to multiple MCP servers"""
def __init__(self):
self.connections: Dict[str, MCPConnection] = {}
self._sessions: Dict[str, Tuple[Any, Any, Any]] = {} # Store streams for cleanup
async def connect_server(self, mcp_config: Dict[str, Any]) -> MCPConnection:
"""
Connect to an MCP server using configuration
Args:
mcp_config: Configuration from agent's configured_mcps field
{
"name": "Exa Search",
"qualifiedName": "exa",
"config": {"exaApiKey": "xxx"},
"enabledTools": ["web_search_exa"]
}
"""
qualified_name = mcp_config["qualifiedName"]
# Check if already connected
if qualified_name in self.connections:
logger.info(f"MCP server {qualified_name} already connected")
return self.connections[qualified_name]
logger.info(f"Connecting to MCP server: {qualified_name}")
# Check if Smithery API key is available
if not SMITHERY_API_KEY:
raise ValueError(
"SMITHERY_API_KEY environment variable is not set. "
"Please set it to use MCP servers from Smithery."
)
try:
# Encode config in base64
config_json = json.dumps(mcp_config["config"])
config_b64 = base64.b64encode(config_json.encode()).decode()
# Create server URL
url = f"{SMITHERY_SERVER_BASE_URL}/{qualified_name}/mcp?config={config_b64}&api_key={SMITHERY_API_KEY}"
# Test connection and get available tools
async with streamablehttp_client(url) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the connection
await session.initialize()
logger.info(f"MCP session initialized for {qualified_name}")
# List available tools
tools_result = await session.list_tools()
tools = tools_result.tools if hasattr(tools_result, 'tools') else tools_result
logger.info(f"Available tools from {qualified_name}: {[t.name for t in tools]}")
# Create connection object (without persistent session)
connection = MCPConnection(
qualified_name=qualified_name,
name=mcp_config["name"],
config=mcp_config["config"],
enabled_tools=mcp_config.get("enabledTools", []),
session=None, # No persistent session
tools=tools
)
self.connections[qualified_name] = connection
return connection
except Exception as e:
logger.error(f"Failed to connect to MCP server {qualified_name}: {str(e)}")
raise
async def connect_all(self, mcp_configs: List[Dict[str, Any]]) -> None:
"""Connect to all MCP servers in the configuration"""
for config in mcp_configs:
try:
await self.connect_server(config)
except Exception as e:
logger.error(f"Failed to connect to {config['qualifiedName']}: {str(e)}")
# Continue with other servers even if one fails
def get_all_tools_openapi(self) -> List[Dict[str, Any]]:
"""
Convert all connected MCP tools to OpenAPI format for LLM
Returns a list of tool definitions in OpenAPI format
"""
all_tools = []
for conn in self.connections.values():
if not conn.tools:
continue
for tool in conn.tools:
# Skip tools that are not enabled
if conn.enabled_tools and tool.name not in conn.enabled_tools:
continue
# Convert MCP tool to OpenAPI format
openapi_tool = {
"name": f"mcp_{conn.qualified_name}_{tool.name}", # Prefix to avoid conflicts
"description": tool.description or f"MCP tool from {conn.name}",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
# Convert input schema if available
if hasattr(tool, 'inputSchema') and tool.inputSchema:
schema = tool.inputSchema
if isinstance(schema, dict):
openapi_tool["parameters"]["properties"] = schema.get("properties", {})
openapi_tool["parameters"]["required"] = schema.get("required", [])
all_tools.append(openapi_tool)
return all_tools
async def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute an MCP tool call
Args:
tool_name: Name in format "mcp_{qualified_name}_{original_tool_name}"
arguments: Tool arguments
Returns:
Tool execution result
"""
# Parse the tool name to get server and original tool name
parts = tool_name.split("_", 2)
if len(parts) != 3 or parts[0] != "mcp":
raise ValueError(f"Invalid MCP tool name format: {tool_name}")
_, qualified_name, original_tool_name = parts
# Find the connection config
if qualified_name not in self.connections:
raise ValueError(f"MCP server {qualified_name} not connected")
conn = self.connections[qualified_name]
logger.info(f"Executing MCP tool {original_tool_name} on server {qualified_name}")
# Check if Smithery API key is available
if not SMITHERY_API_KEY:
raise ValueError("SMITHERY_API_KEY environment variable is not set")
try:
# Create fresh connection for this tool call
config_json = json.dumps(conn.config)
config_b64 = base64.b64encode(config_json.encode()).decode()
url = f"{SMITHERY_SERVER_BASE_URL}/{qualified_name}/mcp?config={config_b64}&api_key={SMITHERY_API_KEY}"
# Use the documented pattern with proper context management
async with streamablehttp_client(url) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the connection
await session.initialize()
# Call the tool
result = await session.call_tool(original_tool_name, arguments)
# Convert result to dict - handle MCP response properly
if hasattr(result, 'content'):
# Handle content which might be a list of TextContent objects
content = result.content
if isinstance(content, list):
# Extract text from TextContent objects
text_parts = []
for item in content:
if hasattr(item, 'text'):
text_parts.append(item.text)
elif hasattr(item, 'content'):
text_parts.append(str(item.content))
else:
text_parts.append(str(item))
content_str = "\n".join(text_parts)
elif hasattr(content, 'text'):
# Single TextContent object
content_str = content.text
elif hasattr(content, 'content'):
content_str = str(content.content)
else:
content_str = str(content)
is_error = getattr(result, 'isError', False)
else:
content_str = str(result)
is_error = False
return {
"content": content_str,
"isError": is_error
}
except Exception as e:
logger.error(f"Error executing MCP tool {tool_name}: {str(e)}")
return {
"content": f"Error executing tool: {str(e)}",
"isError": True
}
async def disconnect_all(self):
"""Disconnect all MCP servers (clear stored configurations)"""
for qualified_name in list(self.connections.keys()):
try:
del self.connections[qualified_name]
logger.info(f"Cleared MCP server configuration for {qualified_name}")
except Exception as e:
logger.error(f"Error clearing configuration for {qualified_name}: {str(e)}")
# Clear sessions dict
self._sessions.clear()
def get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]:
"""Get information about a specific tool"""
parts = tool_name.split("_", 2)
if len(parts) != 3 or parts[0] != "mcp":
return None
_, qualified_name, original_tool_name = parts
if qualified_name not in self.connections:
return None
conn = self.connections[qualified_name]
if not conn.tools:
return None
for tool in conn.tools:
if tool.name == original_tool_name:
return {
"server": conn.name,
"qualified_name": qualified_name,
"original_name": tool.name,
"description": tool.description,
"enabled": not conn.enabled_tools or tool.name in conn.enabled_tools
}
return None