suna/backend/mcp_module/support/custom_discovery.py

104 lines
4.2 KiB
Python

import asyncio
from typing import Dict, Any, List
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from ..domain.entities import CustomMCPConnectionResult, ToolInfo
from ..domain.exceptions import CustomMCPError
from ..protocols import Logger
class CustomMCPDiscovery:
def __init__(self, logger: Logger):
self._logger = logger
async def discover_tools(self, request_type: str, config: Dict[str, Any]) -> CustomMCPConnectionResult:
if request_type == "http":
return await self._discover_http_tools(config)
elif request_type == "sse":
return await self._discover_sse_tools(config)
else:
raise CustomMCPError(f"Unsupported request type: {request_type}")
async def _discover_http_tools(self, config: Dict[str, Any]) -> CustomMCPConnectionResult:
url = config.get("url")
if not url:
raise CustomMCPError("URL is required for HTTP MCP connections")
try:
async with streamablehttp_client(url) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
tool_result = await session.list_tools()
tools_info = []
for tool in tool_result.tools:
tools_info.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema
})
return CustomMCPConnectionResult(
success=True,
qualified_name=f"custom_http_{url.split('/')[-1]}",
display_name=f"Custom HTTP MCP ({url})",
tools=tools_info,
config=config,
url=url,
message=f"Connected via HTTP ({len(tools_info)} tools)"
)
except Exception as e:
self._logger.error(f"Error connecting to HTTP MCP server: {str(e)}")
return CustomMCPConnectionResult(
success=False,
qualified_name="",
display_name="",
tools=[],
config=config,
url=url,
message=f"Failed to connect: {str(e)}"
)
async def _discover_sse_tools(self, config: Dict[str, Any]) -> CustomMCPConnectionResult:
url = config.get("url")
if not url:
raise CustomMCPError("URL is required for SSE MCP connections")
try:
async with sse_client(url) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
tool_result = await session.list_tools()
tools_info = []
for tool in tool_result.tools:
tools_info.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema
})
return CustomMCPConnectionResult(
success=True,
qualified_name=f"custom_sse_{url.split('/')[-1]}",
display_name=f"Custom SSE MCP ({url})",
tools=tools_info,
config=config,
url=url,
message=f"Connected via SSE ({len(tools_info)} tools)"
)
except Exception as e:
self._logger.error(f"Error connecting to SSE MCP server: {str(e)}")
return CustomMCPConnectionResult(
success=False,
qualified_name="",
display_name="",
tools=[],
config=config,
url=url,
message=f"Failed to connect: {str(e)}"
)