suna/backend/pipedream/mcp_service.py

367 lines
14 KiB
Python

import asyncio
import json
import os
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
from typing import List, Optional, Dict, Any
from enum import Enum
import httpx
from utils.logger import logger
try:
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
MCP_AVAILABLE = True
except ImportError:
MCP_AVAILABLE = False
logger.warning("MCP client libraries not available")
class ConnectionStatus(Enum):
CONNECTED = "connected"
DISCONNECTED = "disconnected"
ERROR = "error"
PENDING = "pending"
@dataclass
class MCPTool:
name: str
description: str
input_schema: Dict[str, Any]
@dataclass
class MCPServer:
app_slug: str
app_name: str
server_url: str
project_id: str
environment: str
external_user_id: str
oauth_app_id: Optional[str] = None
status: ConnectionStatus = ConnectionStatus.DISCONNECTED
available_tools: List[MCPTool] = field(default_factory=list)
error_message: Optional[str] = None
def is_connected(self) -> bool:
return self.status == ConnectionStatus.CONNECTED
def get_tool_count(self) -> int:
return len(self.available_tools)
def add_tool(self, tool: MCPTool):
self.available_tools.append(tool)
class MCPServiceError(Exception):
pass
class MCPConnectionError(MCPServiceError):
pass
class MCPServerNotAvailableError(MCPServiceError):
pass
class AuthenticationError(MCPServiceError):
pass
class RateLimitError(MCPServiceError):
pass
class ExternalUserId:
def __init__(self, value: str):
if not value or not isinstance(value, str):
raise ValueError("ExternalUserId must be a non-empty string")
if len(value) > 255:
raise ValueError("ExternalUserId must be less than 255 characters")
self.value = value
class AppSlug:
def __init__(self, value: str):
if not value or not isinstance(value, str):
raise ValueError("AppSlug must be a non-empty string")
if not re.match(r'^[a-z0-9_-]+$', value):
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
self.value = value
class MCPService:
def __init__(self, logger=None):
self._logger = logger or logger
self.base_url = "https://api.pipedream.com/v1"
self.session = None
self.access_token = None
self.token_expires_at = None
async def _get_session(self) -> httpx.AsyncClient:
if self.session is None or self.session.is_closed:
self.session = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
headers={"User-Agent": "Suna-Pipedream-Client/1.0"}
)
return self.session
async def _ensure_access_token(self) -> str:
if self.access_token and self.token_expires_at:
if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)):
return self.access_token
else:
self.access_token = None
self.token_expires_at = None
return await self._fetch_fresh_token()
async def _fetch_fresh_token(self) -> str:
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
client_id = os.getenv("PIPEDREAM_CLIENT_ID")
client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET")
if not all([project_id, client_id, client_secret]):
raise AuthenticationError("Missing required environment variables")
session = await self._get_session()
try:
response = await session.post(
f"{self.base_url}/oauth/token",
data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret
}
)
response.raise_for_status()
data = response.json()
self.access_token = data["access_token"]
expires_in = data.get("expires_in", 3600)
self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
return self.access_token
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
raise RateLimitError("Rate limit exceeded")
raise AuthenticationError(f"Failed to obtain access token: {e}")
async def _make_request(self, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None) -> Dict[str, Any]:
session = await self._get_session()
access_token = await self._ensure_access_token()
request_headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json"
}
if headers:
request_headers.update(headers)
try:
response = await session.get(url, headers=request_headers, params=params)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
raise RateLimitError("Rate limit exceeded")
raise MCPServiceError(f"HTTP request failed: {e}")
async def test_connection(self, server: MCPServer) -> MCPServer:
if not MCP_AVAILABLE:
logger.warning(f"MCP client not available for testing {server.app_name}")
server.status = ConnectionStatus.ERROR
server.error_message = "MCP client libraries not available"
return server
try:
access_token = await self._ensure_access_token()
except Exception as e:
logger.error(f"Failed to get access token for MCP connection: {str(e)}")
server.status = ConnectionStatus.ERROR
server.error_message = f"Authentication failed: {str(e)}"
return server
headers = {
"Authorization": f"Bearer {access_token}",
"x-pd-project-id": server.project_id,
"x-pd-environment": server.environment,
"x-pd-external-user-id": server.external_user_id,
"x-pd-app-slug": server.app_slug,
}
logger.debug(f"Testing MCP connection for {server.app_name} at {server.server_url}")
try:
async with asyncio.timeout(15):
async with streamablehttp_client(server.server_url, headers=headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
tools_result = await session.list_tools()
tools = tools_result.tools if hasattr(tools_result, 'tools') else tools_result
for tool in tools:
mcp_tool = MCPTool(
name=tool.name,
description=tool.description,
input_schema=tool.inputSchema
)
server.add_tool(mcp_tool)
server.status = ConnectionStatus.CONNECTED
logger.debug(f"Successfully tested MCP server for {server.app_name} with {server.get_tool_count()} tools")
return server
except asyncio.TimeoutError:
logger.error(f"Timeout testing MCP connection for {server.app_name}")
server.status = ConnectionStatus.ERROR
server.error_message = "Connection timeout"
except Exception as e:
logger.error(f"Failed to test MCP connection for {server.app_name}: {str(e)}")
server.status = ConnectionStatus.ERROR
server.error_message = str(e)
return server
async def discover_servers_for_user(self, external_user_id: ExternalUserId, app_slug: Optional[AppSlug] = None) -> List[MCPServer]:
if not MCP_AVAILABLE:
logger.warning("MCP client libraries not available - returning empty server list")
return []
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
if not project_id:
logger.error("Missing PIPEDREAM_PROJECT_ID environment variable")
return []
logger.debug(f"Discovering MCP servers for user: {external_user_id.value}, app_slug: {app_slug.value if app_slug else 'all'}")
url = f"{self.base_url}/connect/{project_id}/accounts"
params = {"external_id": external_user_id.value}
headers = {"X-PD-Environment": environment}
try:
data = await self._make_request(url, headers=headers, params=params)
accounts = data.get("data", [])
if not accounts:
logger.debug(f"No connected apps found for user: {external_user_id.value}")
return []
user_apps = [account.get("app") for account in accounts if account.get("app")]
if app_slug:
user_apps = [app for app in user_apps if app.get("name_slug") == app_slug.value]
logger.debug(f"Filtered to {len(user_apps)} apps for app_slug: {app_slug.value}")
mcp_servers = []
for app in user_apps:
app_slug_current = app.get('name_slug')
app_name = app.get('name')
if not app_slug_current:
logger.warning(f"App missing name_slug: {app}")
continue
logger.debug(f"Creating MCP server for app: {app_name} ({app_slug_current})")
server = MCPServer(
app_slug=app_slug_current,
app_name=app_name,
server_url='https://remote.mcp.pipedream.net',
project_id=project_id,
environment=environment,
external_user_id=external_user_id.value,
status=ConnectionStatus.DISCONNECTED
)
try:
tested_server = await self.test_connection(server)
mcp_servers.append(tested_server)
logger.debug(f"Successfully tested MCP server for {app_name}: {tested_server.status.value}")
except Exception as e:
logger.warning(f"Failed to test MCP server for {app_name}: {str(e)}")
server.status = ConnectionStatus.ERROR
server.error_message = str(e)
mcp_servers.append(server)
logger.debug(f"Discovered {len(mcp_servers)} MCP servers for user: {external_user_id.value}")
return mcp_servers
except Exception as e:
logger.error(f"Error discovering MCP servers: {str(e)}")
return []
async def create_connection(self, external_user_id: ExternalUserId, app_slug: AppSlug, oauth_app_id: Optional[str] = None) -> MCPServer:
if not MCP_AVAILABLE:
raise MCPServerNotAvailableError("MCP client not available")
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
if not project_id:
raise MCPConnectionError("Missing PIPEDREAM_PROJECT_ID")
logger.debug(f"Creating MCP connection for user: {external_user_id.value}, app: {app_slug.value}")
url = f"{self.base_url}/connect/{project_id}/accounts"
params = {"external_id": external_user_id.value}
headers = {"X-PD-Environment": environment}
try:
data = await self._make_request(url, headers=headers, params=params)
accounts = data.get("data", [])
user_apps = [account.get("app") for account in accounts if account.get("app")]
connected_app = None
for app in user_apps:
if app.get('name_slug') == app_slug.value:
connected_app = app
break
if not connected_app:
raise MCPConnectionError(f"User {external_user_id.value} does not have {app_slug.value} connected")
server = MCPServer(
app_slug=app_slug.value,
app_name=connected_app.get('name'),
server_url='https://remote.mcp.pipedream.net',
project_id=project_id,
environment=environment,
external_user_id=external_user_id.value,
oauth_app_id=oauth_app_id,
status=ConnectionStatus.DISCONNECTED
)
tested_server = await self.test_connection(server)
logger.debug(f"Successfully created MCP connection for {app_slug.value}")
return tested_server
except Exception as e:
logger.error(f"Failed to create MCP connection for {app_slug.value}: {str(e)}")
raise MCPConnectionError(str(e))
async def close(self):
if self.session and not self.session.is_closed:
await self.session.aclose()
_mcp_service = None
def get_mcp_service() -> MCPService:
global _mcp_service
if _mcp_service is None:
_mcp_service = MCPService()
return _mcp_service
PipedreamException = MCPServiceError
MCPServerNotAvailableError = MCPServerNotAvailableError
AuthenticationException = AuthenticationError
HttpClientException = MCPServiceError
RateLimitException = RateLimitError