suna/backend/mcp_module/services/connection_service.py

92 lines
3.9 KiB
Python

from typing import Dict, List, Optional, Any
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from ..domain.entities import MCPConnection, MCPConnectionRequest
from ..domain.exceptions import MCPConnectionError, MCPProviderError
from ..support.providers import MCPProviderFactory
from ..protocols import Logger
class ConnectionService:
def __init__(self, provider_factory: MCPProviderFactory, logger: Logger):
self._provider_factory = provider_factory
self._logger = logger
self._connections: Dict[str, MCPConnection] = {}
async def connect_server(self, request: MCPConnectionRequest) -> MCPConnection:
self._logger.info(f"Connecting to MCP server: {request.qualified_name}")
try:
provider = self._provider_factory.create_provider(request.provider)
server_url = provider.get_server_url(request.qualified_name, request.config)
headers = provider.get_headers(
request.qualified_name,
request.config,
request.external_user_id
)
async with streamablehttp_client(server_url, headers=headers) as (
read_stream, write_stream, _
):
session = ClientSession(read_stream, write_stream)
await session.initialize()
tool_result = await session.list_tools()
tools = tool_result.tools if tool_result else []
connection = MCPConnection(
qualified_name=request.qualified_name,
name=request.name,
config=request.config,
enabled_tools=request.enabled_tools,
provider=request.provider,
external_user_id=request.external_user_id,
session=session,
tools=tools
)
self._connections[request.qualified_name] = connection
self._logger.info(f"Connected to {request.qualified_name} ({len(tools)} tools available)")
return connection
except Exception as e:
self._logger.error(f"Failed to connect to {request.qualified_name}: {str(e)}")
raise MCPConnectionError(f"Failed to connect to MCP server: {str(e)}")
async def disconnect_server(self, qualified_name: str) -> None:
connection = self._connections.get(qualified_name)
if connection and connection.session:
try:
await connection.session.close()
self._logger.info(f"Disconnected from {qualified_name}")
except Exception as e:
self._logger.warning(f"Error disconnecting from {qualified_name}: {str(e)}")
self._connections.pop(qualified_name, None)
async def disconnect_all(self) -> None:
for qualified_name in list(self._connections.keys()):
await self.disconnect_server(qualified_name)
self._connections.clear()
self._logger.info("Disconnected from all MCP servers")
def get_connection(self, qualified_name: str) -> Optional[MCPConnection]:
return self._connections.get(qualified_name)
def get_all_connections(self) -> List[MCPConnection]:
return list(self._connections.values())
def is_connected(self, qualified_name: str) -> bool:
connection = self._connections.get(qualified_name)
return connection is not None and connection.session is not None
async def connect_all(self, requests: List[MCPConnectionRequest]) -> None:
for request in requests:
try:
await self.connect_server(request)
except MCPConnectionError as e:
self._logger.error(f"Failed to connect to {request.qualified_name}: {str(e)}")
continue