mirror of https://github.com/kortix-ai/suna.git
104 lines
4.2 KiB
Python
104 lines
4.2 KiB
Python
import asyncio
|
|
from typing import Dict, Any, List
|
|
from mcp import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
from ..domain.entities import CustomMCPConnectionResult, ToolInfo
|
|
from ..domain.exceptions import CustomMCPError
|
|
from ..protocols import Logger
|
|
|
|
|
|
class CustomMCPDiscovery:
|
|
def __init__(self, logger: Logger):
|
|
self._logger = logger
|
|
|
|
async def discover_tools(self, request_type: str, config: Dict[str, Any]) -> CustomMCPConnectionResult:
|
|
if request_type == "http":
|
|
return await self._discover_http_tools(config)
|
|
elif request_type == "sse":
|
|
return await self._discover_sse_tools(config)
|
|
else:
|
|
raise CustomMCPError(f"Unsupported request type: {request_type}")
|
|
|
|
async def _discover_http_tools(self, config: Dict[str, Any]) -> CustomMCPConnectionResult:
|
|
url = config.get("url")
|
|
if not url:
|
|
raise CustomMCPError("URL is required for HTTP MCP connections")
|
|
|
|
try:
|
|
async with streamablehttp_client(url) as (read_stream, write_stream, _):
|
|
async with ClientSession(read_stream, write_stream) as session:
|
|
await session.initialize()
|
|
tool_result = await session.list_tools()
|
|
|
|
tools_info = []
|
|
for tool in tool_result.tools:
|
|
tools_info.append({
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"inputSchema": tool.inputSchema
|
|
})
|
|
|
|
return CustomMCPConnectionResult(
|
|
success=True,
|
|
qualified_name=f"custom_http_{url.split('/')[-1]}",
|
|
display_name=f"Custom HTTP MCP ({url})",
|
|
tools=tools_info,
|
|
config=config,
|
|
url=url,
|
|
message=f"Connected via HTTP ({len(tools_info)} tools)"
|
|
)
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error connecting to HTTP MCP server: {str(e)}")
|
|
return CustomMCPConnectionResult(
|
|
success=False,
|
|
qualified_name="",
|
|
display_name="",
|
|
tools=[],
|
|
config=config,
|
|
url=url,
|
|
message=f"Failed to connect: {str(e)}"
|
|
)
|
|
|
|
async def _discover_sse_tools(self, config: Dict[str, Any]) -> CustomMCPConnectionResult:
|
|
url = config.get("url")
|
|
if not url:
|
|
raise CustomMCPError("URL is required for SSE MCP connections")
|
|
|
|
try:
|
|
async with sse_client(url) as (read_stream, write_stream):
|
|
async with ClientSession(read_stream, write_stream) as session:
|
|
await session.initialize()
|
|
tool_result = await session.list_tools()
|
|
|
|
tools_info = []
|
|
for tool in tool_result.tools:
|
|
tools_info.append({
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"inputSchema": tool.inputSchema
|
|
})
|
|
|
|
return CustomMCPConnectionResult(
|
|
success=True,
|
|
qualified_name=f"custom_sse_{url.split('/')[-1]}",
|
|
display_name=f"Custom SSE MCP ({url})",
|
|
tools=tools_info,
|
|
config=config,
|
|
url=url,
|
|
message=f"Connected via SSE ({len(tools_info)} tools)"
|
|
)
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error connecting to SSE MCP server: {str(e)}")
|
|
return CustomMCPConnectionResult(
|
|
success=False,
|
|
qualified_name="",
|
|
display_name="",
|
|
tools=[],
|
|
config=config,
|
|
url=url,
|
|
message=f"Failed to connect: {str(e)}"
|
|
) |