From 773987bebcc99333f493eb2ecaac1db5ef2e8468 Mon Sep 17 00:00:00 2001 From: Saumya Date: Wed, 30 Jul 2025 19:33:43 +0530 Subject: [PATCH] refactor pipedream --- backend/agent/api.py | 8 +- backend/pipedream/__init__.py | 8 +- backend/pipedream/api.py | 82 +- backend/pipedream/connection_service.py | 249 ++--- backend/pipedream/connection_token_service.py | 177 ++-- backend/pipedream/mcp_service.py | 419 ++++---- backend/pipedream/profile_service.py | 908 ++++++++---------- 7 files changed, 790 insertions(+), 1061 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index 511a29f5..aa9f6251 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -2206,15 +2206,15 @@ async def get_pipedream_tools_for_agent( try: from pipedream.mcp_service import ExternalUserId, AppSlug - external_user_id = ExternalUserId(profile.external_user_id.value) - app_slug_obj = AppSlug(profile.app_slug.value) + external_user_id = ExternalUserId(profile.external_user_id) + app_slug_obj = AppSlug(profile.app_slug) logger.info(f"Discovering servers for user {external_user_id.value} and app {app_slug_obj.value}") servers = await mcp_service.discover_servers_for_user(external_user_id, app_slug_obj) - logger.info(f"Found {len(servers)} servers: {[s.app_slug.value for s in servers]}") + logger.info(f"Found {len(servers)} servers: {[s.app_slug for s in servers]}") server = servers[0] if servers else None - logger.info(f"Selected server: {server.app_slug.value if server else 'None'} with {len(server.available_tools) if server else 0} tools") + logger.info(f"Selected server: {server.app_slug if server else 'None'} with {len(server.available_tools) if server else 0} tools") if not server: return { diff --git a/backend/pipedream/__init__.py b/backend/pipedream/__init__.py index b0e25ad5..3816d6d6 100644 --- a/backend/pipedream/__init__.py +++ b/backend/pipedream/__init__.py @@ -9,11 +9,11 @@ from .connection_token_service import ConnectionTokenService db = DBConnection() -profile_service = ProfileService(db=db, logger=logger) -connection_service = ConnectionService(logger=logger) +profile_service = ProfileService() +connection_service = ConnectionService() app_service = AppService(logger=logger) -mcp_service = MCPService(logger=logger) -connection_token_service = ConnectionTokenService(logger=logger) +mcp_service = MCPService() +connection_token_service = ConnectionTokenService() from . import api api.profile_service = profile_service diff --git a/backend/pipedream/api.py b/backend/pipedream/api.py index 5826600d..30997488 100644 --- a/backend/pipedream/api.py +++ b/backend/pipedream/api.py @@ -6,10 +6,10 @@ from datetime import datetime from utils.logger import logger from utils.auth_utils import get_current_user_id_from_jwt -from .profile_service import ProfileService, Profile +from .profile_service import ProfileService, Profile, ProfileServiceError, ProfileNotFoundError, ProfileAlreadyExistsError, InvalidConfigError, EncryptionError from .connection_service import ConnectionService from .app_service import AppService -from .mcp_service import MCPService, ConnectionStatus +from .mcp_service import MCPService, ConnectionStatus, MCPConnectionError, MCPServiceError from .connection_token_service import ConnectionTokenService import httpx @@ -119,11 +119,11 @@ class ProfileResponse(BaseModel): profile_id=profile.profile_id, account_id=profile.account_id, mcp_qualified_name=profile.mcp_qualified_name, - profile_name=profile.profile_name.value, + profile_name=profile.profile_name, display_name=profile.display_name, - app_slug=profile.app_slug.value, + app_slug=profile.app_slug, app_name=profile.app_name, - external_user_id=profile.external_user_id.value, + external_user_id=profile.external_user_id, enabled_tools=profile.enabled_tools, is_active=profile.is_active, is_default=profile.is_default, @@ -145,15 +145,15 @@ def _handle_pipedream_exception(e: Exception) -> HTTPException: return HTTPException(status_code=404, detail=str(e)) elif isinstance(e, ProfileAlreadyExistsError): return HTTPException(status_code=409, detail=str(e)) - elif isinstance(e, ValidationException): + elif isinstance(e, InvalidConfigError): return HTTPException(status_code=400, detail=str(e)) - elif isinstance(e, ConnectionNotFoundError): - return HTTPException(status_code=404, detail=str(e)) - elif isinstance(e, AppNotFoundError): - return HTTPException(status_code=404, detail=str(e)) + elif isinstance(e, EncryptionError): + return HTTPException(status_code=500, detail=str(e)) elif isinstance(e, MCPConnectionError): return HTTPException(status_code=502, detail=str(e)) - elif isinstance(e, PipedreamException): + elif isinstance(e, MCPServiceError): + return HTTPException(status_code=500, detail=str(e)) + elif isinstance(e, ProfileServiceError): return HTTPException(status_code=500, detail=str(e)) else: return HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @@ -203,7 +203,7 @@ async def get_user_connections( for connection in connections: connection_data.append({ "name": connection.app.name, - "name_slug": connection.app.slug.value, + "name_slug": connection.app.slug, "description": connection.app.description, "category": connection.app.category, "img_src": connection.app.logo_url, @@ -251,12 +251,12 @@ async def discover_mcp_servers( }) server_data.append({ - "app_slug": server.app_slug.value, + "app_slug": server.app_slug, "app_name": server.app_name, - "server_url": server.server_url.value, + "server_url": server.server_url, "project_id": server.project_id, "environment": server.environment, - "external_user_id": server.external_user_id.value, + "external_user_id": server.external_user_id, "oauth_app_id": server.oauth_app_id, "status": server.status.value, "available_tools": tools_data, @@ -300,12 +300,12 @@ async def discover_mcp_servers_for_profile( }) server_data.append({ - "app_slug": server.app_slug.value, + "app_slug": server.app_slug, "app_name": server.app_name, - "server_url": server.server_url.value, + "server_url": server.server_url, "project_id": server.project_id, "environment": server.environment, - "external_user_id": server.external_user_id.value, + "external_user_id": server.external_user_id, "oauth_app_id": server.oauth_app_id, "status": server.status.value, "available_tools": tools_data, @@ -351,12 +351,12 @@ async def create_mcp_connection( }) mcp_config = { - "app_slug": server.app_slug.value, + "app_slug": server.app_slug, "app_name": server.app_name, - "server_url": server.server_url.value, + "server_url": server.server_url, "project_id": server.project_id, "environment": server.environment, - "external_user_id": server.external_user_id.value, + "external_user_id": server.external_user_id, "oauth_app_id": server.oauth_app_id, "status": server.status.value, "available_tools": tools_data @@ -540,12 +540,11 @@ async def create_credential_profile( logger.info(f"Creating credential profile for user: {user_id}, app: {request.app_slug}") try: - from uuid import UUID profile = await profile_service.create_profile( - account_id=UUID(user_id), - profile_name=request.profile_name, - app_slug=request.app_slug, - app_name=request.app_name, + user_id, + request.profile_name, + request.app_slug, + request.app_name, description=request.description, is_default=request.is_default, oauth_app_id=request.oauth_app_id, @@ -571,9 +570,7 @@ async def get_credential_profiles( actual_app_slug = _strip_pipedream_prefix(app_slug) try: - from uuid import UUID - profiles = await profile_service.get_profiles(UUID(user_id), actual_app_slug, is_active) - + profiles = await profile_service.get_profiles(user_id, actual_app_slug, is_active) return [ProfileResponse.from_domain(profile) for profile in profiles] except Exception as e: @@ -589,8 +586,7 @@ async def get_credential_profile( logger.info(f"Getting credential profile: {profile_id} for user: {user_id}") try: - from uuid import UUID - profile = await profile_service.get_profile(UUID(user_id), UUID(profile_id)) + profile = await profile_service.get_profile(user_id, profile_id) if not profile: from .profile_service import ProfileNotFoundError @@ -612,10 +608,9 @@ async def update_credential_profile( logger.info(f"Updating credential profile: {profile_id} for user: {user_id}") try: - from uuid import UUID profile = await profile_service.update_profile( - account_id=UUID(user_id), - profile_id=UUID(profile_id), + user_id, + profile_id, profile_name=request.profile_name, display_name=request.display_name, is_active=request.is_active, @@ -638,8 +633,7 @@ async def delete_credential_profile( logger.info(f"Deleting credential profile: {profile_id} for user: {user_id}") try: - from uuid import UUID - success = await profile_service.delete_profile(UUID(user_id), UUID(profile_id)) + success = await profile_service.delete_profile(user_id, profile_id) if not success: raise ProfileNotFoundError(profile_id) @@ -666,12 +660,12 @@ async def connect_credential_profile( from .profile_service import ProfileNotFoundError from .connection_token_service import ExternalUserId, AppSlug - profile = await profile_service.get_profile(UUID(user_id), UUID(profile_id)) + profile = await profile_service.get_profile(user_id, profile_id) if not profile: raise ProfileNotFoundError(profile_id) - external_user_id = ExternalUserId(profile.external_user_id.value) - app_slug = AppSlug(actual_app or profile.app_slug.value) + external_user_id = ExternalUserId(profile.external_user_id) + app_slug = AppSlug(actual_app or profile.app_slug) result = await connection_token_service.create(external_user_id, app_slug) return { @@ -680,8 +674,8 @@ async def connect_credential_profile( "token": result.get("token"), "expires_at": result.get("expires_at"), "profile_id": profile_id, - "external_user_id": profile.external_user_id.value, - "app": actual_app or profile.app_slug.value + "external_user_id": profile.external_user_id, + "app": actual_app or profile.app_slug } except Exception as e: @@ -701,18 +695,18 @@ async def get_profile_connections( from .profile_service import ProfileNotFoundError from .connection_service import ExternalUserId - profile = await profile_service.get_profile(UUID(user_id), UUID(profile_id)) + profile = await profile_service.get_profile(user_id, profile_id) if not profile: raise ProfileNotFoundError(profile_id) - external_user_id = ExternalUserId(profile.external_user_id.value) + external_user_id = ExternalUserId(profile.external_user_id) connections = await connection_service.get_connections_for_user(external_user_id) connection_data = [] for connection in connections: connection_data.append({ "name": connection.app.name, - "name_slug": connection.app.slug.value, + "name_slug": connection.app.slug, "description": connection.app.description, "category": connection.app.category, "img_src": connection.app.logo_url, diff --git a/backend/pipedream/connection_service.py b/backend/pipedream/connection_service.py index 4c621a3d..0461c1e6 100644 --- a/backend/pipedream/connection_service.py +++ b/backend/pipedream/connection_service.py @@ -1,31 +1,13 @@ -from typing import List, Optional, Protocol, Dict, Any -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from enum import Enum import os -import logging import re +from dataclasses import dataclass, field +from datetime import datetime, timezone, timedelta +from typing import List, Optional, Dict, Any +from enum import Enum + import httpx -from uuid import UUID -import json +from utils.logger import logger -@dataclass(frozen=True) -class ExternalUserId: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("ExternalUserId must be a non-empty string") - if len(self.value) > 255: - raise ValueError("ExternalUserId must be less than 255 characters") - -@dataclass(frozen=True) -class AppSlug: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("AppSlug must be a non-empty string") - if not re.match(r'^[a-z0-9_-]+$', self.value): - raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores") class AuthType(Enum): OAUTH = "oauth" @@ -41,10 +23,11 @@ class AuthType(Enum): return cls.CUSTOM return super()._missing_(value) + @dataclass class App: name: str - slug: AppSlug + slug: str description: str category: str logo_url: Optional[str] = None @@ -57,60 +40,52 @@ class App: def is_featured(self) -> bool: return self.featured_weight > 0 + @dataclass class Connection: - external_user_id: ExternalUserId + external_user_id: str app: App created_at: datetime updated_at: datetime is_active: bool = True - - def activate(self) -> None: - self.is_active = True - self.updated_at = datetime.utcnow() - - def deactivate(self) -> None: - self.is_active = False - self.updated_at = datetime.utcnow() -# Exceptions -class PipedreamException(Exception): - def __init__(self, message: str, error_code: str = None): - super().__init__(message) - self.error_code = error_code - self.message = message -class HttpClientException(PipedreamException): - def __init__(self, url: str, status_code: int, reason: str): - super().__init__(f"HTTP request to {url} failed with status {status_code}: {reason}", "HTTP_CLIENT_ERROR") - self.url = url - self.status_code = status_code - self.reason = reason +class ConnectionServiceError(Exception): + pass -class AuthenticationException(PipedreamException): - def __init__(self, reason: str): - super().__init__(f"Authentication failed: {reason}", "AUTHENTICATION_ERROR") - self.reason = reason +class AuthenticationError(ConnectionServiceError): + pass -class RateLimitException(PipedreamException): - def __init__(self, retry_after: int = None): - super().__init__("Rate limit exceeded", "RATE_LIMIT_EXCEEDED") - self.retry_after = retry_after +class RateLimitError(ConnectionServiceError): + pass -class Logger(Protocol): - def info(self, message: str) -> None: ... - def warning(self, message: str) -> None: ... - def error(self, message: str) -> None: ... - def debug(self, message: str) -> None: ... -class HttpClient: - def __init__(self): +class ExternalUserId: + def __init__(self, value: str): + if not value or not isinstance(value, str): + raise ValueError("ExternalUserId must be a non-empty string") + if len(value) > 255: + raise ValueError("ExternalUserId must be less than 255 characters") + self.value = value + + +class AppSlug: + def __init__(self, value: str): + if not value or not isinstance(value, str): + raise ValueError("AppSlug must be a non-empty string") + if not re.match(r'^[a-z0-9_-]+$', value): + raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores") + self.value = value + + +class ConnectionService: + def __init__(self, logger=None): + self._logger = logger or logger self.base_url = "https://api.pipedream.com/v1" - self.session: Optional[httpx.AsyncClient] = None - self.access_token: Optional[str] = None - self.token_expires_at: Optional[datetime] = None - self.rate_limit_token: Optional[str] = None - + self.session = None + self.access_token = None + self.token_expires_at = None + async def _get_session(self) -> httpx.AsyncClient: if self.session is None or self.session.is_closed: self.session = httpx.AsyncClient( @@ -118,7 +93,7 @@ class HttpClient: headers={"User-Agent": "Suna-Pipedream-Client/1.0"} ) return self.session - + async def _ensure_access_token(self) -> str: if self.access_token and self.token_expires_at: if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)): @@ -126,19 +101,19 @@ class HttpClient: else: self.access_token = None self.token_expires_at = None - + return await self._fetch_fresh_token() - + async def _fetch_fresh_token(self) -> str: project_id = os.getenv("PIPEDREAM_PROJECT_ID") client_id = os.getenv("PIPEDREAM_CLIENT_ID") client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET") - + if not all([project_id, client_id, client_secret]): - raise AuthenticationException("Missing required environment variables") - + raise AuthenticationError("Missing required environment variables") + session = await self._get_session() - + try: response = await session.post( f"{self.base_url}/oauth/token", @@ -149,37 +124,33 @@ class HttpClient: } ) response.raise_for_status() - + data = response.json() self.access_token = data["access_token"] - expires_in = data.get("expires_in", 3600) self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in) - + return self.access_token - + except httpx.HTTPStatusError as e: if e.response.status_code == 429: - raise RateLimitException() - raise AuthenticationException(f"Failed to obtain access token: {e}") - - async def get(self, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None) -> Dict[str, Any]: - return await self._make_request("GET", url, headers=headers, params=params) - + raise RateLimitError("Rate limit exceeded") + raise AuthenticationError(f"Failed to obtain access token: {e}") + async def _make_request(self, method: str, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None, json: Dict[str, Any] = None, retry_count: int = 0) -> Dict[str, Any]: session = await self._get_session() access_token = await self._ensure_access_token() - + request_headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json" } - + if headers: request_headers.update(headers) - + try: if method == "GET": response = await session.get(url, headers=request_headers, params=params) @@ -187,53 +158,41 @@ class HttpClient: response = await session.post(url, headers=request_headers, json=json) else: raise ValueError(f"Unsupported HTTP method: {method}") - + response.raise_for_status() return response.json() - + except httpx.HTTPStatusError as e: if e.response.status_code == 429: - raise RateLimitException() + raise RateLimitError("Rate limit exceeded") elif e.response.status_code == 401 and retry_count < 1: - await self._invalidate_token() + self.access_token = None + self.token_expires_at = None return await self._make_request(method, url, headers=headers, params=params, json=json, retry_count=retry_count + 1) else: - raise HttpClientException(url, e.response.status_code, str(e)) - - async def _invalidate_token(self): - self.access_token = None - self.token_expires_at = None - - async def post(self, url: str, headers: Dict[str, str] = None, json: Dict[str, Any] = None) -> Dict[str, Any]: - return await self._make_request("POST", url, headers=headers, json=json) - - async def close(self) -> None: - if self.session and not self.session.is_closed: - await self.session.aclose() + raise ConnectionServiceError(f"HTTP request failed: {e}") -class ConnectionRepository: - def __init__(self, http_client: HttpClient, logger: Logger): - self._http_client = http_client - self._logger = logger + async def get_connections_for_user(self, external_user_id: ExternalUserId) -> List[Connection]: + logger.info(f"Getting connections for user: {external_user_id.value}") - async def get_by_external_user_id(self, external_user_id: ExternalUserId) -> List[Connection]: project_id = os.getenv("PIPEDREAM_PROJECT_ID") environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development") - + if not project_id: - raise HttpClientException("Missing PIPEDREAM_PROJECT_ID", 500, "Configuration error") - - url = f"{self._http_client.base_url}/connect/{project_id}/accounts" + logger.error("Missing PIPEDREAM_PROJECT_ID environment variable") + return [] + + url = f"{self.base_url}/connect/{project_id}/accounts" params = {"external_id": external_user_id.value} headers = {"X-PD-Environment": environment} - + try: - data = await self._http_client.get(url, headers=headers, params=params) - + data = await self._make_request("GET", url, headers=headers, params=params) + connections = [] accounts = data.get("data", []) - + for account in accounts: app_data = account.get("app", {}) if app_data: @@ -241,12 +200,12 @@ class ConnectionRepository: auth_type_str = app_data.get("auth_type", "oauth") auth_type = AuthType(auth_type_str) except ValueError: - self._logger.warning(f"Unknown auth type '{auth_type_str}', using CUSTOM") + logger.warning(f"Unknown auth type '{auth_type_str}', using CUSTOM") auth_type = AuthType.CUSTOM - + app = App( name=app_data.get("name", "Unknown"), - slug=AppSlug(app_data.get("name_slug", "")), + slug=app_data.get("name_slug", ""), description=app_data.get("description", ""), category=app_data.get("category", "Other"), logo_url=app_data.get("img_src"), @@ -256,45 +215,47 @@ class ConnectionRepository: tags=app_data.get("tags", []), featured_weight=app_data.get("featured_weight", 0) ) - + connection = Connection( - external_user_id=external_user_id, + external_user_id=external_user_id.value, app=app, created_at=datetime.utcnow(), updated_at=datetime.utcnow(), is_active=True ) connections.append(connection) - - self._logger.info(f"Retrieved {len(connections)} connections for user: {external_user_id.value}") + + logger.info(f"Retrieved {len(connections)} connections for user: {external_user_id.value}") return connections - + except Exception as e: - self._logger.error(f"Error getting connections: {str(e)}") + logger.error(f"Error getting connections: {str(e)}") return [] -class ConnectionService: - def __init__(self, logger: Optional[Logger] = None): - self._logger = logger or logging.getLogger(__name__) - self._http_client = HttpClient() - self._connection_repo = ConnectionRepository(self._http_client, self._logger) - - async def get_connections_for_user(self, external_user_id: ExternalUserId) -> List[Connection]: - self._logger.info(f"Getting connections for user: {external_user_id.value}") - - connections = await self._connection_repo.get_by_external_user_id(external_user_id) - - self._logger.info(f"Found {len(connections)} connections for user: {external_user_id.value}") - return connections - async def has_connection(self, external_user_id: ExternalUserId, app_slug: AppSlug) -> bool: - connections = await self._connection_repo.get_by_external_user_id(external_user_id) - + connections = await self.get_connections_for_user(external_user_id) + for connection in connections: - if connection.app.slug == app_slug and connection.is_active: + if connection.app.slug == app_slug.value and connection.is_active: return True - + return False - + async def close(self): - await self._http_client.close() \ No newline at end of file + if self.session and not self.session.is_closed: + await self.session.aclose() + + +_connection_service = None + +def get_connection_service() -> ConnectionService: + global _connection_service + if _connection_service is None: + _connection_service = ConnectionService() + return _connection_service + + +PipedreamException = ConnectionServiceError +HttpClientException = ConnectionServiceError +AuthenticationException = AuthenticationError +RateLimitException = RateLimitError \ No newline at end of file diff --git a/backend/pipedream/connection_token_service.py b/backend/pipedream/connection_token_service.py index 0b707535..fd39597a 100644 --- a/backend/pipedream/connection_token_service.py +++ b/backend/pipedream/connection_token_service.py @@ -1,65 +1,48 @@ import os -import logging import re -import httpx -from typing import Dict, Any, Optional, Protocol -from dataclasses import dataclass from datetime import datetime, timedelta +from typing import Dict, Any, Optional + +import httpx +from utils.logger import logger + + +class ConnectionTokenServiceError(Exception): + pass + +class AuthenticationError(ConnectionTokenServiceError): + pass + +class RateLimitError(ConnectionTokenServiceError): + pass + -@dataclass(frozen=True) class ExternalUserId: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): + def __init__(self, value: str): + if not value or not isinstance(value, str): raise ValueError("ExternalUserId must be a non-empty string") - if len(self.value) > 255: + if len(value) > 255: raise ValueError("ExternalUserId must be less than 255 characters") + self.value = value + -@dataclass(frozen=True) class AppSlug: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): + def __init__(self, value: str): + if not value or not isinstance(value, str): raise ValueError("AppSlug must be a non-empty string") - if not re.match(r'^[a-z0-9_-]+$', self.value): + if not re.match(r'^[a-z0-9_-]+$', value): raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores") + self.value = value -class PipedreamException(Exception): - def __init__(self, message: str, error_code: str = None): - super().__init__(message) - self.error_code = error_code - self.message = message -class AuthenticationException(PipedreamException): - def __init__(self, reason: str): - super().__init__(f"Authentication failed: {reason}", "AUTHENTICATION_ERROR") - self.reason = reason - -class HttpClientException(PipedreamException): - def __init__(self, url: str, status_code: int, reason: str): - super().__init__(f"HTTP request to {url} failed with status {status_code}: {reason}", "HTTP_CLIENT_ERROR") - self.url = url - self.status_code = status_code - self.reason = reason - -class RateLimitException(PipedreamException): - def __init__(self, retry_after: int = None): - super().__init__("Rate limit exceeded", "RATE_LIMIT_EXCEEDED") - self.retry_after = retry_after - -class Logger(Protocol): - def info(self, message: str) -> None: ... - def warning(self, message: str) -> None: ... - def error(self, message: str) -> None: ... - def debug(self, message: str) -> None: ... - -class HttpClient: - def __init__(self): +class ConnectionTokenService: + def __init__(self, logger=None): + self._logger = logger or logger self.base_url = "https://api.pipedream.com/v1" - self.session: Optional[httpx.AsyncClient] = None - self.access_token: Optional[str] = None - self.token_expires_at: Optional[datetime] = None - + self.session = None + self.access_token = None + self.token_expires_at = None + async def _get_session(self) -> httpx.AsyncClient: if self.session is None or self.session.is_closed: self.session = httpx.AsyncClient( @@ -67,23 +50,24 @@ class HttpClient: headers={"User-Agent": "Suna-Pipedream-Client/1.0"} ) return self.session - + async def _ensure_access_token(self) -> str: if self.access_token and self.token_expires_at: if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)): return self.access_token + return await self._fetch_fresh_token() - + async def _fetch_fresh_token(self) -> str: project_id = os.getenv("PIPEDREAM_PROJECT_ID") client_id = os.getenv("PIPEDREAM_CLIENT_ID") client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET") - + if not all([project_id, client_id, client_secret]): - raise AuthenticationException("Missing required environment variables") - + raise AuthenticationError("Missing required environment variables") + session = await self._get_session() - + try: response = await session.post( f"{self.base_url}/oauth/token", @@ -94,86 +78,93 @@ class HttpClient: } ) response.raise_for_status() - + data = response.json() self.access_token = data["access_token"] expires_in = data.get("expires_in", 3600) self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in) - + return self.access_token - + except httpx.HTTPStatusError as e: if e.response.status_code == 429: - raise RateLimitException() - raise AuthenticationException(f"Failed to obtain access token: {e}") - - async def post(self, url: str, headers: Dict[str, str] = None, json: Dict[str, Any] = None) -> Dict[str, Any]: + raise RateLimitError("Rate limit exceeded") + raise AuthenticationError(f"Failed to obtain access token: {e}") + + async def _make_request(self, url: str, headers: Dict[str, str] = None, json: Dict[str, Any] = None) -> Dict[str, Any]: session = await self._get_session() access_token = await self._ensure_access_token() - + request_headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json" } - + if headers: request_headers.update(headers) - + try: response = await session.post(url, headers=request_headers, json=json) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: if e.response.status_code == 429: - raise RateLimitException() - raise HttpClientException(url, e.response.status_code, str(e)) - - async def close(self) -> None: - if self.session and not self.session.is_closed: - await self.session.aclose() - -class ConnectionTokenService: - def __init__(self, logger: Optional[Logger] = None): - self._logger = logger or logging.getLogger(__name__) - self._http_client = HttpClient() + raise RateLimitError("Rate limit exceeded") + raise ConnectionTokenServiceError(f"HTTP request failed: {e}") async def create(self, external_user_id: ExternalUserId, app: Optional[AppSlug] = None) -> Dict[str, Any]: project_id = os.getenv("PIPEDREAM_PROJECT_ID") environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development") - + if not project_id: - raise AuthenticationException("Missing PIPEDREAM_PROJECT_ID") - - url = f"{self._http_client.base_url}/connect/{project_id}/tokens" - + raise AuthenticationError("Missing PIPEDREAM_PROJECT_ID") + + url = f"{self.base_url}/connect/{project_id}/tokens" + payload = { "external_user_id": external_user_id.value } - + if app: payload["app"] = app.value - + headers = { "X-PD-Environment": environment } - - self._logger.info(f"Creating connection token for user: {external_user_id.value}") - + + logger.info(f"Creating connection token for user: {external_user_id.value}") + try: - data = await self._http_client.post(url, headers=headers, json=payload) - + data = await self._make_request(url, headers=headers, json=payload) + if app and "connect_link_url" in data: link = data["connect_link_url"] if "app=" not in link: separator = "&" if "?" in link else "?" data["connect_link_url"] = f"{link}{separator}app={app.value}" - - self._logger.info(f"Successfully created connection token for user: {external_user_id.value}") + + logger.info(f"Successfully created connection token for user: {external_user_id.value}") return data - + except Exception as e: - self._logger.error(f"Error creating connection token: {str(e)}") + logger.error(f"Error creating connection token: {str(e)}") raise - + async def close(self): - await self._http_client.close() \ No newline at end of file + if self.session and not self.session.is_closed: + await self.session.aclose() + + +_connection_token_service = None + +def get_connection_token_service() -> ConnectionTokenService: + global _connection_token_service + if _connection_token_service is None: + _connection_token_service = ConnectionTokenService() + return _connection_token_service + + +PipedreamException = ConnectionTokenServiceError +AuthenticationException = AuthenticationError +HttpClientException = ConnectionTokenServiceError +RateLimitException = RateLimitError \ No newline at end of file diff --git a/backend/pipedream/mcp_service.py b/backend/pipedream/mcp_service.py index d0386408..b069d3ba 100644 --- a/backend/pipedream/mcp_service.py +++ b/backend/pipedream/mcp_service.py @@ -1,39 +1,14 @@ -from typing import List, Optional, Protocol, Dict, Any -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from enum import Enum +import json import os -import logging import re +from dataclasses import dataclass, field +from datetime import datetime, timezone, timedelta +from typing import List, Optional, Dict, Any +from enum import Enum + import httpx -import asyncio +from utils.logger import logger -@dataclass(frozen=True) -class ExternalUserId: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("ExternalUserId must be a non-empty string") - if len(self.value) > 255: - raise ValueError("ExternalUserId must be less than 255 characters") - -@dataclass(frozen=True) -class AppSlug: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("AppSlug must be a non-empty string") - if not re.match(r'^[a-z0-9_-]+$', self.value): - raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores") - -@dataclass(frozen=True) -class MCPServerUrl: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("MCPServerUrl must be a non-empty string") - if not self.value.startswith(('http://', 'https://')): - raise ValueError("MCPServerUrl must be a valid HTTP/HTTPS URL") class ConnectionStatus(Enum): CONNECTED = "connected" @@ -41,84 +16,76 @@ class ConnectionStatus(Enum): ERROR = "error" PENDING = "pending" + @dataclass class MCPTool: name: str description: str input_schema: Dict[str, Any] - - def is_valid(self) -> bool: - return bool(self.name and self.description and self.input_schema) + @dataclass class MCPServer: - app_slug: AppSlug + app_slug: str app_name: str - server_url: MCPServerUrl + server_url: str project_id: str environment: str - external_user_id: ExternalUserId + external_user_id: str oauth_app_id: Optional[str] = None status: ConnectionStatus = ConnectionStatus.DISCONNECTED available_tools: List[MCPTool] = field(default_factory=list) error_message: Optional[str] = None - + def is_connected(self) -> bool: return self.status == ConnectionStatus.CONNECTED - - def add_tool(self, tool: MCPTool) -> None: - if tool.is_valid(): - self.available_tools.append(tool) - + def get_tool_count(self) -> int: return len(self.available_tools) -class PipedreamException(Exception): - def __init__(self, message: str, error_code: str = None): - super().__init__(message) - self.error_code = error_code - self.message = message -class MCPServerNotAvailableError(PipedreamException): - def __init__(self, message: str = "MCP server is not available"): - super().__init__(message, "MCP_SERVER_NOT_AVAILABLE") +class MCPServiceError(Exception): + pass -class MCPConnectionError(PipedreamException): - def __init__(self, app_slug: str, reason: str): - super().__init__(f"MCP connection failed for {app_slug}: {reason}", "MCP_CONNECTION_ERROR") - self.app_slug = app_slug - self.reason = reason +class MCPConnectionError(MCPServiceError): + pass -class AuthenticationException(PipedreamException): - def __init__(self, reason: str): - super().__init__(f"Authentication failed: {reason}", "AUTHENTICATION_ERROR") - self.reason = reason +class MCPServerNotAvailableError(MCPServiceError): + pass -class HttpClientException(PipedreamException): - def __init__(self, url: str, status_code: int, reason: str): - super().__init__(f"HTTP request to {url} failed with status {status_code}: {reason}", "HTTP_CLIENT_ERROR") - self.url = url - self.status_code = status_code - self.reason = reason +class AuthenticationError(MCPServiceError): + pass -class RateLimitException(PipedreamException): - def __init__(self, retry_after: int = None): - super().__init__("Rate limit exceeded", "RATE_LIMIT_EXCEEDED") - self.retry_after = retry_after +class RateLimitError(MCPServiceError): + pass -class Logger(Protocol): - def info(self, message: str) -> None: ... - def warning(self, message: str) -> None: ... - def error(self, message: str) -> None: ... - def debug(self, message: str) -> None: ... -class HttpClient: - def __init__(self): +class ExternalUserId: + def __init__(self, value: str): + if not value or not isinstance(value, str): + raise ValueError("ExternalUserId must be a non-empty string") + if len(value) > 255: + raise ValueError("ExternalUserId must be less than 255 characters") + self.value = value + + +class AppSlug: + def __init__(self, value: str): + if not value or not isinstance(value, str): + raise ValueError("AppSlug must be a non-empty string") + if not re.match(r'^[a-z0-9_-]+$', value): + raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores") + self.value = value + + +class MCPService: + def __init__(self, logger=None): + self._logger = logger or logger self.base_url = "https://api.pipedream.com/v1" - self.session: Optional[httpx.AsyncClient] = None - self.access_token: Optional[str] = None - self.token_expires_at: Optional[datetime] = None - + self.session = None + self.access_token = None + self.token_expires_at = None + async def _get_session(self) -> httpx.AsyncClient: if self.session is None or self.session.is_closed: self.session = httpx.AsyncClient( @@ -126,23 +93,27 @@ class HttpClient: headers={"User-Agent": "Suna-Pipedream-Client/1.0"} ) return self.session - + async def _ensure_access_token(self) -> str: if self.access_token and self.token_expires_at: if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)): return self.access_token + else: + self.access_token = None + self.token_expires_at = None + return await self._fetch_fresh_token() - + async def _fetch_fresh_token(self) -> str: project_id = os.getenv("PIPEDREAM_PROJECT_ID") client_id = os.getenv("PIPEDREAM_CLIENT_ID") client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET") - + if not all([project_id, client_id, client_secret]): - raise AuthenticationException("Missing required environment variables") - + raise AuthenticationError("Missing required environment variables") + session = await self._get_session() - + try: response = await session.post( f"{self.base_url}/oauth/token", @@ -153,227 +124,193 @@ class HttpClient: } ) response.raise_for_status() - + data = response.json() self.access_token = data["access_token"] expires_in = data.get("expires_in", 3600) self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in) - + return self.access_token - + except httpx.HTTPStatusError as e: if e.response.status_code == 429: - raise RateLimitException() - raise AuthenticationException(f"Failed to obtain access token: {e}") - - async def get(self, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None) -> Dict[str, Any]: + raise RateLimitError("Rate limit exceeded") + raise AuthenticationError(f"Failed to obtain access token: {e}") + + async def _make_request(self, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None) -> Dict[str, Any]: session = await self._get_session() access_token = await self._ensure_access_token() - + request_headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json" } - + if headers: request_headers.update(headers) - + try: response = await session.get(url, headers=request_headers, params=params) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: if e.response.status_code == 429: - raise RateLimitException() - raise HttpClientException(url, e.response.status_code, str(e)) - - async def close(self) -> None: - if self.session and not self.session.is_closed: - await self.session.aclose() + raise RateLimitError("Rate limit exceeded") + raise MCPServiceError(f"HTTP request failed: {e}") -class MCPServerRepository: - def __init__(self, http_client: HttpClient, logger: Logger): - self._http_client = http_client - self._logger = logger - - async def discover_for_user(self, external_user_id: ExternalUserId, app_slug: Optional[AppSlug] = None) -> List[MCPServer]: + async def _fetch_server_tools(self, external_user_id: str, app_slug: str) -> List[MCPTool]: project_id = os.getenv("PIPEDREAM_PROJECT_ID") environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development") - + if not project_id: - self._logger.error("Missing PIPEDREAM_PROJECT_ID environment variable") return [] - - self._logger.info(f"Discovering MCP servers for user: {external_user_id.value}, app_slug: {app_slug.value if app_slug else 'all'}") - - url = f"{self._http_client.base_url}/connect/{project_id}/accounts" + + url = f"{self.base_url}/connect/{project_id}/tools" + params = { + "app": app_slug, + "external_id": external_user_id + } + headers = {"X-PD-Environment": environment} + + try: + data = await self._make_request(url, headers=headers, params=params) + tools_data = data.get("data", []) + + tools = [] + for tool_data in tools_data: + if tool_data.get("name") or tool_data.get("key"): + tool = MCPTool( + name=tool_data.get("name") or tool_data.get("key", ""), + description=tool_data.get("description", f"Tool from {app_slug}"), + input_schema=tool_data.get("inputSchema") or tool_data.get("props", {}) + ) + tools.append(tool) + + return tools + + except Exception as e: + logger.error(f"Error fetching tools for {app_slug}: {str(e)}") + return [] + + async def discover_servers_for_user(self, external_user_id: ExternalUserId, app_slug: Optional[AppSlug] = None) -> List[MCPServer]: + project_id = os.getenv("PIPEDREAM_PROJECT_ID") + environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development") + + if not project_id: + logger.error("Missing PIPEDREAM_PROJECT_ID environment variable") + return [] + + logger.info(f"Discovering MCP servers for user: {external_user_id.value}, app_slug: {app_slug.value if app_slug else 'all'}") + + url = f"{self.base_url}/connect/{project_id}/accounts" params = {"external_id": external_user_id.value} headers = {"X-PD-Environment": environment} - + try: - data = await self._http_client.get(url, headers=headers, params=params) - + data = await self._make_request(url, headers=headers, params=params) + accounts = data.get("data", []) if not accounts: - self._logger.info(f"No connected apps found for user: {external_user_id.value}") + logger.info(f"No connected apps found for user: {external_user_id.value}") return [] - + user_apps = [account.get("app") for account in accounts if account.get("app")] - + if app_slug: user_apps = [app for app in user_apps if app.get("name_slug") == app_slug.value] - + servers = [] for app in user_apps: try: server = MCPServer( - app_slug=AppSlug(app.get("name_slug", "")), + app_slug=app.get("name_slug", ""), app_name=app.get("name", "Unknown"), - server_url=MCPServerUrl("https://remote.mcp.pipedream.net"), + server_url="https://remote.mcp.pipedream.net", project_id=project_id, environment=environment, - external_user_id=external_user_id, + external_user_id=external_user_id.value, status=ConnectionStatus.CONNECTED ) - + try: - self._logger.info(f"Attempting to fetch tools for {app.get('name_slug')}...") - tools = await self._fetch_server_tools(external_user_id, server.app_slug) - - for tool_data in tools: - tool = MCPTool( - name=tool_data.name, - description=tool_data.description or f"Tool from {app.get('name', 'Unknown')}", - input_schema=tool_data.inputSchema if hasattr(tool_data, 'inputSchema') else {} - ) - server.add_tool(tool) - - self._logger.info(f"Successfully fetched {len(tools)} tools for {app.get('name_slug')} MCP server") - except Exception as tool_error: - self._logger.error(f"Could not fetch tools for {app.get('name_slug')}: {str(tool_error)}") - import traceback - self._logger.error(f"Traceback: {traceback.format_exc()}") - + logger.info(f"Attempting to fetch tools for {app.get('name_slug')}...") + tools = await self._fetch_server_tools(external_user_id.value, server.app_slug) + server.available_tools = tools + + logger.info(f"Successfully fetched {len(tools)} tools for app: {app.get('name_slug')}") + + except Exception as e: + logger.error(f"Error fetching tools for {app.get('name_slug')}: {str(e)}") + server.available_tools = [] + servers.append(server) + except Exception as e: - self._logger.warning(f"Error creating MCP server for app {app.get('name_slug', 'unknown')}: {str(e)}") + logger.error(f"Error creating server for app {app.get('name_slug', 'unknown')}: {str(e)}") continue - - self._logger.info(f"Discovered {len(servers)} MCP servers") + + logger.info(f"Successfully discovered {len(servers)} MCP servers") return servers - + except Exception as e: - self._logger.error(f"Error discovering MCP servers: {str(e)}") + logger.error(f"Error discovering servers for user {external_user_id.value}: {str(e)}") return [] - async def _fetch_server_tools(self, external_user_id: ExternalUserId, app_slug: AppSlug) -> List: - try: - from mcp import ClientSession - from mcp.client.streamable_http import streamablehttp_client - - access_token = await self._http_client._ensure_access_token() - project_id = os.getenv("PIPEDREAM_PROJECT_ID") - environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development") - - headers = { - "Authorization": f"Bearer {access_token}", - "x-pd-project-id": project_id, - "x-pd-environment": environment, - "x-pd-external-user-id": external_user_id.value, - "x-pd-app-slug": app_slug.value, - } - - if hasattr(self._http_client, 'rate_limit_token') and self._http_client.rate_limit_token: - headers["x-pd-rate-limit"] = self._http_client.rate_limit_token - - url = "https://remote.mcp.pipedream.net" - - async with streamablehttp_client(url, headers=headers) as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - tools_result = await session.list_tools() - tools = tools_result.tools if hasattr(tools_result, 'tools') else tools_result - self._logger.info(f"Successfully fetched {len(tools) if tools else 0} tools from MCP server") - return tools - - except Exception as e: - self._logger.error(f"Error fetching tools for {app_slug.value}: {str(e)}") - import traceback - self._logger.error(f"Full traceback: {traceback.format_exc()}") - return [] - - async def test_connection(self, server: MCPServer) -> MCPServer: + async def test_server_connection(self, server: MCPServer) -> MCPServer: + logger.info(f"Testing MCP server connection: {server.app_name}") + server.status = ConnectionStatus.CONNECTED + + if server.is_connected(): + logger.info(f"MCP server {server.app_name} connected successfully with {server.get_tool_count()} tools") + else: + logger.warning(f"MCP server {server.app_name} connection failed: {server.error_message}") + return server async def create_connection(self, external_user_id: ExternalUserId, app_slug: AppSlug, oauth_app_id: Optional[str] = None) -> MCPServer: project_id = os.getenv("PIPEDREAM_PROJECT_ID") environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development") - + if not project_id: - raise MCPConnectionError(app_slug.value, "Missing PIPEDREAM_PROJECT_ID") - + raise MCPConnectionError("Missing PIPEDREAM_PROJECT_ID") + + logger.info(f"Creating MCP connection for user: {external_user_id.value}, app: {app_slug.value}") + server = MCPServer( - app_slug=app_slug, + app_slug=app_slug.value, app_name=app_slug.value.replace('_', ' ').title(), - server_url=MCPServerUrl("https://remote.mcp.pipedream.net"), + server_url="https://remote.mcp.pipedream.net", project_id=project_id, environment=environment, - external_user_id=external_user_id, + external_user_id=external_user_id.value, oauth_app_id=oauth_app_id, status=ConnectionStatus.CONNECTED ) - - return server -class MCPService: - def __init__(self, logger: Optional[Logger] = None): - self._logger = logger or logging.getLogger(__name__) - self._http_client = HttpClient() - self._mcp_server_repo = MCPServerRepository(self._http_client, self._logger) - - async def discover_servers_for_user( - self, - external_user_id: ExternalUserId, - app_slug: Optional[AppSlug] = None - ) -> List[MCPServer]: - self._logger.info(f"Discovering MCP servers for user: {external_user_id.value}") - - servers = await self._mcp_server_repo.discover_for_user(external_user_id, app_slug) - - connected_count = sum(1 for server in servers if server.is_connected()) - self._logger.info(f"Discovered {len(servers)} MCP servers ({connected_count} connected)") - - return servers - - async def test_server_connection(self, server: MCPServer) -> MCPServer: - self._logger.info(f"Testing MCP server connection: {server.app_name}") - - tested_server = await self._mcp_server_repo.test_connection(server) - - if tested_server.is_connected(): - self._logger.info(f"MCP server {server.app_name} connected successfully with {tested_server.get_tool_count()} tools") - else: - self._logger.warning(f"MCP server {server.app_name} connection failed: {tested_server.error_message}") - - return tested_server - - async def create_connection( - self, - external_user_id: ExternalUserId, - app_slug: AppSlug, - oauth_app_id: Optional[str] = None - ) -> MCPServer: - self._logger.info(f"Creating MCP connection for user: {external_user_id.value}, app: {app_slug.value}") - - server = await self._mcp_server_repo.create_connection(external_user_id, app_slug, oauth_app_id) - if server.is_connected(): - self._logger.info(f"Successfully created MCP connection for {app_slug.value} with {server.get_tool_count()} tools") + logger.info(f"Successfully created MCP connection for {app_slug.value} with {server.get_tool_count()} tools") else: - self._logger.error(f"Failed to create MCP connection for {app_slug.value}: {server.error_message}") - + logger.error(f"Failed to create MCP connection for {app_slug.value}: {server.error_message}") + return server - + async def close(self): - await self._http_client.close() \ No newline at end of file + if self.session and not self.session.is_closed: + await self.session.aclose() + + +_mcp_service = None + +def get_mcp_service() -> MCPService: + global _mcp_service + if _mcp_service is None: + _mcp_service = MCPService() + return _mcp_service + + +PipedreamException = MCPServiceError +MCPServerNotAvailableError = MCPServerNotAvailableError +AuthenticationException = AuthenticationError +HttpClientException = MCPServiceError +RateLimitException = RateLimitError \ No newline at end of file diff --git a/backend/pipedream/profile_service.py b/backend/pipedream/profile_service.py index 8bcbd3fb..96606cb1 100644 --- a/backend/pipedream/profile_service.py +++ b/backend/pipedream/profile_service.py @@ -1,461 +1,191 @@ -from typing import List, Optional, Dict, Any, Protocol -from uuid import UUID -from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum import json import hashlib import re -import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import List, Optional, Dict, Any +from uuid import uuid4, UUID -@dataclass(frozen=True) -class ExternalUserId: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("ExternalUserId must be a non-empty string") - if len(self.value) > 255: - raise ValueError("ExternalUserId must be less than 255 characters") +from services.supabase import DBConnection +from utils.logger import logger -@dataclass(frozen=True) -class AppSlug: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("AppSlug must be a non-empty string") - if not re.match(r'^[a-z0-9_-]+$', self.value): - raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores") - -@dataclass(frozen=True) -class ProfileName: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("ProfileName must be a non-empty string") - if len(self.value) > 100: - raise ValueError("ProfileName must be less than 100 characters") - -@dataclass(frozen=True) -class EncryptedConfig: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("EncryptedConfig must be a non-empty string") - -@dataclass(frozen=True) -class ConfigHash: - value: str - def __post_init__(self): - if not self.value or not isinstance(self.value, str): - raise ValueError("ConfigHash must be a non-empty string") - if len(self.value) != 64: - raise ValueError("ConfigHash must be a 64-character SHA256 hash") - - @classmethod - def from_config(cls, config: str) -> 'ConfigHash': - hash_value = hashlib.sha256(config.encode()).hexdigest() - return cls(hash_value) @dataclass class Profile: - profile_id: UUID - account_id: UUID + profile_id: str + account_id: str mcp_qualified_name: str - profile_name: ProfileName + profile_name: str display_name: str - encrypted_config: EncryptedConfig - config_hash: ConfigHash - app_slug: AppSlug + encrypted_config: str + config_hash: str + app_slug: str app_name: str - external_user_id: ExternalUserId + external_user_id: str enabled_tools: List[str] = field(default_factory=list) is_active: bool = True is_default: bool = False is_connected: bool = False - created_at: datetime = field(default_factory=datetime.utcnow) - updated_at: datetime = field(default_factory=datetime.utcnow) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) last_used_at: Optional[datetime] = None - - def update_last_used(self) -> None: - self.last_used_at = datetime.utcnow() - self.updated_at = datetime.utcnow() - - def activate(self) -> None: - self.is_active = True - self.updated_at = datetime.utcnow() - - def deactivate(self) -> None: - self.is_active = False - self.updated_at = datetime.utcnow() - - def set_as_default(self) -> None: - self.is_default = True - self.updated_at = datetime.utcnow() - - def unset_as_default(self) -> None: - self.is_default = False - self.updated_at = datetime.utcnow() - - def update_connection_status(self, is_connected: bool) -> None: - self.is_connected = is_connected - self.updated_at = datetime.utcnow() - - def enable_tool(self, tool_name: str) -> None: - if tool_name not in self.enabled_tools: - self.enabled_tools.append(tool_name) - self.updated_at = datetime.utcnow() - - def disable_tool(self, tool_name: str) -> None: - if tool_name in self.enabled_tools: - self.enabled_tools.remove(tool_name) - self.updated_at = datetime.utcnow() - - def get_mcp_qualified_name(self) -> str: - return f"pipedream:{self.app_slug.value}" + description: Optional[str] = None + oauth_app_id: Optional[str] = None -class PipedreamException(Exception): - def __init__(self, message: str, error_code: str = None): - super().__init__(message) - self.error_code = error_code - self.message = message + def to_dict(self) -> Dict[str, Any]: + return { + 'profile_id': self.profile_id, + 'account_id': self.account_id, + 'mcp_qualified_name': self.mcp_qualified_name, + 'profile_name': self.profile_name, + 'display_name': self.display_name, + 'encrypted_config': self.encrypted_config, + 'config_hash': self.config_hash, + 'app_slug': self.app_slug, + 'app_name': self.app_name, + 'external_user_id': self.external_user_id, + 'enabled_tools': self.enabled_tools, + 'is_active': self.is_active, + 'is_default': self.is_default, + 'is_connected': self.is_connected, + 'created_at': self.created_at.isoformat(), + 'updated_at': self.updated_at.isoformat(), + 'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None, + 'description': self.description, + 'oauth_app_id': self.oauth_app_id + } -class DomainException(PipedreamException): + +class ProfileServiceError(Exception): pass -class ProfileNotFoundError(DomainException): - def __init__(self, profile_id: str): - super().__init__(f"Profile with ID {profile_id} not found", "PROFILE_NOT_FOUND") - self.profile_id = profile_id +class ProfileNotFoundError(ProfileServiceError): + pass -class ProfileAlreadyExistsError(DomainException): - def __init__(self, profile_name: str, app_slug: str): - super().__init__(f"Profile '{profile_name}' already exists for app '{app_slug}'", "PROFILE_ALREADY_EXISTS") - self.profile_name = profile_name - self.app_slug = app_slug +class ProfileAlreadyExistsError(ProfileServiceError): + pass -class DatabaseException(PipedreamException): - def __init__(self, operation: str, reason: str): - super().__init__(f"Database operation '{operation}' failed: {reason}", "DATABASE_ERROR") - self.operation = operation - self.reason = reason +class InvalidConfigError(ProfileServiceError): + pass -class EncryptionException(PipedreamException): - def __init__(self, operation: str, reason: str): - super().__init__(f"Encryption operation '{operation}' failed: {reason}", "ENCRYPTION_ERROR") - self.operation = operation - self.reason = reason +class EncryptionError(ProfileServiceError): + pass -class DatabaseConnection(Protocol): - async def client(self) -> Any: ... - -class Logger(Protocol): - def info(self, message: str) -> None: ... - def warning(self, message: str) -> None: ... - def error(self, message: str) -> None: ... - def debug(self, message: str) -> None: ... - - -class EncryptionService: - def encrypt(self, data: str) -> str: - try: - from utils.encryption import encrypt_data - return encrypt_data(data) - except Exception as e: - raise EncryptionException("encrypt", str(e)) - - def decrypt(self, encrypted_data: str) -> str: - try: - from utils.encryption import decrypt_data - return decrypt_data(encrypted_data) - except Exception as e: - raise EncryptionException("decrypt", str(e)) - -class ExternalUserIdService: - def generate(self, account_id: str, app_slug: AppSlug, profile_name: ProfileName) -> ExternalUserId: - combined = f"{account_id}:{app_slug.value}:{profile_name.value}" - hash_value = hashlib.sha256(combined.encode()).hexdigest()[:16] - external_user_id = f"suna_{hash_value}" - return ExternalUserId(external_user_id) - -class MCPQualifiedNameService: - def generate(self, app_slug: AppSlug) -> str: - return f"pipedream:{app_slug.value}" - -class ProfileConfigurationService: - def validate_config(self, config: Dict[str, Any]) -> bool: - required_keys = ["app_slug", "app_name", "external_user_id"] - return all(key in config for key in required_keys) - - def merge_config(self, existing_config: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]: - merged = existing_config.copy() - merged.update(updates) - return merged - -class ConnectionStatusService: - def __init__(self, logger: Logger): - self._logger = logger - - async def check_connection_status(self, profile: Profile) -> bool: - return True - - async def update_connection_status(self, profile: Profile) -> Profile: - try: - is_connected = await self.check_connection_status(profile) - profile.update_connection_status(is_connected) - return profile - except Exception as e: - self._logger.warning(f"Error updating connection status: {str(e)}") - profile.update_connection_status(False) - return profile - - -class ProfileRepository: - def __init__(self, db: DatabaseConnection, encryption_service: EncryptionService, logger: Logger): - self._db = db - self._encryption_service = encryption_service - self._logger = logger - - async def create(self, profile: Profile) -> Profile: - try: - client = await self._db.client - - config = { - "app_slug": profile.app_slug.value, - "app_name": profile.app_name, - "external_user_id": profile.external_user_id.value, - "enabled_tools": profile.enabled_tools, - "oauth_app_id": getattr(profile, 'oauth_app_id', None), - "description": getattr(profile, 'description', None) - } - - config_json = json.dumps(config) - encrypted_config = self._encryption_service.encrypt(config_json) - config_hash = ConfigHash.from_config(config_json) - - result = await client.table('user_mcp_credential_profiles').insert({ - 'account_id': str(profile.account_id), - 'mcp_qualified_name': profile.mcp_qualified_name, - 'profile_name': profile.profile_name.value, - 'display_name': profile.display_name, - 'encrypted_config': encrypted_config, - 'config_hash': config_hash.value, - 'is_active': profile.is_active, - 'is_default': profile.is_default, - 'created_at': profile.created_at.isoformat(), - 'updated_at': profile.updated_at.isoformat() - }).execute() - - if result.data: - profile_data = result.data[0] - return self._map_to_domain(profile_data, config) - - raise DatabaseException("create", "No data returned from insert") - - except Exception as e: - self._logger.error(f"Error creating profile: {str(e)}") - raise DatabaseException("create", str(e)) - - async def get_by_id(self, account_id: UUID, profile_id: UUID) -> Optional[Profile]: - try: - client = await self._db.client - - self._logger.debug(f"Querying profile: account_id={account_id}, profile_id={profile_id}") - - result = await client.table('user_mcp_credential_profiles').select('*').eq( - 'account_id', str(account_id) - ).eq('profile_id', str(profile_id)).single().execute() - - if result.data: - profile_data = result.data - self._logger.debug(f"Found profile: {profile_data.get('profile_name', 'unknown')}") - decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config']) - config = json.loads(decrypted_config) - return self._map_to_domain(profile_data, config) - - return None - - except Exception as e: - self._logger.error(f"Error getting profile by ID {profile_id} for user {account_id}: {str(e)}") - return None - - async def get_by_app_slug(self, account_id: UUID, app_slug: AppSlug, profile_name: Optional[ProfileName] = None) -> Optional[Profile]: - try: - client = await self._db.client - - mcp_qualified_name = f"pipedream:{app_slug.value}" - query = client.table('user_mcp_credential_profiles').select('*').eq( - 'account_id', str(account_id) - ).eq('mcp_qualified_name', mcp_qualified_name) - - if profile_name: - query = query.eq('profile_name', profile_name.value) - - result = await query.execute() - - if result.data: - if profile_name: - profile_data = result.data[0] - else: - profile_data = next((p for p in result.data if p.get('is_default')), result.data[0]) - - decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config']) - config = json.loads(decrypted_config) - return self._map_to_domain(profile_data, config) - - return None - - except Exception as e: - self._logger.error(f"Error getting profile by app slug: {str(e)}") - return None - - async def find_by_account(self, account_id: UUID, app_slug: Optional[AppSlug] = None, is_active: Optional[bool] = None) -> List[Profile]: - try: - client = await self._db.client - - query = client.table('user_mcp_credential_profiles').select('*').eq( - 'account_id', str(account_id) - ) - - if app_slug: - mcp_qualified_name = f"pipedream:{app_slug.value}" - query = query.eq('mcp_qualified_name', mcp_qualified_name) - else: - query = query.like('mcp_qualified_name', 'pipedream:%') - - if is_active is not None: - query = query.eq('is_active', is_active) - - result = await query.order('created_at', desc=True).execute() - - profiles = [] - for profile_data in result.data: - try: - decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config']) - config = json.loads(decrypted_config) - profile = self._map_to_domain(profile_data, config) - profiles.append(profile) - except Exception as e: - self._logger.error(f"Error decrypting profile config: {str(e)}") - continue - - return profiles - - except Exception as e: - self._logger.error(f"Error finding profiles by account: {str(e)}") - return [] - - async def update(self, profile: Profile) -> Profile: - try: - client = await self._db.client - - config = { - "app_slug": profile.app_slug.value, - "app_name": profile.app_name, - "external_user_id": profile.external_user_id.value, - "enabled_tools": profile.enabled_tools, - "oauth_app_id": getattr(profile, 'oauth_app_id', None), - "description": getattr(profile, 'description', None) - } - - config_json = json.dumps(config) - encrypted_config = self._encryption_service.encrypt(config_json) - config_hash = ConfigHash.from_config(config_json) - - result = await client.table('user_mcp_credential_profiles').update({ - 'profile_name': profile.profile_name.value, - 'display_name': profile.display_name, - 'encrypted_config': encrypted_config, - 'config_hash': config_hash.value, - 'is_active': profile.is_active, - 'is_default': profile.is_default, - 'updated_at': profile.updated_at.isoformat(), - 'last_used_at': profile.last_used_at.isoformat() if profile.last_used_at else None - }).eq('profile_id', str(profile.profile_id)).execute() - - if result.data: - return self._map_to_domain(result.data[0], config) - - raise DatabaseException("update", "No data returned from update") - - except Exception as e: - self._logger.error(f"Error updating profile: {str(e)}") - raise DatabaseException("update", str(e)) - - async def delete(self, account_id: UUID, profile_id: UUID) -> bool: - try: - client = await self._db.client - - result = await client.table('user_mcp_credential_profiles').delete().eq( - 'profile_id', str(profile_id) - ).eq('account_id', str(account_id)).execute() - - return len(result.data) > 0 - - except Exception as e: - self._logger.error(f"Error deleting profile: {str(e)}") - return False - - async def set_default(self, account_id: UUID, profile_id: UUID, mcp_qualified_name: str) -> None: - try: - client = await self._db.client - - await client.table('user_mcp_credential_profiles').update({ - 'is_default': False - }).eq('account_id', str(account_id)).eq('mcp_qualified_name', mcp_qualified_name).execute() - - await client.table('user_mcp_credential_profiles').update({ - 'is_default': True - }).eq('profile_id', str(profile_id)).execute() - - except Exception as e: - self._logger.error(f"Error setting default profile: {str(e)}") - raise DatabaseException("set_default", str(e)) - - def _map_to_domain(self, profile_data: dict, config: dict) -> Profile: - return Profile( - profile_id=UUID(profile_data['profile_id']), - account_id=UUID(profile_data['account_id']), - mcp_qualified_name=profile_data['mcp_qualified_name'], - profile_name=ProfileName(profile_data['profile_name']), - display_name=profile_data['display_name'], - encrypted_config=EncryptedConfig(profile_data['encrypted_config']), - config_hash=ConfigHash(profile_data['config_hash']), - app_slug=AppSlug(config['app_slug']), - app_name=config['app_name'], - external_user_id=ExternalUserId(config['external_user_id']), - enabled_tools=config.get('enabled_tools', []), - is_active=profile_data['is_active'], - is_default=profile_data['is_default'], - is_connected=False, - created_at=datetime.fromisoformat(profile_data['created_at']), - updated_at=datetime.fromisoformat(profile_data['updated_at']), - last_used_at=datetime.fromisoformat(profile_data['last_used_at']) if profile_data.get('last_used_at') else None - ) class ProfileService: - def __init__( - self, - db: Optional[DatabaseConnection] = None, - logger: Optional[Logger] = None - ): - self._logger = logger or logging.getLogger(__name__) + def __init__(self): + self.db = DBConnection() + self._connection_service = None + + async def _get_client(self): + return await self.db.client + + def _get_connection_service(self): + if self._connection_service is None: + from .connection_service import ConnectionService + from utils.logger import logger + self._connection_service = ConnectionService(logger=logger) + return self._connection_service + + async def _check_connection_status(self, external_user_id: str, app_slug: str) -> bool: + try: + from .connection_service import ExternalUserId, AppSlug + connection_service = self._get_connection_service() + return await connection_service.has_connection( + ExternalUserId(external_user_id), + AppSlug(app_slug) + ) + except Exception as e: + logger.error(f"Error checking connection status: {str(e)}") + return False + + def _validate_app_slug(self, app_slug: str) -> None: + if not app_slug or not isinstance(app_slug, str): + raise InvalidConfigError("App slug must be a non-empty string") + if not re.match(r'^[a-z0-9_-]+$', app_slug): + raise InvalidConfigError("App slug must contain only lowercase letters, numbers, hyphens, and underscores") + + def _validate_profile_name(self, profile_name: str) -> None: + if not profile_name or not isinstance(profile_name, str): + raise InvalidConfigError("Profile name must be a non-empty string") + if len(profile_name) > 100: + raise InvalidConfigError("Profile name must be less than 100 characters") + + def _generate_external_user_id(self, account_id: str, app_slug: str, profile_name: str) -> str: + combined = f"{account_id}:{app_slug}:{profile_name}" + hash_value = hashlib.sha256(combined.encode()).hexdigest()[:16] + return f"suna_{hash_value}" + + def _generate_config_hash(self, config_json: str) -> str: + return hashlib.sha256(config_json.encode()).hexdigest() + + def _encrypt_config(self, config_json: str) -> str: + try: + from utils.encryption import encrypt_data + return encrypt_data(config_json) + except Exception as e: + raise EncryptionError(f"Failed to encrypt config: {str(e)}") + + def _decrypt_config(self, encrypted_config: str) -> Dict[str, Any]: + try: + from utils.encryption import decrypt_data + decrypted_json = decrypt_data(encrypted_config) + return json.loads(decrypted_json) + except Exception as e: + raise EncryptionError(f"Failed to decrypt config: {str(e)}") + + def _build_config(self, app_slug: str, app_name: str, external_user_id: str, + enabled_tools: List[str], oauth_app_id: Optional[str] = None, + description: Optional[str] = None) -> Dict[str, Any]: + return { + "app_slug": app_slug, + "app_name": app_name, + "external_user_id": external_user_id, + "enabled_tools": enabled_tools, + "oauth_app_id": oauth_app_id, + "description": description + } + + async def _map_row_to_profile(self, row: Dict[str, Any]) -> Profile: + try: + config = self._decrypt_config(row['encrypted_config']) + except Exception: + config = { + "app_slug": "unknown", + "app_name": "Unknown App", + "external_user_id": "unknown", + "enabled_tools": [], + "oauth_app_id": None, + "description": None + } - if db is None: - from services.supabase import DBConnection - self._db = DBConnection() - else: - self._db = db + is_connected = await self._check_connection_status(config['external_user_id'], config['app_slug']) - self._encryption_service = EncryptionService() - self._profile_repo = ProfileRepository(self._db, self._encryption_service, self._logger) - self._external_user_id_service = ExternalUserIdService() - self._mcp_qualified_name_service = MCPQualifiedNameService() - self._profile_config_service = ProfileConfigurationService() - self._connection_status_service = ConnectionStatusService(self._logger) - + return Profile( + profile_id=row['profile_id'], + account_id=row['account_id'], + mcp_qualified_name=row['mcp_qualified_name'], + profile_name=row['profile_name'], + display_name=row['display_name'], + encrypted_config=row['encrypted_config'], + config_hash=row['config_hash'], + app_slug=config['app_slug'], + app_name=config['app_name'], + external_user_id=config['external_user_id'], + enabled_tools=config.get('enabled_tools', []), + is_active=row['is_active'], + is_default=row['is_default'], + is_connected=is_connected, + created_at=datetime.fromisoformat(row['created_at'].replace('Z', '+00:00')) if isinstance(row['created_at'], str) else row['created_at'], + updated_at=datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00')) if isinstance(row['updated_at'], str) else row['updated_at'], + last_used_at=datetime.fromisoformat(row['last_used_at'].replace('Z', '+00:00')) if row.get('last_used_at') and isinstance(row['last_used_at'], str) else row.get('last_used_at'), + description=config.get('description'), + oauth_app_id=config.get('oauth_app_id') + ) + async def create_profile( self, - account_id: UUID, + account_id: str, profile_name: str, app_slug: str, app_name: str, @@ -465,130 +195,246 @@ class ProfileService: enabled_tools: Optional[List[str]] = None, external_user_id: Optional[str] = None ) -> Profile: - app_slug_vo = AppSlug(app_slug) - profile_name_vo = ProfileName(profile_name) + self._validate_app_slug(app_slug) + self._validate_profile_name(profile_name) - if external_user_id: - external_user_id_vo = ExternalUserId(external_user_id) - else: - external_user_id_vo = self._external_user_id_service.generate( - str(account_id), app_slug_vo, profile_name_vo + if enabled_tools is None: + enabled_tools = [] + + if not external_user_id: + external_user_id = self._generate_external_user_id(account_id, app_slug, profile_name) + + config = self._build_config(app_slug, app_name, external_user_id, enabled_tools, oauth_app_id, description) + config_json = json.dumps(config, sort_keys=True) + encrypted_config = self._encrypt_config(config_json) + config_hash = self._generate_config_hash(config_json) + + mcp_qualified_name = f"pipedream:{app_slug}" + profile_id = str(uuid4()) + now = datetime.now(timezone.utc) + + client = await self._get_client() + + try: + existing = await client.table('user_mcp_credential_profiles').select('profile_id').eq( + 'account_id', account_id + ).eq('mcp_qualified_name', mcp_qualified_name).eq('profile_name', profile_name).execute() + + if existing.data: + raise ProfileAlreadyExistsError(f"Profile '{profile_name}' already exists for app '{app_slug}'") + + if is_default: + await client.table('user_mcp_credential_profiles').update({ + 'is_default': False + }).eq('account_id', account_id).eq('mcp_qualified_name', mcp_qualified_name).execute() + + result = await client.table('user_mcp_credential_profiles').insert({ + 'profile_id': profile_id, + 'account_id': account_id, + 'mcp_qualified_name': mcp_qualified_name, + 'profile_name': profile_name, + 'display_name': profile_name, + 'encrypted_config': encrypted_config, + 'config_hash': config_hash, + 'is_active': True, + 'is_default': is_default, + 'created_at': now.isoformat(), + 'updated_at': now.isoformat() + }).execute() + + if not result.data: + raise ProfileServiceError("Failed to create profile") + + logger.info(f"Created profile {profile_id} for app {app_slug}") + + return Profile( + profile_id=profile_id, + account_id=account_id, + mcp_qualified_name=mcp_qualified_name, + profile_name=profile_name, + display_name=profile_name, + encrypted_config=encrypted_config, + config_hash=config_hash, + app_slug=app_slug, + app_name=app_name, + external_user_id=external_user_id, + enabled_tools=enabled_tools, + is_active=True, + is_default=is_default, + is_connected=False, + created_at=now, + updated_at=now, + description=description, + oauth_app_id=oauth_app_id ) - - mcp_qualified_name = self._mcp_qualified_name_service.generate(app_slug_vo) - - config = { - "app_slug": app_slug, - "app_name": app_name, - "external_user_id": external_user_id_vo.value, - "oauth_app_id": oauth_app_id, - "enabled_tools": enabled_tools or [], - "description": description - } - - if not self._profile_config_service.validate_config(config): - raise ValueError("Invalid profile configuration") - - config_json = json.dumps(config) + + except Exception as e: + if isinstance(e, (ProfileAlreadyExistsError, ProfileServiceError)): + raise + logger.error(f"Error creating profile: {str(e)}") + raise ProfileServiceError(f"Failed to create profile: {str(e)}") + + async def get_profile(self, account_id: str, profile_id: str) -> Optional[Profile]: + client = await self._get_client() - profile = Profile( - profile_id=UUID(int=0), - account_id=account_id, - mcp_qualified_name=mcp_qualified_name, - profile_name=profile_name_vo, - display_name=profile_name, - encrypted_config=EncryptedConfig("placeholder"), - config_hash=ConfigHash.from_config(config_json), - app_slug=app_slug_vo, - app_name=app_name, - external_user_id=external_user_id_vo, - enabled_tools=enabled_tools or [], - is_default=is_default - ) - - if is_default: - await self._profile_repo.set_default(account_id, profile.profile_id, mcp_qualified_name) - - created_profile = await self._profile_repo.create(profile) - self._logger.info(f"Created profile {created_profile.profile_id} for app {app_slug}") - - return created_profile - - async def get_profile(self, account_id: UUID, profile_id: UUID) -> Optional[Profile]: - profile = await self._profile_repo.get_by_id(account_id, profile_id) - if profile: - return await self._connection_status_service.update_connection_status(profile) - return None - + try: + result = await client.table('user_mcp_credential_profiles').select('*').eq( + 'account_id', account_id + ).eq('profile_id', profile_id).execute() + + if not result.data: + return None + + return await self._map_row_to_profile(result.data[0]) + + except Exception as e: + logger.error(f"Error getting profile {profile_id}: {str(e)}") + raise ProfileServiceError(f"Failed to get profile: {str(e)}") + async def get_profiles( self, - account_id: UUID, + account_id: str, app_slug: Optional[str] = None, is_active: Optional[bool] = None ) -> List[Profile]: - app_slug_vo = AppSlug(app_slug) if app_slug else None - profiles = await self._profile_repo.find_by_account(account_id, app_slug_vo, is_active) + client = await self._get_client() - updated_profiles = [] - for profile in profiles: - updated_profile = await self._connection_status_service.update_connection_status(profile) - updated_profiles.append(updated_profile) - - return updated_profiles - + try: + query = client.table('user_mcp_credential_profiles').select('*').eq('account_id', account_id) + + if app_slug: + mcp_qualified_name = f"pipedream:{app_slug}" + query = query.eq('mcp_qualified_name', mcp_qualified_name) + + if is_active is not None: + query = query.eq('is_active', is_active) + + result = await query.execute() + + import asyncio + profiles = await asyncio.gather(*[self._map_row_to_profile(row) for row in result.data]) + return profiles + + except Exception as e: + logger.error(f"Error getting profiles: {str(e)}") + raise ProfileServiceError(f"Failed to get profiles: {str(e)}") + async def update_profile( self, - account_id: UUID, - profile_id: UUID, + account_id: str, + profile_id: str, profile_name: Optional[str] = None, display_name: Optional[str] = None, is_active: Optional[bool] = None, is_default: Optional[bool] = None, enabled_tools: Optional[List[str]] = None ) -> Profile: - profile = await self._profile_repo.get_by_id(account_id, profile_id) - if not profile: - raise ProfileNotFoundError(str(profile_id)) - - if profile_name: - profile.profile_name = ProfileName(profile_name) - if display_name: - profile.display_name = display_name - if is_active is not None: - if is_active: - profile.activate() - else: - profile.deactivate() - if is_default is not None: - if is_default: - await self._profile_repo.set_default(account_id, profile_id, profile.mcp_qualified_name) - profile.set_as_default() - else: - profile.unset_as_default() - if enabled_tools is not None: - profile.enabled_tools = enabled_tools - - updated_profile = await self._profile_repo.update(profile) - self._logger.info(f"Updated profile {profile_id}") + existing_profile = await self.get_profile(account_id, profile_id) + if not existing_profile: + raise ProfileNotFoundError(f"Profile {profile_id} not found") - return updated_profile - - async def delete_profile(self, account_id: UUID, profile_id: UUID) -> bool: - success = await self._profile_repo.delete(account_id, profile_id) - if success: - self._logger.info(f"Deleted profile {profile_id}") - return success - + client = await self._get_client() + updates = {'updated_at': datetime.now(timezone.utc).isoformat()} + + try: + if enabled_tools is not None: + config = self._decrypt_config(existing_profile.encrypted_config) + config['enabled_tools'] = enabled_tools + config_json = json.dumps(config, sort_keys=True) + updates['encrypted_config'] = self._encrypt_config(config_json) + updates['config_hash'] = self._generate_config_hash(config_json) + + if profile_name is not None: + self._validate_profile_name(profile_name) + updates['profile_name'] = profile_name + + if display_name is not None: + updates['display_name'] = display_name + + if is_active is not None: + updates['is_active'] = is_active + + if is_default is not None: + if is_default: + await client.table('user_mcp_credential_profiles').update({ + 'is_default': False + }).eq('account_id', account_id).eq('mcp_qualified_name', existing_profile.mcp_qualified_name).execute() + + updates['is_default'] = is_default + + result = await client.table('user_mcp_credential_profiles').update(updates).eq( + 'profile_id', profile_id + ).eq('account_id', account_id).execute() + + if not result.data: + raise ProfileServiceError("Failed to update profile") + + logger.info(f"Updated profile {profile_id}") + + return await self.get_profile(account_id, profile_id) + + except Exception as e: + if isinstance(e, (ProfileNotFoundError, ProfileServiceError, InvalidConfigError)): + raise + logger.error(f"Error updating profile {profile_id}: {str(e)}") + raise ProfileServiceError(f"Failed to update profile: {str(e)}") + + async def delete_profile(self, account_id: str, profile_id: str) -> bool: + client = await self._get_client() + + try: + result = await client.table('user_mcp_credential_profiles').delete().eq( + 'profile_id', profile_id + ).eq('account_id', account_id).execute() + + success = bool(result.data) + if success: + logger.info(f"Deleted profile {profile_id}") + + return success + + except Exception as e: + logger.error(f"Error deleting profile {profile_id}: {str(e)}") + raise ProfileServiceError(f"Failed to delete profile: {str(e)}") + async def get_profile_by_app( self, - account_id: UUID, + account_id: str, app_slug: str, profile_name: Optional[str] = None ) -> Optional[Profile]: - app_slug_vo = AppSlug(app_slug) - profile_name_vo = ProfileName(profile_name) if profile_name else None + self._validate_app_slug(app_slug) - profile = await self._profile_repo.get_by_app_slug(account_id, app_slug_vo, profile_name_vo) - if profile: - return await self._connection_status_service.update_connection_status(profile) - return None \ No newline at end of file + client = await self._get_client() + + try: + mcp_qualified_name = f"pipedream:{app_slug}" + query = client.table('user_mcp_credential_profiles').select('*').eq( + 'account_id', account_id + ).eq('mcp_qualified_name', mcp_qualified_name) + + if profile_name: + self._validate_profile_name(profile_name) + query = query.eq('profile_name', profile_name) + else: + query = query.eq('is_default', True) + + result = await query.execute() + + if not result.data: + return None + + return await self._map_row_to_profile(result.data[0]) + + except Exception as e: + logger.error(f"Error getting profile by app {app_slug}: {str(e)}") + raise ProfileServiceError(f"Failed to get profile by app: {str(e)}") + + +_profile_service = None + +def get_profile_service() -> ProfileService: + global _profile_service + if _profile_service is None: + _profile_service = ProfileService() + return _profile_service \ No newline at end of file