import asyncio from typing import Dict, Any, List from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client from utils.logger import logger class MCPConnectionManager: def __init__(self): self.connected_servers: Dict[str, Dict[str, Any]] = {} async def connect_sse_server(self, server_name: str, server_config: Dict[str, Any], timeout: int = 15) -> Dict[str, Any]: url = server_config["url"] headers = server_config.get("headers", {}) async with asyncio.timeout(timeout): try: async with sse_client(url, headers=headers) as (read, write): async with ClientSession(read, write) as session: await session.initialize() tools_result = await session.list_tools() tools_info = [ { "name": tool.name, "description": tool.description, "input_schema": tool.inputSchema } for tool in tools_result.tools ] server_info = { "status": "connected", "transport": "sse", "url": url, "tools": tools_info } self.connected_servers[server_name] = server_info logger.info(f"Connected to {server_name} via SSE ({len(tools_info)} tools)") return server_info except TypeError as e: if "unexpected keyword argument" in str(e): async with sse_client(url) as (read, write): async with ClientSession(read, write) as session: await session.initialize() tools_result = await session.list_tools() tools_info = [ { "name": tool.name, "description": tool.description, "input_schema": tool.inputSchema } for tool in tools_result.tools ] server_info = { "status": "connected", "transport": "sse", "url": url, "tools": tools_info } self.connected_servers[server_name] = server_info logger.info(f"Connected to {server_name} via SSE ({len(tools_info)} tools)") return server_info else: raise async def connect_http_server(self, server_name: str, server_config: Dict[str, Any], timeout: int = 15) -> Dict[str, Any]: url = server_config["url"] async with asyncio.timeout(timeout): async with streamablehttp_client(url) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: await session.initialize() tools_result = await session.list_tools() tools_info = [ { "name": tool.name, "description": tool.description, "input_schema": tool.inputSchema } for tool in tools_result.tools ] server_info = { "status": "connected", "transport": "http", "url": url, "tools": tools_info } self.connected_servers[server_name] = server_info logger.info(f"Connected to {server_name} via HTTP ({len(tools_info)} tools)") return server_info async def connect_stdio_server(self, server_name: str, server_config: Dict[str, Any], timeout: int = 15) -> Dict[str, Any]: server_params = StdioServerParameters( command=server_config["command"], args=server_config.get("args", []), env=server_config.get("env", {}) ) async with asyncio.timeout(timeout): async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() tools_result = await session.list_tools() tools_info = [ { "name": tool.name, "description": tool.description, "input_schema": tool.inputSchema } for tool in tools_result.tools ] server_info = { "status": "connected", "transport": "stdio", "tools": tools_info } self.connected_servers[server_name] = server_info logger.info(f"Connected to {server_name} via stdio ({len(tools_info)} tools)") return server_info def get_server_info(self, server_name: str) -> Dict[str, Any]: return self.connected_servers.get(server_name, {}) def get_all_servers(self) -> Dict[str, Dict[str, Any]]: return self.connected_servers.copy()