import asyncio from typing import Dict, Any, List from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from ..domain.entities import CustomMCPConnectionResult, ToolInfo from ..domain.exceptions import CustomMCPError from ..protocols import Logger class CustomMCPDiscovery: def __init__(self, logger: Logger): self._logger = logger async def discover_tools(self, request_type: str, config: Dict[str, Any]) -> CustomMCPConnectionResult: if request_type == "http": return await self._discover_http_tools(config) elif request_type == "sse": return await self._discover_sse_tools(config) else: raise CustomMCPError(f"Unsupported request type: {request_type}") async def _discover_http_tools(self, config: Dict[str, Any]) -> CustomMCPConnectionResult: url = config.get("url") if not url: raise CustomMCPError("URL is required for HTTP MCP connections") try: async with streamablehttp_client(url) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: await session.initialize() tool_result = await session.list_tools() tools_info = [] for tool in tool_result.tools: tools_info.append({ "name": tool.name, "description": tool.description, "inputSchema": tool.inputSchema }) return CustomMCPConnectionResult( success=True, qualified_name=f"custom_http_{url.split('/')[-1]}", display_name=f"Custom HTTP MCP ({url})", tools=tools_info, config=config, url=url, message=f"Connected via HTTP ({len(tools_info)} tools)" ) except Exception as e: self._logger.error(f"Error connecting to HTTP MCP server: {str(e)}") return CustomMCPConnectionResult( success=False, qualified_name="", display_name="", tools=[], config=config, url=url, message=f"Failed to connect: {str(e)}" ) async def _discover_sse_tools(self, config: Dict[str, Any]) -> CustomMCPConnectionResult: url = config.get("url") if not url: raise CustomMCPError("URL is required for SSE MCP connections") try: async with sse_client(url) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() tool_result = await session.list_tools() tools_info = [] for tool in tool_result.tools: tools_info.append({ "name": tool.name, "description": tool.description, "inputSchema": tool.inputSchema }) return CustomMCPConnectionResult( success=True, qualified_name=f"custom_sse_{url.split('/')[-1]}", display_name=f"Custom SSE MCP ({url})", tools=tools_info, config=config, url=url, message=f"Connected via SSE ({len(tools_info)} tools)" ) except Exception as e: self._logger.error(f"Error connecting to SSE MCP server: {str(e)}") return CustomMCPConnectionResult( success=False, qualified_name="", display_name="", tools=[], config=config, url=url, message=f"Failed to connect: {str(e)}" )