suna/backend/agent/tools/mcp_tool_wrapper.py

337 lines
14 KiB
Python

from typing import Any, Dict, List, Optional
from agentpress.tool import Tool, ToolResult, ToolSchema, SchemaType
from mcp_module import mcp_service
from utils.logger import logger
import inspect
import asyncio
import time
import hashlib
import json
from agent.tools.utils.mcp_connection_manager import MCPConnectionManager
from agent.tools.utils.custom_mcp_handler import CustomMCPHandler
from agent.tools.utils.dynamic_tool_builder import DynamicToolBuilder
from agent.tools.utils.mcp_tool_executor import MCPToolExecutor
from services import redis as redis_service
class MCPSchemaRedisCache:
def __init__(self, ttl_seconds: int = 3600, key_prefix: str = "mcp_schema:"):
self._ttl = ttl_seconds
self._key_prefix = key_prefix
self._redis_client = None
async def _ensure_redis(self):
if not self._redis_client:
try:
self._redis_client = await redis_service.get_client()
except Exception as e:
logger.warning(f"Redis not available for MCP cache: {e}")
return False
return True
def _get_cache_key(self, config: Dict[str, Any]) -> str:
config_str = json.dumps(config, sort_keys=True)
config_hash = hashlib.md5(config_str.encode()).hexdigest()
return f"{self._key_prefix}{config_hash}"
async def get(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not await self._ensure_redis():
return None
try:
key = self._get_cache_key(config)
cached_data = await self._redis_client.get(key)
if cached_data:
logger.debug(f"⚡ Redis cache hit for MCP: {config.get('name', config.get('qualifiedName', 'Unknown'))}")
return json.loads(cached_data)
else:
logger.debug(f"Redis cache miss for MCP: {config.get('name', config.get('qualifiedName', 'Unknown'))}")
return None
except Exception as e:
logger.warning(f"Error reading from Redis cache: {e}")
return None
async def set(self, config: Dict[str, Any], data: Dict[str, Any]):
if not await self._ensure_redis():
return
try:
key = self._get_cache_key(config)
serialized_data = json.dumps(data)
await self._redis_client.setex(key, self._ttl, serialized_data)
logger.debug(f"✅ Cached MCP schema in Redis for {config.get('name', config.get('qualifiedName', 'Unknown'))} (TTL: {self._ttl}s)")
except Exception as e:
logger.warning(f"Error writing to Redis cache: {e}")
async def clear_pattern(self, pattern: Optional[str] = None):
if not await self._ensure_redis():
return
try:
if pattern:
search_pattern = f"{self._key_prefix}{pattern}*"
else:
search_pattern = f"{self._key_prefix}*"
keys = []
async for key in self._redis_client.scan_iter(match=search_pattern):
keys.append(key)
if keys:
await self._redis_client.delete(*keys)
logger.debug(f"Cleared {len(keys)} MCP schema cache entries from Redis")
except Exception as e:
logger.warning(f"Error clearing Redis cache: {e}")
async def get_stats(self) -> Dict[str, Any]:
if not await self._ensure_redis():
return {"available": False}
try:
count = 0
async for _ in self._redis_client.scan_iter(match=f"{self._key_prefix}*"):
count += 1
return {
"available": True,
"cached_schemas": count,
"ttl_seconds": self._ttl,
"key_prefix": self._key_prefix
}
except Exception as e:
logger.warning(f"Error getting cache stats: {e}")
return {"available": False, "error": str(e)}
_redis_cache = MCPSchemaRedisCache(ttl_seconds=3600)
class MCPToolWrapper(Tool):
def __init__(self, mcp_configs: Optional[List[Dict[str, Any]]] = None, use_cache: bool = True):
self.mcp_manager = mcp_service
self.mcp_configs = mcp_configs or []
self._initialized = False
self._schemas: Dict[str, List[ToolSchema]] = {}
self._dynamic_tools = {}
self._custom_tools = {}
self.use_cache = use_cache
self.connection_manager = MCPConnectionManager()
self.custom_handler = CustomMCPHandler(self.connection_manager)
self.tool_builder = DynamicToolBuilder()
self.tool_executor = None
super().__init__()
async def _ensure_initialized(self):
if not self._initialized:
await self._initialize_servers()
await self._create_dynamic_tools()
self._initialized = True
async def _initialize_servers(self):
start_time = time.time()
standard_configs = [cfg for cfg in self.mcp_configs if not cfg.get('isCustom', False)]
custom_configs = [cfg for cfg in self.mcp_configs if cfg.get('isCustom', False)]
cached_configs = []
cached_tools_data = []
initialization_tasks = []
if standard_configs:
for config in standard_configs:
if self.use_cache:
cached_data = await _redis_cache.get(config)
if cached_data:
cached_configs.append(config.get('qualifiedName', 'Unknown'))
cached_tools_data.append(cached_data)
continue
task = self._initialize_single_standard_server(config)
initialization_tasks.append(('standard', config, task))
if custom_configs:
for config in custom_configs:
if self.use_cache:
cached_data = await _redis_cache.get(config)
if cached_data:
cached_configs.append(config.get('name', 'Unknown'))
cached_tools_data.append(cached_data)
continue
task = self._initialize_single_custom_mcp(config)
initialization_tasks.append(('custom', config, task))
if cached_tools_data:
logger.debug(f"⚡ Loaded {len(cached_configs)} MCP schemas from Redis cache: {', '.join(cached_configs)}")
for cached_data in cached_tools_data:
try:
if cached_data.get('type') == 'standard':
logger.debug("Standard MCP tools found in cache but require connection to restore")
elif cached_data.get('type') == 'custom':
custom_tools = cached_data.get('tools', {})
if custom_tools:
self.custom_handler.custom_tools.update(custom_tools)
logger.debug(f"Restored {len(custom_tools)} custom tools from cache")
except Exception as e:
logger.warning(f"Failed to restore cached tools: {e}")
if initialization_tasks:
logger.debug(f"🚀 Initializing {len(initialization_tasks)} MCP servers in parallel (cache enabled: {self.use_cache})...")
tasks = [task for _, _, task in initialization_tasks]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful = 0
failed = 0
for i, result in enumerate(results):
task_type, config, _ = initialization_tasks[i]
if isinstance(result, Exception):
failed += 1
config_name = config.get('name', config.get('qualifiedName', 'Unknown'))
logger.error(f"Failed to initialize MCP server '{config_name}': {result}")
else:
successful += 1
if self.use_cache and result:
await _redis_cache.set(config, result)
elapsed_time = time.time() - start_time
logger.debug(f"⚡ MCP initialization completed in {elapsed_time:.2f}s - {successful} successful, {failed} failed, {len(cached_configs)} from cache")
else:
if cached_configs:
elapsed_time = time.time() - start_time
logger.debug(f"⚡ All {len(cached_configs)} MCP schemas loaded from Redis cache in {elapsed_time:.2f}s - instant startup!")
else:
logger.debug("No MCP servers to initialize")
async def _initialize_single_standard_server(self, config: Dict[str, Any]):
try:
logger.debug(f"Connecting to standard MCP server: {config['qualifiedName']}")
await self.mcp_manager.connect_server(config)
logger.debug(f"✓ Connected to MCP server: {config['qualifiedName']}")
tools_info = self.mcp_manager.get_all_tools_openapi()
return {'tools': tools_info, 'type': 'standard', 'timestamp': time.time()}
except Exception as e:
logger.error(f"✗ Failed to connect to MCP server {config['qualifiedName']}: {e}")
raise e
async def _initialize_single_custom_mcp(self, config: Dict[str, Any]):
try:
logger.debug(f"Initializing custom MCP: {config.get('name', 'Unknown')}")
await self.custom_handler._initialize_single_custom_mcp(config)
logger.debug(f"✓ Initialized custom MCP: {config.get('name', 'Unknown')}")
custom_tools = self.custom_handler.get_custom_tools()
return {'tools': custom_tools, 'type': 'custom', 'timestamp': time.time()}
except Exception as e:
logger.error(f"✗ Failed to initialize custom MCP {config.get('name', 'Unknown')}: {e}")
raise e
async def _initialize_standard_servers(self, standard_configs: List[Dict[str, Any]]):
pass
async def _create_dynamic_tools(self):
try:
available_tools = self.mcp_manager.get_all_tools_openapi()
custom_tools = self.custom_handler.get_custom_tools()
logger.debug(f"MCPManager returned {len(available_tools)} tools")
logger.debug(f"Custom handler returned {len(custom_tools)} custom tools")
self._custom_tools = custom_tools
self.tool_executor = MCPToolExecutor(custom_tools, self)
dynamic_methods = self.tool_builder.create_dynamic_methods(
available_tools,
custom_tools,
self._execute_mcp_tool
)
self._dynamic_tools = self.tool_builder.get_dynamic_tools()
for method_name, method in dynamic_methods.items():
setattr(self, method_name, method)
self._schemas.update(self.tool_builder.get_schemas())
logger.debug(f"Created {len(self._dynamic_tools)} dynamic MCP tool methods")
self._register_schemas()
logger.debug(f"Re-registered schemas after creating dynamic tools - total: {len(self._schemas)}")
except Exception as e:
logger.error(f"Error creating dynamic MCP tools: {e}")
def _register_schemas(self):
self._schemas.clear()
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(method, 'tool_schemas'):
self._schemas[name] = method.tool_schemas
logger.debug(f"Registered schemas for method '{name}' in {self.__class__.__name__}")
if hasattr(self, '_dynamic_tools') and self._dynamic_tools:
for tool_name, tool_data in self._dynamic_tools.items():
method_name = tool_data.get('method_name')
if method_name and method_name in self._schemas:
continue
method = tool_data.get('method')
if method and hasattr(method, 'tool_schemas'):
self._schemas[method_name] = method.tool_schemas
logger.debug(f"Registered dynamic method schemas for '{method_name}'")
logger.debug(f"Registration complete for MCPToolWrapper - total schemas: {len(self._schemas)}")
def get_schemas(self) -> Dict[str, List[ToolSchema]]:
logger.debug(f"get_schemas called - returning {len(self._schemas)} schemas")
for method_name in self._schemas:
logger.debug(f" - Schema available for: {method_name}")
return self._schemas
def __getattr__(self, name: str):
if hasattr(self, 'tool_builder') and self.tool_builder:
method = self.tool_builder.find_method_by_name(name)
if method:
return method
if hasattr(self, '_dynamic_tools') and self._dynamic_tools:
for tool_data in self._dynamic_tools.values():
if tool_data.get('method_name') == name:
return tool_data.get('method')
name_with_hyphens = name.replace('_', '-')
for tool_name, tool_data in self._dynamic_tools.items():
if tool_data.get('method_name') == name or tool_name == name_with_hyphens:
return tool_data.get('method')
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
async def initialize_and_register_tools(self, tool_registry=None):
await self._ensure_initialized()
if tool_registry and self._dynamic_tools:
logger.debug(f"Updating tool registry with {len(self._dynamic_tools)} MCP tools")
async def get_available_tools(self) -> List[Dict[str, Any]]:
await self._ensure_initialized()
return self.mcp_manager.get_all_tools_openapi()
async def _execute_mcp_tool(self, tool_name: str, arguments: Dict[str, Any]) -> ToolResult:
await self._ensure_initialized()
return await self.tool_executor.execute_tool(tool_name, arguments)
async def cleanup(self):
if self._initialized:
try:
await self.mcp_manager.disconnect_all()
except Exception as e:
logger.error(f"Error during MCP cleanup: {str(e)}")
finally:
self._initialized = False