mirror of https://github.com/kortix-ai/suna.git
337 lines
14 KiB
Python
337 lines
14 KiB
Python
from typing import Any, Dict, List, Optional
|
|
from agentpress.tool import Tool, ToolResult, ToolSchema, SchemaType
|
|
from mcp_module import mcp_service
|
|
from utils.logger import logger
|
|
import inspect
|
|
import asyncio
|
|
import time
|
|
import hashlib
|
|
import json
|
|
from agent.tools.utils.mcp_connection_manager import MCPConnectionManager
|
|
from agent.tools.utils.custom_mcp_handler import CustomMCPHandler
|
|
from agent.tools.utils.dynamic_tool_builder import DynamicToolBuilder
|
|
from agent.tools.utils.mcp_tool_executor import MCPToolExecutor
|
|
from services import redis as redis_service
|
|
|
|
|
|
class MCPSchemaRedisCache:
|
|
def __init__(self, ttl_seconds: int = 3600, key_prefix: str = "mcp_schema:"):
|
|
self._ttl = ttl_seconds
|
|
self._key_prefix = key_prefix
|
|
self._redis_client = None
|
|
|
|
async def _ensure_redis(self):
|
|
if not self._redis_client:
|
|
try:
|
|
self._redis_client = await redis_service.get_client()
|
|
except Exception as e:
|
|
logger.warning(f"Redis not available for MCP cache: {e}")
|
|
return False
|
|
return True
|
|
|
|
def _get_cache_key(self, config: Dict[str, Any]) -> str:
|
|
config_str = json.dumps(config, sort_keys=True)
|
|
config_hash = hashlib.md5(config_str.encode()).hexdigest()
|
|
return f"{self._key_prefix}{config_hash}"
|
|
|
|
async def get(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
if not await self._ensure_redis():
|
|
return None
|
|
|
|
try:
|
|
key = self._get_cache_key(config)
|
|
cached_data = await self._redis_client.get(key)
|
|
|
|
if cached_data:
|
|
logger.debug(f"⚡ Redis cache hit for MCP: {config.get('name', config.get('qualifiedName', 'Unknown'))}")
|
|
return json.loads(cached_data)
|
|
else:
|
|
logger.debug(f"Redis cache miss for MCP: {config.get('name', config.get('qualifiedName', 'Unknown'))}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error reading from Redis cache: {e}")
|
|
return None
|
|
|
|
async def set(self, config: Dict[str, Any], data: Dict[str, Any]):
|
|
if not await self._ensure_redis():
|
|
return
|
|
|
|
try:
|
|
key = self._get_cache_key(config)
|
|
serialized_data = json.dumps(data)
|
|
|
|
await self._redis_client.setex(key, self._ttl, serialized_data)
|
|
logger.debug(f"✅ Cached MCP schema in Redis for {config.get('name', config.get('qualifiedName', 'Unknown'))} (TTL: {self._ttl}s)")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error writing to Redis cache: {e}")
|
|
|
|
async def clear_pattern(self, pattern: Optional[str] = None):
|
|
if not await self._ensure_redis():
|
|
return
|
|
try:
|
|
if pattern:
|
|
search_pattern = f"{self._key_prefix}{pattern}*"
|
|
else:
|
|
search_pattern = f"{self._key_prefix}*"
|
|
|
|
keys = []
|
|
async for key in self._redis_client.scan_iter(match=search_pattern):
|
|
keys.append(key)
|
|
|
|
if keys:
|
|
await self._redis_client.delete(*keys)
|
|
logger.debug(f"Cleared {len(keys)} MCP schema cache entries from Redis")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error clearing Redis cache: {e}")
|
|
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|
if not await self._ensure_redis():
|
|
return {"available": False}
|
|
try:
|
|
count = 0
|
|
async for _ in self._redis_client.scan_iter(match=f"{self._key_prefix}*"):
|
|
count += 1
|
|
|
|
return {
|
|
"available": True,
|
|
"cached_schemas": count,
|
|
"ttl_seconds": self._ttl,
|
|
"key_prefix": self._key_prefix
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Error getting cache stats: {e}")
|
|
return {"available": False, "error": str(e)}
|
|
|
|
|
|
_redis_cache = MCPSchemaRedisCache(ttl_seconds=3600)
|
|
|
|
class MCPToolWrapper(Tool):
|
|
def __init__(self, mcp_configs: Optional[List[Dict[str, Any]]] = None, use_cache: bool = True):
|
|
self.mcp_manager = mcp_service
|
|
self.mcp_configs = mcp_configs or []
|
|
self._initialized = False
|
|
self._schemas: Dict[str, List[ToolSchema]] = {}
|
|
self._dynamic_tools = {}
|
|
self._custom_tools = {}
|
|
self.use_cache = use_cache
|
|
|
|
self.connection_manager = MCPConnectionManager()
|
|
self.custom_handler = CustomMCPHandler(self.connection_manager)
|
|
self.tool_builder = DynamicToolBuilder()
|
|
self.tool_executor = None
|
|
|
|
super().__init__()
|
|
|
|
async def _ensure_initialized(self):
|
|
if not self._initialized:
|
|
await self._initialize_servers()
|
|
await self._create_dynamic_tools()
|
|
self._initialized = True
|
|
|
|
async def _initialize_servers(self):
|
|
start_time = time.time()
|
|
|
|
standard_configs = [cfg for cfg in self.mcp_configs if not cfg.get('isCustom', False)]
|
|
custom_configs = [cfg for cfg in self.mcp_configs if cfg.get('isCustom', False)]
|
|
|
|
cached_configs = []
|
|
cached_tools_data = []
|
|
|
|
initialization_tasks = []
|
|
|
|
if standard_configs:
|
|
for config in standard_configs:
|
|
if self.use_cache:
|
|
cached_data = await _redis_cache.get(config)
|
|
if cached_data:
|
|
cached_configs.append(config.get('qualifiedName', 'Unknown'))
|
|
cached_tools_data.append(cached_data)
|
|
continue
|
|
|
|
task = self._initialize_single_standard_server(config)
|
|
initialization_tasks.append(('standard', config, task))
|
|
|
|
if custom_configs:
|
|
for config in custom_configs:
|
|
if self.use_cache:
|
|
cached_data = await _redis_cache.get(config)
|
|
if cached_data:
|
|
cached_configs.append(config.get('name', 'Unknown'))
|
|
cached_tools_data.append(cached_data)
|
|
continue
|
|
|
|
task = self._initialize_single_custom_mcp(config)
|
|
initialization_tasks.append(('custom', config, task))
|
|
|
|
if cached_tools_data:
|
|
logger.debug(f"⚡ Loaded {len(cached_configs)} MCP schemas from Redis cache: {', '.join(cached_configs)}")
|
|
for cached_data in cached_tools_data:
|
|
try:
|
|
if cached_data.get('type') == 'standard':
|
|
logger.debug("Standard MCP tools found in cache but require connection to restore")
|
|
elif cached_data.get('type') == 'custom':
|
|
custom_tools = cached_data.get('tools', {})
|
|
if custom_tools:
|
|
self.custom_handler.custom_tools.update(custom_tools)
|
|
logger.debug(f"Restored {len(custom_tools)} custom tools from cache")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to restore cached tools: {e}")
|
|
|
|
if initialization_tasks:
|
|
logger.debug(f"🚀 Initializing {len(initialization_tasks)} MCP servers in parallel (cache enabled: {self.use_cache})...")
|
|
|
|
tasks = [task for _, _, task in initialization_tasks]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
successful = 0
|
|
failed = 0
|
|
|
|
for i, result in enumerate(results):
|
|
task_type, config, _ = initialization_tasks[i]
|
|
if isinstance(result, Exception):
|
|
failed += 1
|
|
config_name = config.get('name', config.get('qualifiedName', 'Unknown'))
|
|
logger.error(f"Failed to initialize MCP server '{config_name}': {result}")
|
|
else:
|
|
successful += 1
|
|
if self.use_cache and result:
|
|
await _redis_cache.set(config, result)
|
|
|
|
elapsed_time = time.time() - start_time
|
|
logger.debug(f"⚡ MCP initialization completed in {elapsed_time:.2f}s - {successful} successful, {failed} failed, {len(cached_configs)} from cache")
|
|
else:
|
|
if cached_configs:
|
|
elapsed_time = time.time() - start_time
|
|
logger.debug(f"⚡ All {len(cached_configs)} MCP schemas loaded from Redis cache in {elapsed_time:.2f}s - instant startup!")
|
|
else:
|
|
logger.debug("No MCP servers to initialize")
|
|
|
|
async def _initialize_single_standard_server(self, config: Dict[str, Any]):
|
|
try:
|
|
logger.debug(f"Connecting to standard MCP server: {config['qualifiedName']}")
|
|
await self.mcp_manager.connect_server(config)
|
|
logger.debug(f"✓ Connected to MCP server: {config['qualifiedName']}")
|
|
|
|
tools_info = self.mcp_manager.get_all_tools_openapi()
|
|
return {'tools': tools_info, 'type': 'standard', 'timestamp': time.time()}
|
|
except Exception as e:
|
|
logger.error(f"✗ Failed to connect to MCP server {config['qualifiedName']}: {e}")
|
|
raise e
|
|
|
|
async def _initialize_single_custom_mcp(self, config: Dict[str, Any]):
|
|
try:
|
|
logger.debug(f"Initializing custom MCP: {config.get('name', 'Unknown')}")
|
|
await self.custom_handler._initialize_single_custom_mcp(config)
|
|
logger.debug(f"✓ Initialized custom MCP: {config.get('name', 'Unknown')}")
|
|
|
|
custom_tools = self.custom_handler.get_custom_tools()
|
|
return {'tools': custom_tools, 'type': 'custom', 'timestamp': time.time()}
|
|
except Exception as e:
|
|
logger.error(f"✗ Failed to initialize custom MCP {config.get('name', 'Unknown')}: {e}")
|
|
raise e
|
|
|
|
async def _initialize_standard_servers(self, standard_configs: List[Dict[str, Any]]):
|
|
pass
|
|
|
|
async def _create_dynamic_tools(self):
|
|
try:
|
|
available_tools = self.mcp_manager.get_all_tools_openapi()
|
|
custom_tools = self.custom_handler.get_custom_tools()
|
|
|
|
logger.debug(f"MCPManager returned {len(available_tools)} tools")
|
|
logger.debug(f"Custom handler returned {len(custom_tools)} custom tools")
|
|
|
|
self._custom_tools = custom_tools
|
|
|
|
self.tool_executor = MCPToolExecutor(custom_tools, self)
|
|
|
|
dynamic_methods = self.tool_builder.create_dynamic_methods(
|
|
available_tools,
|
|
custom_tools,
|
|
self._execute_mcp_tool
|
|
)
|
|
|
|
self._dynamic_tools = self.tool_builder.get_dynamic_tools()
|
|
|
|
for method_name, method in dynamic_methods.items():
|
|
setattr(self, method_name, method)
|
|
|
|
self._schemas.update(self.tool_builder.get_schemas())
|
|
|
|
logger.debug(f"Created {len(self._dynamic_tools)} dynamic MCP tool methods")
|
|
|
|
self._register_schemas()
|
|
logger.debug(f"Re-registered schemas after creating dynamic tools - total: {len(self._schemas)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating dynamic MCP tools: {e}")
|
|
|
|
def _register_schemas(self):
|
|
self._schemas.clear()
|
|
|
|
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
|
if hasattr(method, 'tool_schemas'):
|
|
self._schemas[name] = method.tool_schemas
|
|
logger.debug(f"Registered schemas for method '{name}' in {self.__class__.__name__}")
|
|
|
|
if hasattr(self, '_dynamic_tools') and self._dynamic_tools:
|
|
for tool_name, tool_data in self._dynamic_tools.items():
|
|
method_name = tool_data.get('method_name')
|
|
if method_name and method_name in self._schemas:
|
|
continue
|
|
|
|
method = tool_data.get('method')
|
|
if method and hasattr(method, 'tool_schemas'):
|
|
self._schemas[method_name] = method.tool_schemas
|
|
logger.debug(f"Registered dynamic method schemas for '{method_name}'")
|
|
|
|
logger.debug(f"Registration complete for MCPToolWrapper - total schemas: {len(self._schemas)}")
|
|
|
|
def get_schemas(self) -> Dict[str, List[ToolSchema]]:
|
|
logger.debug(f"get_schemas called - returning {len(self._schemas)} schemas")
|
|
for method_name in self._schemas:
|
|
logger.debug(f" - Schema available for: {method_name}")
|
|
return self._schemas
|
|
|
|
def __getattr__(self, name: str):
|
|
if hasattr(self, 'tool_builder') and self.tool_builder:
|
|
method = self.tool_builder.find_method_by_name(name)
|
|
if method:
|
|
return method
|
|
|
|
if hasattr(self, '_dynamic_tools') and self._dynamic_tools:
|
|
for tool_data in self._dynamic_tools.values():
|
|
if tool_data.get('method_name') == name:
|
|
return tool_data.get('method')
|
|
|
|
name_with_hyphens = name.replace('_', '-')
|
|
for tool_name, tool_data in self._dynamic_tools.items():
|
|
if tool_data.get('method_name') == name or tool_name == name_with_hyphens:
|
|
return tool_data.get('method')
|
|
|
|
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
|
|
|
async def initialize_and_register_tools(self, tool_registry=None):
|
|
await self._ensure_initialized()
|
|
if tool_registry and self._dynamic_tools:
|
|
logger.debug(f"Updating tool registry with {len(self._dynamic_tools)} MCP tools")
|
|
|
|
async def get_available_tools(self) -> List[Dict[str, Any]]:
|
|
await self._ensure_initialized()
|
|
return self.mcp_manager.get_all_tools_openapi()
|
|
|
|
async def _execute_mcp_tool(self, tool_name: str, arguments: Dict[str, Any]) -> ToolResult:
|
|
await self._ensure_initialized()
|
|
return await self.tool_executor.execute_tool(tool_name, arguments)
|
|
|
|
async def cleanup(self):
|
|
if self._initialized:
|
|
try:
|
|
await self.mcp_manager.disconnect_all()
|
|
except Exception as e:
|
|
logger.error(f"Error during MCP cleanup: {str(e)}")
|
|
finally:
|
|
self._initialized = False |