diff --git a/backend/api.py b/backend/api.py index faacbcb5..5617530c 100644 --- a/backend/api.py +++ b/backend/api.py @@ -12,34 +12,22 @@ import asyncio from utils.logger import logger import time from collections import OrderedDict +from typing import Dict, Any -from mcp import ClientSession -from mcp.client.sse import sse_client -from mcp.client.stdio import stdio_client -from mcp import StdioServerParameters from pydantic import BaseModel # Import the agent API module from agent import api as agent_api from sandbox import api as sandbox_api from services import billing as billing_api from services import transcription as transcription_api -import concurrent.futures -from typing import Dict, Any +from services.mcp_custom import discover_custom_tools import sys -from concurrent.futures import ThreadPoolExecutor - -import os -import subprocess -import json load_dotenv() if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) -# Thread pool for Windows subprocess handling -windows_executor = ThreadPoolExecutor(max_workers=4) - # Initialize managers db = DBConnection() instance_id = "single" @@ -136,19 +124,16 @@ app.add_middleware( allow_headers=["Content-Type", "Authorization"], ) -# Include the agent router with a prefix app.include_router(agent_api.router, prefix="/api") -# Include the sandbox router with a prefix app.include_router(sandbox_api.router, prefix="/api") -# Include the billing router with a prefix app.include_router(billing_api.router, prefix="/api") -# Import and include the MCP router from mcp_local import api as mcp_api + app.include_router(mcp_api.router, prefix="/api") -# Include the transcription router with a prefix + app.include_router(transcription_api.router, prefix="/api") @app.get("/api/health") @@ -162,323 +147,27 @@ async def health_check(): } class CustomMCPDiscoverRequest(BaseModel): - type: str # 'json' or 'sse' + type: str config: Dict[str, Any] -def run_mcp_stdio_sync(command, args, env_vars, timeout=30): - """Synchronous function to run MCP stdio connection on Windows""" - - try: - # Prepare environment - env = os.environ.copy() - env.update(env_vars) - - # Create subprocess with proper Windows handling - full_command = [command] + args - - process = subprocess.Popen( - full_command, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - text=True, - bufsize=0, - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0 - ) - - # MCP Initialization - init_request = { - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "mcp-client", "version": "1.0.0"} - } - } - - # Send initialization - process.stdin.write(json.dumps(init_request) + "\n") - process.stdin.flush() - - # Read initialization response - init_response_line = process.stdout.readline().strip() - if not init_response_line: - raise Exception("No response from MCP server during initialization") - - init_response = json.loads(init_response_line) - - # Send notification that initialization is complete - init_notification = { - "jsonrpc": "2.0", - "method": "notifications/initialized" - } - process.stdin.write(json.dumps(init_notification) + "\n") - process.stdin.flush() - - # Request tools list - tools_request = { - "jsonrpc": "2.0", - "id": 2, - "method": "tools/list", - "params": {} - } - - process.stdin.write(json.dumps(tools_request) + "\n") - process.stdin.flush() - - # Read tools response - tools_response_line = process.stdout.readline().strip() - if not tools_response_line: - raise Exception("No response from MCP server for tools list") - - tools_response = json.loads(tools_response_line) - - # Parse tools - tools_info = [] - if "result" in tools_response and "tools" in tools_response["result"]: - for tool in tools_response["result"]["tools"]: - tool_info = { - "name": tool["name"], - "description": tool.get("description", ""), - "input_schema": tool.get("inputSchema", {}) - } - tools_info.append(tool_info) - - return { - "status": "connected", - "transport": "stdio", - "tools": tools_info - } - - except subprocess.TimeoutExpired: - return { - "status": "error", - "error": f"Process timeout after {timeout} seconds", - "tools": [] - } - except json.JSONDecodeError as e: - return { - "status": "error", - "error": f"Invalid JSON response: {str(e)}", - "tools": [] - } - except Exception as e: - return { - "status": "error", - "error": str(e), - "tools": [] - } - finally: - try: - if 'process' in locals(): - process.terminate() - process.wait(timeout=5) - except: - pass - -async def connect_stdio_server_windows(server_name, server_config, all_tools, timeout): - """Windows-compatible stdio connection using subprocess""" - - logger.info(f"Connecting to {server_name} using Windows subprocess method") - - command = server_config["command"] - args = server_config.get("args", []) - env_vars = server_config.get("env", {}) - - # Run in thread pool to avoid blocking - loop = asyncio.get_event_loop() - result = await loop.run_in_executor( - windows_executor, - run_mcp_stdio_sync, - command, - args, - env_vars, - timeout - ) - - all_tools[server_name] = result - - if result["status"] == "connected": - logger.info(f" {server_name}: Connected via Windows subprocess ({len(result['tools'])} tools)") - else: - logger.error(f" {server_name}: Error - {result['error']}") - -async def list_mcp_tools_mixed_windows(config, timeout=15): - """Windows-compatible version of list_mcp_tools_mixed""" - all_tools = {} - - if "mcpServers" not in config: - return all_tools - - mcp_servers = config["mcpServers"] - - for server_name, server_config in mcp_servers.items(): - logger.info(f"Connecting to MCP server: {server_name}") - if server_config.get("disabled", False): - all_tools[server_name] = {"status": "disabled", "tools": []} - logger.info(f" {server_name}: Disabled") - continue - - try: - await connect_stdio_server_windows(server_name, server_config, all_tools, timeout) - - except asyncio.TimeoutError: - all_tools[server_name] = { - "status": "error", - "error": f"Connection timeout after {timeout} seconds", - "tools": [] - } - logger.error(f" {server_name}: Timeout after {timeout} seconds") - except Exception as e: - error_msg = str(e) - all_tools[server_name] = { - "status": "error", - "error": error_msg, - "tools": [] - } - logger.error(f" {server_name}: Error - {error_msg}") - import traceback - logger.debug(f"Full traceback for {server_name}: {traceback.format_exc()}") - - return all_tools - -# Modified API endpoint @app.post("/api/mcp/discover-custom-tools") async def discover_custom_mcp_tools(request: CustomMCPDiscoverRequest): - """Discover tools from a custom MCP server configuration - Windows compatible.""" try: - logger.info(f"Received custom MCP discovery request: type={request.type}") - logger.debug(f"Request config: {request.config}") - - tools = [] - server_name = None - - if request.type == 'json': - try: - # Use Windows-compatible version - all_tools = await list_mcp_tools_mixed_windows(request.config, timeout=30) - - # Extract the first server name from the config - if "mcpServers" in request.config and request.config["mcpServers"]: - server_name = list(request.config["mcpServers"].keys())[0] - - # Check if the server exists in the results and has tools - if server_name in all_tools: - server_info = all_tools[server_name] - if server_info["status"] == "connected": - tools = server_info["tools"] - logger.info(f"Found {len(tools)} tools for server {server_name}") - else: - # Server had an error or was disabled - error_msg = server_info.get("error", "Unknown error") - logger.error(f"Server {server_name} failed: {error_msg}") - raise HTTPException( - status_code=400, - detail=f"Failed to connect to MCP server '{server_name}': {error_msg}" - ) - else: - logger.error(f"Server {server_name} not found in results") - raise HTTPException(status_code=400, detail=f"Server '{server_name}' not found in results") - else: - logger.error("No MCP servers configured") - raise HTTPException(status_code=400, detail="No MCP servers configured") - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error connecting to stdio MCP server: {e}") - import traceback - logger.error(f"Full traceback: {traceback.format_exc()}") - raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}") - - elif request.type == 'sse': - # SSE handling remains the same as it doesn't use subprocess - if 'url' not in request.config: - raise HTTPException(status_code=400, detail="SSE configuration must include 'url' field") - - url = request.config['url'] - headers = request.config.get('headers', {}) - - try: - async with asyncio.timeout(15): - try: - async with sse_client(url, headers=headers) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - tools_result = await session.list_tools() - tools_info = [] - for tool in tools_result.tools: - tool_info = { - "name": tool.name, - "description": tool.description, - "input_schema": tool.inputSchema - } - tools_info.append(tool_info) - - for tool_info in tools_info: - tools.append({ - "name": tool_info["name"], - "description": tool_info["description"], - "inputSchema": tool_info["input_schema"] - }) - except TypeError as e: - if "unexpected keyword argument" in str(e): - async with sse_client(url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - tools_result = await session.list_tools() - tools_info = [] - for tool in tools_result.tools: - tool_info = { - "name": tool.name, - "description": tool.description, - "input_schema": tool.inputSchema - } - tools_info.append(tool_info) - - for tool_info in tools_info: - tools.append({ - "name": tool_info["name"], - "description": tool_info["description"], - "inputSchema": tool_info["input_schema"] - }) - else: - raise - except asyncio.TimeoutError: - raise HTTPException(status_code=408, detail="Connection timeout - server took too long to respond") - except Exception as e: - logger.error(f"Error connecting to SSE MCP server: {e}") - raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}") - else: - raise HTTPException(status_code=400, detail="Invalid server type. Must be 'json' or 'sse'") - - response_data = {"tools": tools, "count": len(tools)} - - if server_name: - response_data["serverName"] = server_name - - logger.info(f"Returning {len(tools)} tools for server {server_name}") - return response_data - + return await discover_custom_tools(request.type, request.config) except HTTPException: raise except Exception as e: logger.error(f"Error discovering custom MCP tools: {e}") raise HTTPException(status_code=500, detail=str(e)) -# Make sure to set the Windows event loop policy at app startup if __name__ == "__main__": import uvicorn - # Set Windows event loop policy if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) - workers = 1 # Keep single worker for Windows compatibility + workers = 1 logger.info(f"Starting server on 0.0.0.0:8000 with {workers} workers") uvicorn.run( @@ -486,5 +175,5 @@ if __name__ == "__main__": host="0.0.0.0", port=8000, workers=workers, - loop="asyncio" # Explicitly use asyncio event loop + loop="asyncio" ) \ No newline at end of file diff --git a/backend/services/mcp_custom.py b/backend/services/mcp_custom.py new file mode 100644 index 00000000..83190ee5 --- /dev/null +++ b/backend/services/mcp_custom.py @@ -0,0 +1,291 @@ +import os +import sys +import json +import asyncio +import subprocess +from typing import Dict, Any +from concurrent.futures import ThreadPoolExecutor +from fastapi import HTTPException # type: ignore +from utils.logger import logger +from mcp import ClientSession +from mcp.client.sse import sse_client # type: ignore + +windows_executor = ThreadPoolExecutor(max_workers=4) + +def run_mcp_stdio_sync(command, args, env_vars, timeout=30): + try: + env = os.environ.copy() + env.update(env_vars) + + full_command = [command] + args + + process = subprocess.Popen( + full_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + text=True, + bufsize=0, + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0 + ) + + init_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "mcp-client", "version": "1.0.0"} + } + } + + process.stdin.write(json.dumps(init_request) + "\n") + process.stdin.flush() + + init_response_line = process.stdout.readline().strip() + if not init_response_line: + raise Exception("No response from MCP server during initialization") + + init_response = json.loads(init_response_line) + + init_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized" + } + process.stdin.write(json.dumps(init_notification) + "\n") + process.stdin.flush() + + tools_request = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} + } + + process.stdin.write(json.dumps(tools_request) + "\n") + process.stdin.flush() + + tools_response_line = process.stdout.readline().strip() + if not tools_response_line: + raise Exception("No response from MCP server for tools list") + + tools_response = json.loads(tools_response_line) + + tools_info = [] + if "result" in tools_response and "tools" in tools_response["result"]: + for tool in tools_response["result"]["tools"]: + tool_info = { + "name": tool["name"], + "description": tool.get("description", ""), + "input_schema": tool.get("inputSchema", {}) + } + tools_info.append(tool_info) + + return { + "status": "connected", + "transport": "stdio", + "tools": tools_info + } + + except subprocess.TimeoutExpired: + return { + "status": "error", + "error": f"Process timeout after {timeout} seconds", + "tools": [] + } + except json.JSONDecodeError as e: + return { + "status": "error", + "error": f"Invalid JSON response: {str(e)}", + "tools": [] + } + except Exception as e: + return { + "status": "error", + "error": str(e), + "tools": [] + } + finally: + try: + if 'process' in locals(): + process.terminate() + process.wait(timeout=5) + except: + pass + + +async def connect_stdio_server_windows(server_name, server_config, all_tools, timeout): + """Windows-compatible stdio connection using subprocess""" + + logger.info(f"Connecting to {server_name} using Windows subprocess method") + + command = server_config["command"] + args = server_config.get("args", []) + env_vars = server_config.get("env", {}) + + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + windows_executor, + run_mcp_stdio_sync, + command, + args, + env_vars, + timeout + ) + + all_tools[server_name] = result + + if result["status"] == "connected": + logger.info(f" {server_name}: Connected via Windows subprocess ({len(result['tools'])} tools)") + else: + logger.error(f" {server_name}: Error - {result['error']}") + + +async def list_mcp_tools_mixed_windows(config, timeout=15): + all_tools = {} + + if "mcpServers" not in config: + return all_tools + + mcp_servers = config["mcpServers"] + + for server_name, server_config in mcp_servers.items(): + logger.info(f"Connecting to MCP server: {server_name}") + if server_config.get("disabled", False): + all_tools[server_name] = {"status": "disabled", "tools": []} + logger.info(f" {server_name}: Disabled") + continue + + try: + await connect_stdio_server_windows(server_name, server_config, all_tools, timeout) + + except asyncio.TimeoutError: + all_tools[server_name] = { + "status": "error", + "error": f"Connection timeout after {timeout} seconds", + "tools": [] + } + logger.error(f" {server_name}: Timeout after {timeout} seconds") + except Exception as e: + error_msg = str(e) + all_tools[server_name] = { + "status": "error", + "error": error_msg, + "tools": [] + } + logger.error(f" {server_name}: Error - {error_msg}") + import traceback + logger.debug(f"Full traceback for {server_name}: {traceback.format_exc()}") + + return all_tools + + +async def discover_custom_tools(request_type: str, config: Dict[str, Any]): + logger.info(f"Received custom MCP discovery request: type={request_type}") + logger.debug(f"Request config: {config}") + + tools = [] + server_name = None + + if request_type == 'json': + try: + all_tools = await list_mcp_tools_mixed_windows(config, timeout=30) + if "mcpServers" in config and config["mcpServers"]: + server_name = list(config["mcpServers"].keys())[0] + + if server_name in all_tools: + server_info = all_tools[server_name] + if server_info["status"] == "connected": + tools = server_info["tools"] + logger.info(f"Found {len(tools)} tools for server {server_name}") + else: + error_msg = server_info.get("error", "Unknown error") + logger.error(f"Server {server_name} failed: {error_msg}") + raise HTTPException( + status_code=400, + detail=f"Failed to connect to MCP server '{server_name}': {error_msg}" + ) + else: + logger.error(f"Server {server_name} not found in results") + raise HTTPException(status_code=400, detail=f"Server '{server_name}' not found in results") + else: + logger.error("No MCP servers configured") + raise HTTPException(status_code=400, detail="No MCP servers configured") + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error connecting to stdio MCP server: {e}") + import traceback + logger.error(f"Full traceback: {traceback.format_exc()}") + raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}") + + elif request_type == 'sse': + if 'url' not in config: + raise HTTPException(status_code=400, detail="SSE configuration must include 'url' field") + + url = config['url'] + headers = config.get('headers', {}) + + try: + async with asyncio.timeout(15): + try: + async with sse_client(url, headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + tools_result = await session.list_tools() + tools_info = [] + for tool in tools_result.tools: + tool_info = { + "name": tool.name, + "description": tool.description, + "input_schema": tool.inputSchema + } + tools_info.append(tool_info) + + for tool_info in tools_info: + tools.append({ + "name": tool_info["name"], + "description": tool_info["description"], + "inputSchema": tool_info["input_schema"] + }) + except TypeError as e: + if "unexpected keyword argument" in str(e): + async with sse_client(url) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + tools_result = await session.list_tools() + tools_info = [] + for tool in tools_result.tools: + tool_info = { + "name": tool.name, + "description": tool.description, + "input_schema": tool.inputSchema + } + tools_info.append(tool_info) + + for tool_info in tools_info: + tools.append({ + "name": tool_info["name"], + "description": tool_info["description"], + "inputSchema": tool_info["input_schema"] + }) + else: + raise + except asyncio.TimeoutError: + raise HTTPException(status_code=408, detail="Connection timeout - server took too long to respond") + except Exception as e: + logger.error(f"Error connecting to SSE MCP server: {e}") + raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}") + else: + raise HTTPException(status_code=400, detail="Invalid server type. Must be 'json' or 'sse'") + + response_data = {"tools": tools, "count": len(tools)} + + if server_name: + response_data["serverName"] = server_name + + logger.info(f"Returning {len(tools)} tools for server {server_name}") + return response_data \ No newline at end of file