mirror of https://github.com/kortix-ai/suna.git
refactor pipedream
This commit is contained in:
parent
f095affcf1
commit
773987bebc
|
@ -2206,15 +2206,15 @@ async def get_pipedream_tools_for_agent(
|
|||
|
||||
try:
|
||||
from pipedream.mcp_service import ExternalUserId, AppSlug
|
||||
external_user_id = ExternalUserId(profile.external_user_id.value)
|
||||
app_slug_obj = AppSlug(profile.app_slug.value)
|
||||
external_user_id = ExternalUserId(profile.external_user_id)
|
||||
app_slug_obj = AppSlug(profile.app_slug)
|
||||
|
||||
logger.info(f"Discovering servers for user {external_user_id.value} and app {app_slug_obj.value}")
|
||||
servers = await mcp_service.discover_servers_for_user(external_user_id, app_slug_obj)
|
||||
logger.info(f"Found {len(servers)} servers: {[s.app_slug.value for s in servers]}")
|
||||
logger.info(f"Found {len(servers)} servers: {[s.app_slug for s in servers]}")
|
||||
|
||||
server = servers[0] if servers else None
|
||||
logger.info(f"Selected server: {server.app_slug.value if server else 'None'} with {len(server.available_tools) if server else 0} tools")
|
||||
logger.info(f"Selected server: {server.app_slug if server else 'None'} with {len(server.available_tools) if server else 0} tools")
|
||||
|
||||
if not server:
|
||||
return {
|
||||
|
|
|
@ -9,11 +9,11 @@ from .connection_token_service import ConnectionTokenService
|
|||
|
||||
db = DBConnection()
|
||||
|
||||
profile_service = ProfileService(db=db, logger=logger)
|
||||
connection_service = ConnectionService(logger=logger)
|
||||
profile_service = ProfileService()
|
||||
connection_service = ConnectionService()
|
||||
app_service = AppService(logger=logger)
|
||||
mcp_service = MCPService(logger=logger)
|
||||
connection_token_service = ConnectionTokenService(logger=logger)
|
||||
mcp_service = MCPService()
|
||||
connection_token_service = ConnectionTokenService()
|
||||
|
||||
from . import api
|
||||
api.profile_service = profile_service
|
||||
|
|
|
@ -6,10 +6,10 @@ from datetime import datetime
|
|||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from .profile_service import ProfileService, Profile
|
||||
from .profile_service import ProfileService, Profile, ProfileServiceError, ProfileNotFoundError, ProfileAlreadyExistsError, InvalidConfigError, EncryptionError
|
||||
from .connection_service import ConnectionService
|
||||
from .app_service import AppService
|
||||
from .mcp_service import MCPService, ConnectionStatus
|
||||
from .mcp_service import MCPService, ConnectionStatus, MCPConnectionError, MCPServiceError
|
||||
from .connection_token_service import ConnectionTokenService
|
||||
|
||||
import httpx
|
||||
|
@ -119,11 +119,11 @@ class ProfileResponse(BaseModel):
|
|||
profile_id=profile.profile_id,
|
||||
account_id=profile.account_id,
|
||||
mcp_qualified_name=profile.mcp_qualified_name,
|
||||
profile_name=profile.profile_name.value,
|
||||
profile_name=profile.profile_name,
|
||||
display_name=profile.display_name,
|
||||
app_slug=profile.app_slug.value,
|
||||
app_slug=profile.app_slug,
|
||||
app_name=profile.app_name,
|
||||
external_user_id=profile.external_user_id.value,
|
||||
external_user_id=profile.external_user_id,
|
||||
enabled_tools=profile.enabled_tools,
|
||||
is_active=profile.is_active,
|
||||
is_default=profile.is_default,
|
||||
|
@ -145,15 +145,15 @@ def _handle_pipedream_exception(e: Exception) -> HTTPException:
|
|||
return HTTPException(status_code=404, detail=str(e))
|
||||
elif isinstance(e, ProfileAlreadyExistsError):
|
||||
return HTTPException(status_code=409, detail=str(e))
|
||||
elif isinstance(e, ValidationException):
|
||||
elif isinstance(e, InvalidConfigError):
|
||||
return HTTPException(status_code=400, detail=str(e))
|
||||
elif isinstance(e, ConnectionNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(e))
|
||||
elif isinstance(e, AppNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(e))
|
||||
elif isinstance(e, EncryptionError):
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
elif isinstance(e, MCPConnectionError):
|
||||
return HTTPException(status_code=502, detail=str(e))
|
||||
elif isinstance(e, PipedreamException):
|
||||
elif isinstance(e, MCPServiceError):
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
elif isinstance(e, ProfileServiceError):
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
else:
|
||||
return HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
@ -203,7 +203,7 @@ async def get_user_connections(
|
|||
for connection in connections:
|
||||
connection_data.append({
|
||||
"name": connection.app.name,
|
||||
"name_slug": connection.app.slug.value,
|
||||
"name_slug": connection.app.slug,
|
||||
"description": connection.app.description,
|
||||
"category": connection.app.category,
|
||||
"img_src": connection.app.logo_url,
|
||||
|
@ -251,12 +251,12 @@ async def discover_mcp_servers(
|
|||
})
|
||||
|
||||
server_data.append({
|
||||
"app_slug": server.app_slug.value,
|
||||
"app_slug": server.app_slug,
|
||||
"app_name": server.app_name,
|
||||
"server_url": server.server_url.value,
|
||||
"server_url": server.server_url,
|
||||
"project_id": server.project_id,
|
||||
"environment": server.environment,
|
||||
"external_user_id": server.external_user_id.value,
|
||||
"external_user_id": server.external_user_id,
|
||||
"oauth_app_id": server.oauth_app_id,
|
||||
"status": server.status.value,
|
||||
"available_tools": tools_data,
|
||||
|
@ -300,12 +300,12 @@ async def discover_mcp_servers_for_profile(
|
|||
})
|
||||
|
||||
server_data.append({
|
||||
"app_slug": server.app_slug.value,
|
||||
"app_slug": server.app_slug,
|
||||
"app_name": server.app_name,
|
||||
"server_url": server.server_url.value,
|
||||
"server_url": server.server_url,
|
||||
"project_id": server.project_id,
|
||||
"environment": server.environment,
|
||||
"external_user_id": server.external_user_id.value,
|
||||
"external_user_id": server.external_user_id,
|
||||
"oauth_app_id": server.oauth_app_id,
|
||||
"status": server.status.value,
|
||||
"available_tools": tools_data,
|
||||
|
@ -351,12 +351,12 @@ async def create_mcp_connection(
|
|||
})
|
||||
|
||||
mcp_config = {
|
||||
"app_slug": server.app_slug.value,
|
||||
"app_slug": server.app_slug,
|
||||
"app_name": server.app_name,
|
||||
"server_url": server.server_url.value,
|
||||
"server_url": server.server_url,
|
||||
"project_id": server.project_id,
|
||||
"environment": server.environment,
|
||||
"external_user_id": server.external_user_id.value,
|
||||
"external_user_id": server.external_user_id,
|
||||
"oauth_app_id": server.oauth_app_id,
|
||||
"status": server.status.value,
|
||||
"available_tools": tools_data
|
||||
|
@ -540,12 +540,11 @@ async def create_credential_profile(
|
|||
logger.info(f"Creating credential profile for user: {user_id}, app: {request.app_slug}")
|
||||
|
||||
try:
|
||||
from uuid import UUID
|
||||
profile = await profile_service.create_profile(
|
||||
account_id=UUID(user_id),
|
||||
profile_name=request.profile_name,
|
||||
app_slug=request.app_slug,
|
||||
app_name=request.app_name,
|
||||
user_id,
|
||||
request.profile_name,
|
||||
request.app_slug,
|
||||
request.app_name,
|
||||
description=request.description,
|
||||
is_default=request.is_default,
|
||||
oauth_app_id=request.oauth_app_id,
|
||||
|
@ -571,9 +570,7 @@ async def get_credential_profiles(
|
|||
actual_app_slug = _strip_pipedream_prefix(app_slug)
|
||||
|
||||
try:
|
||||
from uuid import UUID
|
||||
profiles = await profile_service.get_profiles(UUID(user_id), actual_app_slug, is_active)
|
||||
|
||||
profiles = await profile_service.get_profiles(user_id, actual_app_slug, is_active)
|
||||
return [ProfileResponse.from_domain(profile) for profile in profiles]
|
||||
|
||||
except Exception as e:
|
||||
|
@ -589,8 +586,7 @@ async def get_credential_profile(
|
|||
logger.info(f"Getting credential profile: {profile_id} for user: {user_id}")
|
||||
|
||||
try:
|
||||
from uuid import UUID
|
||||
profile = await profile_service.get_profile(UUID(user_id), UUID(profile_id))
|
||||
profile = await profile_service.get_profile(user_id, profile_id)
|
||||
|
||||
if not profile:
|
||||
from .profile_service import ProfileNotFoundError
|
||||
|
@ -612,10 +608,9 @@ async def update_credential_profile(
|
|||
logger.info(f"Updating credential profile: {profile_id} for user: {user_id}")
|
||||
|
||||
try:
|
||||
from uuid import UUID
|
||||
profile = await profile_service.update_profile(
|
||||
account_id=UUID(user_id),
|
||||
profile_id=UUID(profile_id),
|
||||
user_id,
|
||||
profile_id,
|
||||
profile_name=request.profile_name,
|
||||
display_name=request.display_name,
|
||||
is_active=request.is_active,
|
||||
|
@ -638,8 +633,7 @@ async def delete_credential_profile(
|
|||
logger.info(f"Deleting credential profile: {profile_id} for user: {user_id}")
|
||||
|
||||
try:
|
||||
from uuid import UUID
|
||||
success = await profile_service.delete_profile(UUID(user_id), UUID(profile_id))
|
||||
success = await profile_service.delete_profile(user_id, profile_id)
|
||||
|
||||
if not success:
|
||||
raise ProfileNotFoundError(profile_id)
|
||||
|
@ -666,12 +660,12 @@ async def connect_credential_profile(
|
|||
from .profile_service import ProfileNotFoundError
|
||||
from .connection_token_service import ExternalUserId, AppSlug
|
||||
|
||||
profile = await profile_service.get_profile(UUID(user_id), UUID(profile_id))
|
||||
profile = await profile_service.get_profile(user_id, profile_id)
|
||||
if not profile:
|
||||
raise ProfileNotFoundError(profile_id)
|
||||
|
||||
external_user_id = ExternalUserId(profile.external_user_id.value)
|
||||
app_slug = AppSlug(actual_app or profile.app_slug.value)
|
||||
external_user_id = ExternalUserId(profile.external_user_id)
|
||||
app_slug = AppSlug(actual_app or profile.app_slug)
|
||||
result = await connection_token_service.create(external_user_id, app_slug)
|
||||
|
||||
return {
|
||||
|
@ -680,8 +674,8 @@ async def connect_credential_profile(
|
|||
"token": result.get("token"),
|
||||
"expires_at": result.get("expires_at"),
|
||||
"profile_id": profile_id,
|
||||
"external_user_id": profile.external_user_id.value,
|
||||
"app": actual_app or profile.app_slug.value
|
||||
"external_user_id": profile.external_user_id,
|
||||
"app": actual_app or profile.app_slug
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -701,18 +695,18 @@ async def get_profile_connections(
|
|||
from .profile_service import ProfileNotFoundError
|
||||
from .connection_service import ExternalUserId
|
||||
|
||||
profile = await profile_service.get_profile(UUID(user_id), UUID(profile_id))
|
||||
profile = await profile_service.get_profile(user_id, profile_id)
|
||||
if not profile:
|
||||
raise ProfileNotFoundError(profile_id)
|
||||
|
||||
external_user_id = ExternalUserId(profile.external_user_id.value)
|
||||
external_user_id = ExternalUserId(profile.external_user_id)
|
||||
connections = await connection_service.get_connections_for_user(external_user_id)
|
||||
|
||||
connection_data = []
|
||||
for connection in connections:
|
||||
connection_data.append({
|
||||
"name": connection.app.name,
|
||||
"name_slug": connection.app.slug.value,
|
||||
"name_slug": connection.app.slug,
|
||||
"description": connection.app.description,
|
||||
"category": connection.app.category,
|
||||
"img_src": connection.app.logo_url,
|
||||
|
|
|
@ -1,31 +1,13 @@
|
|||
from typing import List, Optional, Protocol, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import os
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
from uuid import UUID
|
||||
import json
|
||||
from utils.logger import logger
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalUserId:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("ExternalUserId must be a non-empty string")
|
||||
if len(self.value) > 255:
|
||||
raise ValueError("ExternalUserId must be less than 255 characters")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppSlug:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("AppSlug must be a non-empty string")
|
||||
if not re.match(r'^[a-z0-9_-]+$', self.value):
|
||||
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
|
||||
|
||||
class AuthType(Enum):
|
||||
OAUTH = "oauth"
|
||||
|
@ -41,10 +23,11 @@ class AuthType(Enum):
|
|||
return cls.CUSTOM
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class App:
|
||||
name: str
|
||||
slug: AppSlug
|
||||
slug: str
|
||||
description: str
|
||||
category: str
|
||||
logo_url: Optional[str] = None
|
||||
|
@ -57,60 +40,52 @@ class App:
|
|||
def is_featured(self) -> bool:
|
||||
return self.featured_weight > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Connection:
|
||||
external_user_id: ExternalUserId
|
||||
external_user_id: str
|
||||
app: App
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
is_active: bool = True
|
||||
|
||||
def activate(self) -> None:
|
||||
self.is_active = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def deactivate(self) -> None:
|
||||
self.is_active = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
# Exceptions
|
||||
class PipedreamException(Exception):
|
||||
def __init__(self, message: str, error_code: str = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
|
||||
class HttpClientException(PipedreamException):
|
||||
def __init__(self, url: str, status_code: int, reason: str):
|
||||
super().__init__(f"HTTP request to {url} failed with status {status_code}: {reason}", "HTTP_CLIENT_ERROR")
|
||||
self.url = url
|
||||
self.status_code = status_code
|
||||
self.reason = reason
|
||||
class ConnectionServiceError(Exception):
|
||||
pass
|
||||
|
||||
class AuthenticationException(PipedreamException):
|
||||
def __init__(self, reason: str):
|
||||
super().__init__(f"Authentication failed: {reason}", "AUTHENTICATION_ERROR")
|
||||
self.reason = reason
|
||||
class AuthenticationError(ConnectionServiceError):
|
||||
pass
|
||||
|
||||
class RateLimitException(PipedreamException):
|
||||
def __init__(self, retry_after: int = None):
|
||||
super().__init__("Rate limit exceeded", "RATE_LIMIT_EXCEEDED")
|
||||
self.retry_after = retry_after
|
||||
class RateLimitError(ConnectionServiceError):
|
||||
pass
|
||||
|
||||
class Logger(Protocol):
|
||||
def info(self, message: str) -> None: ...
|
||||
def warning(self, message: str) -> None: ...
|
||||
def error(self, message: str) -> None: ...
|
||||
def debug(self, message: str) -> None: ...
|
||||
|
||||
class HttpClient:
|
||||
def __init__(self):
|
||||
class ExternalUserId:
|
||||
def __init__(self, value: str):
|
||||
if not value or not isinstance(value, str):
|
||||
raise ValueError("ExternalUserId must be a non-empty string")
|
||||
if len(value) > 255:
|
||||
raise ValueError("ExternalUserId must be less than 255 characters")
|
||||
self.value = value
|
||||
|
||||
|
||||
class AppSlug:
|
||||
def __init__(self, value: str):
|
||||
if not value or not isinstance(value, str):
|
||||
raise ValueError("AppSlug must be a non-empty string")
|
||||
if not re.match(r'^[a-z0-9_-]+$', value):
|
||||
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
|
||||
self.value = value
|
||||
|
||||
|
||||
class ConnectionService:
|
||||
def __init__(self, logger=None):
|
||||
self._logger = logger or logger
|
||||
self.base_url = "https://api.pipedream.com/v1"
|
||||
self.session: Optional[httpx.AsyncClient] = None
|
||||
self.access_token: Optional[str] = None
|
||||
self.token_expires_at: Optional[datetime] = None
|
||||
self.rate_limit_token: Optional[str] = None
|
||||
|
||||
self.session = None
|
||||
self.access_token = None
|
||||
self.token_expires_at = None
|
||||
|
||||
async def _get_session(self) -> httpx.AsyncClient:
|
||||
if self.session is None or self.session.is_closed:
|
||||
self.session = httpx.AsyncClient(
|
||||
|
@ -118,7 +93,7 @@ class HttpClient:
|
|||
headers={"User-Agent": "Suna-Pipedream-Client/1.0"}
|
||||
)
|
||||
return self.session
|
||||
|
||||
|
||||
async def _ensure_access_token(self) -> str:
|
||||
if self.access_token and self.token_expires_at:
|
||||
if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)):
|
||||
|
@ -126,19 +101,19 @@ class HttpClient:
|
|||
else:
|
||||
self.access_token = None
|
||||
self.token_expires_at = None
|
||||
|
||||
|
||||
return await self._fetch_fresh_token()
|
||||
|
||||
|
||||
async def _fetch_fresh_token(self) -> str:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
client_id = os.getenv("PIPEDREAM_CLIENT_ID")
|
||||
client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET")
|
||||
|
||||
|
||||
if not all([project_id, client_id, client_secret]):
|
||||
raise AuthenticationException("Missing required environment variables")
|
||||
|
||||
raise AuthenticationError("Missing required environment variables")
|
||||
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
try:
|
||||
response = await session.post(
|
||||
f"{self.base_url}/oauth/token",
|
||||
|
@ -149,37 +124,33 @@ class HttpClient:
|
|||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
data = response.json()
|
||||
self.access_token = data["access_token"]
|
||||
|
||||
expires_in = data.get("expires_in", 3600)
|
||||
self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
|
||||
return self.access_token
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
raise RateLimitException()
|
||||
raise AuthenticationException(f"Failed to obtain access token: {e}")
|
||||
|
||||
async def get(self, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
return await self._make_request("GET", url, headers=headers, params=params)
|
||||
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
raise AuthenticationError(f"Failed to obtain access token: {e}")
|
||||
|
||||
async def _make_request(self, method: str, url: str, headers: Dict[str, str] = None,
|
||||
params: Dict[str, Any] = None, json: Dict[str, Any] = None,
|
||||
retry_count: int = 0) -> Dict[str, Any]:
|
||||
session = await self._get_session()
|
||||
access_token = await self._ensure_access_token()
|
||||
|
||||
|
||||
request_headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
|
||||
try:
|
||||
if method == "GET":
|
||||
response = await session.get(url, headers=request_headers, params=params)
|
||||
|
@ -187,53 +158,41 @@ class HttpClient:
|
|||
response = await session.post(url, headers=request_headers, json=json)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
raise RateLimitException()
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
elif e.response.status_code == 401 and retry_count < 1:
|
||||
await self._invalidate_token()
|
||||
self.access_token = None
|
||||
self.token_expires_at = None
|
||||
return await self._make_request(method, url, headers=headers, params=params,
|
||||
json=json, retry_count=retry_count + 1)
|
||||
else:
|
||||
raise HttpClientException(url, e.response.status_code, str(e))
|
||||
|
||||
async def _invalidate_token(self):
|
||||
self.access_token = None
|
||||
self.token_expires_at = None
|
||||
|
||||
async def post(self, url: str, headers: Dict[str, str] = None, json: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
return await self._make_request("POST", url, headers=headers, json=json)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.session and not self.session.is_closed:
|
||||
await self.session.aclose()
|
||||
raise ConnectionServiceError(f"HTTP request failed: {e}")
|
||||
|
||||
class ConnectionRepository:
|
||||
def __init__(self, http_client: HttpClient, logger: Logger):
|
||||
self._http_client = http_client
|
||||
self._logger = logger
|
||||
async def get_connections_for_user(self, external_user_id: ExternalUserId) -> List[Connection]:
|
||||
logger.info(f"Getting connections for user: {external_user_id.value}")
|
||||
|
||||
async def get_by_external_user_id(self, external_user_id: ExternalUserId) -> List[Connection]:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
|
||||
|
||||
|
||||
if not project_id:
|
||||
raise HttpClientException("Missing PIPEDREAM_PROJECT_ID", 500, "Configuration error")
|
||||
|
||||
url = f"{self._http_client.base_url}/connect/{project_id}/accounts"
|
||||
logger.error("Missing PIPEDREAM_PROJECT_ID environment variable")
|
||||
return []
|
||||
|
||||
url = f"{self.base_url}/connect/{project_id}/accounts"
|
||||
params = {"external_id": external_user_id.value}
|
||||
headers = {"X-PD-Environment": environment}
|
||||
|
||||
|
||||
try:
|
||||
data = await self._http_client.get(url, headers=headers, params=params)
|
||||
|
||||
data = await self._make_request("GET", url, headers=headers, params=params)
|
||||
|
||||
connections = []
|
||||
accounts = data.get("data", [])
|
||||
|
||||
|
||||
for account in accounts:
|
||||
app_data = account.get("app", {})
|
||||
if app_data:
|
||||
|
@ -241,12 +200,12 @@ class ConnectionRepository:
|
|||
auth_type_str = app_data.get("auth_type", "oauth")
|
||||
auth_type = AuthType(auth_type_str)
|
||||
except ValueError:
|
||||
self._logger.warning(f"Unknown auth type '{auth_type_str}', using CUSTOM")
|
||||
logger.warning(f"Unknown auth type '{auth_type_str}', using CUSTOM")
|
||||
auth_type = AuthType.CUSTOM
|
||||
|
||||
|
||||
app = App(
|
||||
name=app_data.get("name", "Unknown"),
|
||||
slug=AppSlug(app_data.get("name_slug", "")),
|
||||
slug=app_data.get("name_slug", ""),
|
||||
description=app_data.get("description", ""),
|
||||
category=app_data.get("category", "Other"),
|
||||
logo_url=app_data.get("img_src"),
|
||||
|
@ -256,45 +215,47 @@ class ConnectionRepository:
|
|||
tags=app_data.get("tags", []),
|
||||
featured_weight=app_data.get("featured_weight", 0)
|
||||
)
|
||||
|
||||
|
||||
connection = Connection(
|
||||
external_user_id=external_user_id,
|
||||
external_user_id=external_user_id.value,
|
||||
app=app,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
is_active=True
|
||||
)
|
||||
connections.append(connection)
|
||||
|
||||
self._logger.info(f"Retrieved {len(connections)} connections for user: {external_user_id.value}")
|
||||
|
||||
logger.info(f"Retrieved {len(connections)} connections for user: {external_user_id.value}")
|
||||
return connections
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error getting connections: {str(e)}")
|
||||
logger.error(f"Error getting connections: {str(e)}")
|
||||
return []
|
||||
|
||||
class ConnectionService:
|
||||
def __init__(self, logger: Optional[Logger] = None):
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._http_client = HttpClient()
|
||||
self._connection_repo = ConnectionRepository(self._http_client, self._logger)
|
||||
|
||||
async def get_connections_for_user(self, external_user_id: ExternalUserId) -> List[Connection]:
|
||||
self._logger.info(f"Getting connections for user: {external_user_id.value}")
|
||||
|
||||
connections = await self._connection_repo.get_by_external_user_id(external_user_id)
|
||||
|
||||
self._logger.info(f"Found {len(connections)} connections for user: {external_user_id.value}")
|
||||
return connections
|
||||
|
||||
async def has_connection(self, external_user_id: ExternalUserId, app_slug: AppSlug) -> bool:
|
||||
connections = await self._connection_repo.get_by_external_user_id(external_user_id)
|
||||
|
||||
connections = await self.get_connections_for_user(external_user_id)
|
||||
|
||||
for connection in connections:
|
||||
if connection.app.slug == app_slug and connection.is_active:
|
||||
if connection.app.slug == app_slug.value and connection.is_active:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def close(self):
|
||||
await self._http_client.close()
|
||||
if self.session and not self.session.is_closed:
|
||||
await self.session.aclose()
|
||||
|
||||
|
||||
_connection_service = None
|
||||
|
||||
def get_connection_service() -> ConnectionService:
|
||||
global _connection_service
|
||||
if _connection_service is None:
|
||||
_connection_service = ConnectionService()
|
||||
return _connection_service
|
||||
|
||||
|
||||
PipedreamException = ConnectionServiceError
|
||||
HttpClientException = ConnectionServiceError
|
||||
AuthenticationException = AuthenticationError
|
||||
RateLimitException = RateLimitError
|
|
@ -1,65 +1,48 @@
|
|||
import os
|
||||
import logging
|
||||
import re
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, Protocol
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import httpx
|
||||
from utils.logger import logger
|
||||
|
||||
|
||||
class ConnectionTokenServiceError(Exception):
|
||||
pass
|
||||
|
||||
class AuthenticationError(ConnectionTokenServiceError):
|
||||
pass
|
||||
|
||||
class RateLimitError(ConnectionTokenServiceError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalUserId:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
def __init__(self, value: str):
|
||||
if not value or not isinstance(value, str):
|
||||
raise ValueError("ExternalUserId must be a non-empty string")
|
||||
if len(self.value) > 255:
|
||||
if len(value) > 255:
|
||||
raise ValueError("ExternalUserId must be less than 255 characters")
|
||||
self.value = value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppSlug:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
def __init__(self, value: str):
|
||||
if not value or not isinstance(value, str):
|
||||
raise ValueError("AppSlug must be a non-empty string")
|
||||
if not re.match(r'^[a-z0-9_-]+$', self.value):
|
||||
if not re.match(r'^[a-z0-9_-]+$', value):
|
||||
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
|
||||
self.value = value
|
||||
|
||||
class PipedreamException(Exception):
|
||||
def __init__(self, message: str, error_code: str = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
|
||||
class AuthenticationException(PipedreamException):
|
||||
def __init__(self, reason: str):
|
||||
super().__init__(f"Authentication failed: {reason}", "AUTHENTICATION_ERROR")
|
||||
self.reason = reason
|
||||
|
||||
class HttpClientException(PipedreamException):
|
||||
def __init__(self, url: str, status_code: int, reason: str):
|
||||
super().__init__(f"HTTP request to {url} failed with status {status_code}: {reason}", "HTTP_CLIENT_ERROR")
|
||||
self.url = url
|
||||
self.status_code = status_code
|
||||
self.reason = reason
|
||||
|
||||
class RateLimitException(PipedreamException):
|
||||
def __init__(self, retry_after: int = None):
|
||||
super().__init__("Rate limit exceeded", "RATE_LIMIT_EXCEEDED")
|
||||
self.retry_after = retry_after
|
||||
|
||||
class Logger(Protocol):
|
||||
def info(self, message: str) -> None: ...
|
||||
def warning(self, message: str) -> None: ...
|
||||
def error(self, message: str) -> None: ...
|
||||
def debug(self, message: str) -> None: ...
|
||||
|
||||
class HttpClient:
|
||||
def __init__(self):
|
||||
class ConnectionTokenService:
|
||||
def __init__(self, logger=None):
|
||||
self._logger = logger or logger
|
||||
self.base_url = "https://api.pipedream.com/v1"
|
||||
self.session: Optional[httpx.AsyncClient] = None
|
||||
self.access_token: Optional[str] = None
|
||||
self.token_expires_at: Optional[datetime] = None
|
||||
|
||||
self.session = None
|
||||
self.access_token = None
|
||||
self.token_expires_at = None
|
||||
|
||||
async def _get_session(self) -> httpx.AsyncClient:
|
||||
if self.session is None or self.session.is_closed:
|
||||
self.session = httpx.AsyncClient(
|
||||
|
@ -67,23 +50,24 @@ class HttpClient:
|
|||
headers={"User-Agent": "Suna-Pipedream-Client/1.0"}
|
||||
)
|
||||
return self.session
|
||||
|
||||
|
||||
async def _ensure_access_token(self) -> str:
|
||||
if self.access_token and self.token_expires_at:
|
||||
if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)):
|
||||
return self.access_token
|
||||
|
||||
return await self._fetch_fresh_token()
|
||||
|
||||
|
||||
async def _fetch_fresh_token(self) -> str:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
client_id = os.getenv("PIPEDREAM_CLIENT_ID")
|
||||
client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET")
|
||||
|
||||
|
||||
if not all([project_id, client_id, client_secret]):
|
||||
raise AuthenticationException("Missing required environment variables")
|
||||
|
||||
raise AuthenticationError("Missing required environment variables")
|
||||
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
try:
|
||||
response = await session.post(
|
||||
f"{self.base_url}/oauth/token",
|
||||
|
@ -94,86 +78,93 @@ class HttpClient:
|
|||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
data = response.json()
|
||||
self.access_token = data["access_token"]
|
||||
expires_in = data.get("expires_in", 3600)
|
||||
self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
|
||||
return self.access_token
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
raise RateLimitException()
|
||||
raise AuthenticationException(f"Failed to obtain access token: {e}")
|
||||
|
||||
async def post(self, url: str, headers: Dict[str, str] = None, json: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
raise AuthenticationError(f"Failed to obtain access token: {e}")
|
||||
|
||||
async def _make_request(self, url: str, headers: Dict[str, str] = None, json: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
session = await self._get_session()
|
||||
access_token = await self._ensure_access_token()
|
||||
|
||||
|
||||
request_headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
|
||||
try:
|
||||
response = await session.post(url, headers=request_headers, json=json)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
raise RateLimitException()
|
||||
raise HttpClientException(url, e.response.status_code, str(e))
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.session and not self.session.is_closed:
|
||||
await self.session.aclose()
|
||||
|
||||
class ConnectionTokenService:
|
||||
def __init__(self, logger: Optional[Logger] = None):
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._http_client = HttpClient()
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
raise ConnectionTokenServiceError(f"HTTP request failed: {e}")
|
||||
|
||||
async def create(self, external_user_id: ExternalUserId, app: Optional[AppSlug] = None) -> Dict[str, Any]:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
|
||||
|
||||
|
||||
if not project_id:
|
||||
raise AuthenticationException("Missing PIPEDREAM_PROJECT_ID")
|
||||
|
||||
url = f"{self._http_client.base_url}/connect/{project_id}/tokens"
|
||||
|
||||
raise AuthenticationError("Missing PIPEDREAM_PROJECT_ID")
|
||||
|
||||
url = f"{self.base_url}/connect/{project_id}/tokens"
|
||||
|
||||
payload = {
|
||||
"external_user_id": external_user_id.value
|
||||
}
|
||||
|
||||
|
||||
if app:
|
||||
payload["app"] = app.value
|
||||
|
||||
|
||||
headers = {
|
||||
"X-PD-Environment": environment
|
||||
}
|
||||
|
||||
self._logger.info(f"Creating connection token for user: {external_user_id.value}")
|
||||
|
||||
|
||||
logger.info(f"Creating connection token for user: {external_user_id.value}")
|
||||
|
||||
try:
|
||||
data = await self._http_client.post(url, headers=headers, json=payload)
|
||||
|
||||
data = await self._make_request(url, headers=headers, json=payload)
|
||||
|
||||
if app and "connect_link_url" in data:
|
||||
link = data["connect_link_url"]
|
||||
if "app=" not in link:
|
||||
separator = "&" if "?" in link else "?"
|
||||
data["connect_link_url"] = f"{link}{separator}app={app.value}"
|
||||
|
||||
self._logger.info(f"Successfully created connection token for user: {external_user_id.value}")
|
||||
|
||||
logger.info(f"Successfully created connection token for user: {external_user_id.value}")
|
||||
return data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error creating connection token: {str(e)}")
|
||||
logger.error(f"Error creating connection token: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def close(self):
|
||||
await self._http_client.close()
|
||||
if self.session and not self.session.is_closed:
|
||||
await self.session.aclose()
|
||||
|
||||
|
||||
_connection_token_service = None
|
||||
|
||||
def get_connection_token_service() -> ConnectionTokenService:
|
||||
global _connection_token_service
|
||||
if _connection_token_service is None:
|
||||
_connection_token_service = ConnectionTokenService()
|
||||
return _connection_token_service
|
||||
|
||||
|
||||
PipedreamException = ConnectionTokenServiceError
|
||||
AuthenticationException = AuthenticationError
|
||||
HttpClientException = ConnectionTokenServiceError
|
||||
RateLimitException = RateLimitError
|
|
@ -1,39 +1,14 @@
|
|||
from typing import List, Optional, Protocol, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import asyncio
|
||||
from utils.logger import logger
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalUserId:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("ExternalUserId must be a non-empty string")
|
||||
if len(self.value) > 255:
|
||||
raise ValueError("ExternalUserId must be less than 255 characters")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppSlug:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("AppSlug must be a non-empty string")
|
||||
if not re.match(r'^[a-z0-9_-]+$', self.value):
|
||||
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPServerUrl:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("MCPServerUrl must be a non-empty string")
|
||||
if not self.value.startswith(('http://', 'https://')):
|
||||
raise ValueError("MCPServerUrl must be a valid HTTP/HTTPS URL")
|
||||
|
||||
class ConnectionStatus(Enum):
|
||||
CONNECTED = "connected"
|
||||
|
@ -41,84 +16,76 @@ class ConnectionStatus(Enum):
|
|||
ERROR = "error"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPTool:
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return bool(self.name and self.description and self.input_schema)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPServer:
|
||||
app_slug: AppSlug
|
||||
app_slug: str
|
||||
app_name: str
|
||||
server_url: MCPServerUrl
|
||||
server_url: str
|
||||
project_id: str
|
||||
environment: str
|
||||
external_user_id: ExternalUserId
|
||||
external_user_id: str
|
||||
oauth_app_id: Optional[str] = None
|
||||
status: ConnectionStatus = ConnectionStatus.DISCONNECTED
|
||||
available_tools: List[MCPTool] = field(default_factory=list)
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self.status == ConnectionStatus.CONNECTED
|
||||
|
||||
def add_tool(self, tool: MCPTool) -> None:
|
||||
if tool.is_valid():
|
||||
self.available_tools.append(tool)
|
||||
|
||||
|
||||
def get_tool_count(self) -> int:
|
||||
return len(self.available_tools)
|
||||
|
||||
class PipedreamException(Exception):
|
||||
def __init__(self, message: str, error_code: str = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
|
||||
class MCPServerNotAvailableError(PipedreamException):
|
||||
def __init__(self, message: str = "MCP server is not available"):
|
||||
super().__init__(message, "MCP_SERVER_NOT_AVAILABLE")
|
||||
class MCPServiceError(Exception):
|
||||
pass
|
||||
|
||||
class MCPConnectionError(PipedreamException):
|
||||
def __init__(self, app_slug: str, reason: str):
|
||||
super().__init__(f"MCP connection failed for {app_slug}: {reason}", "MCP_CONNECTION_ERROR")
|
||||
self.app_slug = app_slug
|
||||
self.reason = reason
|
||||
class MCPConnectionError(MCPServiceError):
|
||||
pass
|
||||
|
||||
class AuthenticationException(PipedreamException):
|
||||
def __init__(self, reason: str):
|
||||
super().__init__(f"Authentication failed: {reason}", "AUTHENTICATION_ERROR")
|
||||
self.reason = reason
|
||||
class MCPServerNotAvailableError(MCPServiceError):
|
||||
pass
|
||||
|
||||
class HttpClientException(PipedreamException):
|
||||
def __init__(self, url: str, status_code: int, reason: str):
|
||||
super().__init__(f"HTTP request to {url} failed with status {status_code}: {reason}", "HTTP_CLIENT_ERROR")
|
||||
self.url = url
|
||||
self.status_code = status_code
|
||||
self.reason = reason
|
||||
class AuthenticationError(MCPServiceError):
|
||||
pass
|
||||
|
||||
class RateLimitException(PipedreamException):
|
||||
def __init__(self, retry_after: int = None):
|
||||
super().__init__("Rate limit exceeded", "RATE_LIMIT_EXCEEDED")
|
||||
self.retry_after = retry_after
|
||||
class RateLimitError(MCPServiceError):
|
||||
pass
|
||||
|
||||
class Logger(Protocol):
|
||||
def info(self, message: str) -> None: ...
|
||||
def warning(self, message: str) -> None: ...
|
||||
def error(self, message: str) -> None: ...
|
||||
def debug(self, message: str) -> None: ...
|
||||
|
||||
class HttpClient:
|
||||
def __init__(self):
|
||||
class ExternalUserId:
|
||||
def __init__(self, value: str):
|
||||
if not value or not isinstance(value, str):
|
||||
raise ValueError("ExternalUserId must be a non-empty string")
|
||||
if len(value) > 255:
|
||||
raise ValueError("ExternalUserId must be less than 255 characters")
|
||||
self.value = value
|
||||
|
||||
|
||||
class AppSlug:
|
||||
def __init__(self, value: str):
|
||||
if not value or not isinstance(value, str):
|
||||
raise ValueError("AppSlug must be a non-empty string")
|
||||
if not re.match(r'^[a-z0-9_-]+$', value):
|
||||
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
|
||||
self.value = value
|
||||
|
||||
|
||||
class MCPService:
|
||||
def __init__(self, logger=None):
|
||||
self._logger = logger or logger
|
||||
self.base_url = "https://api.pipedream.com/v1"
|
||||
self.session: Optional[httpx.AsyncClient] = None
|
||||
self.access_token: Optional[str] = None
|
||||
self.token_expires_at: Optional[datetime] = None
|
||||
|
||||
self.session = None
|
||||
self.access_token = None
|
||||
self.token_expires_at = None
|
||||
|
||||
async def _get_session(self) -> httpx.AsyncClient:
|
||||
if self.session is None or self.session.is_closed:
|
||||
self.session = httpx.AsyncClient(
|
||||
|
@ -126,23 +93,27 @@ class HttpClient:
|
|||
headers={"User-Agent": "Suna-Pipedream-Client/1.0"}
|
||||
)
|
||||
return self.session
|
||||
|
||||
|
||||
async def _ensure_access_token(self) -> str:
|
||||
if self.access_token and self.token_expires_at:
|
||||
if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)):
|
||||
return self.access_token
|
||||
else:
|
||||
self.access_token = None
|
||||
self.token_expires_at = None
|
||||
|
||||
return await self._fetch_fresh_token()
|
||||
|
||||
|
||||
async def _fetch_fresh_token(self) -> str:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
client_id = os.getenv("PIPEDREAM_CLIENT_ID")
|
||||
client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET")
|
||||
|
||||
|
||||
if not all([project_id, client_id, client_secret]):
|
||||
raise AuthenticationException("Missing required environment variables")
|
||||
|
||||
raise AuthenticationError("Missing required environment variables")
|
||||
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
try:
|
||||
response = await session.post(
|
||||
f"{self.base_url}/oauth/token",
|
||||
|
@ -153,227 +124,193 @@ class HttpClient:
|
|||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
data = response.json()
|
||||
self.access_token = data["access_token"]
|
||||
expires_in = data.get("expires_in", 3600)
|
||||
self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
|
||||
return self.access_token
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
raise RateLimitException()
|
||||
raise AuthenticationException(f"Failed to obtain access token: {e}")
|
||||
|
||||
async def get(self, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
raise AuthenticationError(f"Failed to obtain access token: {e}")
|
||||
|
||||
async def _make_request(self, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
session = await self._get_session()
|
||||
access_token = await self._ensure_access_token()
|
||||
|
||||
|
||||
request_headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
|
||||
try:
|
||||
response = await session.get(url, headers=request_headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
raise RateLimitException()
|
||||
raise HttpClientException(url, e.response.status_code, str(e))
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.session and not self.session.is_closed:
|
||||
await self.session.aclose()
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
raise MCPServiceError(f"HTTP request failed: {e}")
|
||||
|
||||
class MCPServerRepository:
|
||||
def __init__(self, http_client: HttpClient, logger: Logger):
|
||||
self._http_client = http_client
|
||||
self._logger = logger
|
||||
|
||||
async def discover_for_user(self, external_user_id: ExternalUserId, app_slug: Optional[AppSlug] = None) -> List[MCPServer]:
|
||||
async def _fetch_server_tools(self, external_user_id: str, app_slug: str) -> List[MCPTool]:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
|
||||
|
||||
|
||||
if not project_id:
|
||||
self._logger.error("Missing PIPEDREAM_PROJECT_ID environment variable")
|
||||
return []
|
||||
|
||||
self._logger.info(f"Discovering MCP servers for user: {external_user_id.value}, app_slug: {app_slug.value if app_slug else 'all'}")
|
||||
|
||||
url = f"{self._http_client.base_url}/connect/{project_id}/accounts"
|
||||
|
||||
url = f"{self.base_url}/connect/{project_id}/tools"
|
||||
params = {
|
||||
"app": app_slug,
|
||||
"external_id": external_user_id
|
||||
}
|
||||
headers = {"X-PD-Environment": environment}
|
||||
|
||||
try:
|
||||
data = await self._make_request(url, headers=headers, params=params)
|
||||
tools_data = data.get("data", [])
|
||||
|
||||
tools = []
|
||||
for tool_data in tools_data:
|
||||
if tool_data.get("name") or tool_data.get("key"):
|
||||
tool = MCPTool(
|
||||
name=tool_data.get("name") or tool_data.get("key", ""),
|
||||
description=tool_data.get("description", f"Tool from {app_slug}"),
|
||||
input_schema=tool_data.get("inputSchema") or tool_data.get("props", {})
|
||||
)
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching tools for {app_slug}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def discover_servers_for_user(self, external_user_id: ExternalUserId, app_slug: Optional[AppSlug] = None) -> List[MCPServer]:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
|
||||
|
||||
if not project_id:
|
||||
logger.error("Missing PIPEDREAM_PROJECT_ID environment variable")
|
||||
return []
|
||||
|
||||
logger.info(f"Discovering MCP servers for user: {external_user_id.value}, app_slug: {app_slug.value if app_slug else 'all'}")
|
||||
|
||||
url = f"{self.base_url}/connect/{project_id}/accounts"
|
||||
params = {"external_id": external_user_id.value}
|
||||
headers = {"X-PD-Environment": environment}
|
||||
|
||||
|
||||
try:
|
||||
data = await self._http_client.get(url, headers=headers, params=params)
|
||||
|
||||
data = await self._make_request(url, headers=headers, params=params)
|
||||
|
||||
accounts = data.get("data", [])
|
||||
if not accounts:
|
||||
self._logger.info(f"No connected apps found for user: {external_user_id.value}")
|
||||
logger.info(f"No connected apps found for user: {external_user_id.value}")
|
||||
return []
|
||||
|
||||
|
||||
user_apps = [account.get("app") for account in accounts if account.get("app")]
|
||||
|
||||
|
||||
if app_slug:
|
||||
user_apps = [app for app in user_apps if app.get("name_slug") == app_slug.value]
|
||||
|
||||
|
||||
servers = []
|
||||
for app in user_apps:
|
||||
try:
|
||||
server = MCPServer(
|
||||
app_slug=AppSlug(app.get("name_slug", "")),
|
||||
app_slug=app.get("name_slug", ""),
|
||||
app_name=app.get("name", "Unknown"),
|
||||
server_url=MCPServerUrl("https://remote.mcp.pipedream.net"),
|
||||
server_url="https://remote.mcp.pipedream.net",
|
||||
project_id=project_id,
|
||||
environment=environment,
|
||||
external_user_id=external_user_id,
|
||||
external_user_id=external_user_id.value,
|
||||
status=ConnectionStatus.CONNECTED
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
self._logger.info(f"Attempting to fetch tools for {app.get('name_slug')}...")
|
||||
tools = await self._fetch_server_tools(external_user_id, server.app_slug)
|
||||
|
||||
for tool_data in tools:
|
||||
tool = MCPTool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description or f"Tool from {app.get('name', 'Unknown')}",
|
||||
input_schema=tool_data.inputSchema if hasattr(tool_data, 'inputSchema') else {}
|
||||
)
|
||||
server.add_tool(tool)
|
||||
|
||||
self._logger.info(f"Successfully fetched {len(tools)} tools for {app.get('name_slug')} MCP server")
|
||||
except Exception as tool_error:
|
||||
self._logger.error(f"Could not fetch tools for {app.get('name_slug')}: {str(tool_error)}")
|
||||
import traceback
|
||||
self._logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
logger.info(f"Attempting to fetch tools for {app.get('name_slug')}...")
|
||||
tools = await self._fetch_server_tools(external_user_id.value, server.app_slug)
|
||||
server.available_tools = tools
|
||||
|
||||
logger.info(f"Successfully fetched {len(tools)} tools for app: {app.get('name_slug')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching tools for {app.get('name_slug')}: {str(e)}")
|
||||
server.available_tools = []
|
||||
|
||||
servers.append(server)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Error creating MCP server for app {app.get('name_slug', 'unknown')}: {str(e)}")
|
||||
logger.error(f"Error creating server for app {app.get('name_slug', 'unknown')}: {str(e)}")
|
||||
continue
|
||||
|
||||
self._logger.info(f"Discovered {len(servers)} MCP servers")
|
||||
|
||||
logger.info(f"Successfully discovered {len(servers)} MCP servers")
|
||||
return servers
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error discovering MCP servers: {str(e)}")
|
||||
logger.error(f"Error discovering servers for user {external_user_id.value}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _fetch_server_tools(self, external_user_id: ExternalUserId, app_slug: AppSlug) -> List:
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
access_token = await self._http_client._ensure_access_token()
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"x-pd-project-id": project_id,
|
||||
"x-pd-environment": environment,
|
||||
"x-pd-external-user-id": external_user_id.value,
|
||||
"x-pd-app-slug": app_slug.value,
|
||||
}
|
||||
|
||||
if hasattr(self._http_client, 'rate_limit_token') and self._http_client.rate_limit_token:
|
||||
headers["x-pd-rate-limit"] = self._http_client.rate_limit_token
|
||||
|
||||
url = "https://remote.mcp.pipedream.net"
|
||||
|
||||
async with streamablehttp_client(url, headers=headers) as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
tools = tools_result.tools if hasattr(tools_result, 'tools') else tools_result
|
||||
self._logger.info(f"Successfully fetched {len(tools) if tools else 0} tools from MCP server")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error fetching tools for {app_slug.value}: {str(e)}")
|
||||
import traceback
|
||||
self._logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
async def test_connection(self, server: MCPServer) -> MCPServer:
|
||||
async def test_server_connection(self, server: MCPServer) -> MCPServer:
|
||||
logger.info(f"Testing MCP server connection: {server.app_name}")
|
||||
|
||||
server.status = ConnectionStatus.CONNECTED
|
||||
|
||||
if server.is_connected():
|
||||
logger.info(f"MCP server {server.app_name} connected successfully with {server.get_tool_count()} tools")
|
||||
else:
|
||||
logger.warning(f"MCP server {server.app_name} connection failed: {server.error_message}")
|
||||
|
||||
return server
|
||||
|
||||
async def create_connection(self, external_user_id: ExternalUserId, app_slug: AppSlug, oauth_app_id: Optional[str] = None) -> MCPServer:
|
||||
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
|
||||
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
|
||||
|
||||
|
||||
if not project_id:
|
||||
raise MCPConnectionError(app_slug.value, "Missing PIPEDREAM_PROJECT_ID")
|
||||
|
||||
raise MCPConnectionError("Missing PIPEDREAM_PROJECT_ID")
|
||||
|
||||
logger.info(f"Creating MCP connection for user: {external_user_id.value}, app: {app_slug.value}")
|
||||
|
||||
server = MCPServer(
|
||||
app_slug=app_slug,
|
||||
app_slug=app_slug.value,
|
||||
app_name=app_slug.value.replace('_', ' ').title(),
|
||||
server_url=MCPServerUrl("https://remote.mcp.pipedream.net"),
|
||||
server_url="https://remote.mcp.pipedream.net",
|
||||
project_id=project_id,
|
||||
environment=environment,
|
||||
external_user_id=external_user_id,
|
||||
external_user_id=external_user_id.value,
|
||||
oauth_app_id=oauth_app_id,
|
||||
status=ConnectionStatus.CONNECTED
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
class MCPService:
|
||||
def __init__(self, logger: Optional[Logger] = None):
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._http_client = HttpClient()
|
||||
self._mcp_server_repo = MCPServerRepository(self._http_client, self._logger)
|
||||
|
||||
async def discover_servers_for_user(
|
||||
self,
|
||||
external_user_id: ExternalUserId,
|
||||
app_slug: Optional[AppSlug] = None
|
||||
) -> List[MCPServer]:
|
||||
self._logger.info(f"Discovering MCP servers for user: {external_user_id.value}")
|
||||
|
||||
servers = await self._mcp_server_repo.discover_for_user(external_user_id, app_slug)
|
||||
|
||||
connected_count = sum(1 for server in servers if server.is_connected())
|
||||
self._logger.info(f"Discovered {len(servers)} MCP servers ({connected_count} connected)")
|
||||
|
||||
return servers
|
||||
|
||||
async def test_server_connection(self, server: MCPServer) -> MCPServer:
|
||||
self._logger.info(f"Testing MCP server connection: {server.app_name}")
|
||||
|
||||
tested_server = await self._mcp_server_repo.test_connection(server)
|
||||
|
||||
if tested_server.is_connected():
|
||||
self._logger.info(f"MCP server {server.app_name} connected successfully with {tested_server.get_tool_count()} tools")
|
||||
else:
|
||||
self._logger.warning(f"MCP server {server.app_name} connection failed: {tested_server.error_message}")
|
||||
|
||||
return tested_server
|
||||
|
||||
async def create_connection(
|
||||
self,
|
||||
external_user_id: ExternalUserId,
|
||||
app_slug: AppSlug,
|
||||
oauth_app_id: Optional[str] = None
|
||||
) -> MCPServer:
|
||||
self._logger.info(f"Creating MCP connection for user: {external_user_id.value}, app: {app_slug.value}")
|
||||
|
||||
server = await self._mcp_server_repo.create_connection(external_user_id, app_slug, oauth_app_id)
|
||||
|
||||
if server.is_connected():
|
||||
self._logger.info(f"Successfully created MCP connection for {app_slug.value} with {server.get_tool_count()} tools")
|
||||
logger.info(f"Successfully created MCP connection for {app_slug.value} with {server.get_tool_count()} tools")
|
||||
else:
|
||||
self._logger.error(f"Failed to create MCP connection for {app_slug.value}: {server.error_message}")
|
||||
|
||||
logger.error(f"Failed to create MCP connection for {app_slug.value}: {server.error_message}")
|
||||
|
||||
return server
|
||||
|
||||
|
||||
async def close(self):
|
||||
await self._http_client.close()
|
||||
if self.session and not self.session.is_closed:
|
||||
await self.session.aclose()
|
||||
|
||||
|
||||
_mcp_service = None
|
||||
|
||||
def get_mcp_service() -> MCPService:
|
||||
global _mcp_service
|
||||
if _mcp_service is None:
|
||||
_mcp_service = MCPService()
|
||||
return _mcp_service
|
||||
|
||||
|
||||
PipedreamException = MCPServiceError
|
||||
MCPServerNotAvailableError = MCPServerNotAvailableError
|
||||
AuthenticationException = AuthenticationError
|
||||
HttpClientException = MCPServiceError
|
||||
RateLimitException = RateLimitError
|
|
@ -1,461 +1,191 @@
|
|||
from typing import List, Optional, Dict, Any, Protocol
|
||||
from uuid import UUID
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import json
|
||||
import hashlib
|
||||
import re
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalUserId:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("ExternalUserId must be a non-empty string")
|
||||
if len(self.value) > 255:
|
||||
raise ValueError("ExternalUserId must be less than 255 characters")
|
||||
from services.supabase import DBConnection
|
||||
from utils.logger import logger
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppSlug:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("AppSlug must be a non-empty string")
|
||||
if not re.match(r'^[a-z0-9_-]+$', self.value):
|
||||
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfileName:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("ProfileName must be a non-empty string")
|
||||
if len(self.value) > 100:
|
||||
raise ValueError("ProfileName must be less than 100 characters")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EncryptedConfig:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("EncryptedConfig must be a non-empty string")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConfigHash:
|
||||
value: str
|
||||
def __post_init__(self):
|
||||
if not self.value or not isinstance(self.value, str):
|
||||
raise ValueError("ConfigHash must be a non-empty string")
|
||||
if len(self.value) != 64:
|
||||
raise ValueError("ConfigHash must be a 64-character SHA256 hash")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: str) -> 'ConfigHash':
|
||||
hash_value = hashlib.sha256(config.encode()).hexdigest()
|
||||
return cls(hash_value)
|
||||
|
||||
@dataclass
|
||||
class Profile:
|
||||
profile_id: UUID
|
||||
account_id: UUID
|
||||
profile_id: str
|
||||
account_id: str
|
||||
mcp_qualified_name: str
|
||||
profile_name: ProfileName
|
||||
profile_name: str
|
||||
display_name: str
|
||||
encrypted_config: EncryptedConfig
|
||||
config_hash: ConfigHash
|
||||
app_slug: AppSlug
|
||||
encrypted_config: str
|
||||
config_hash: str
|
||||
app_slug: str
|
||||
app_name: str
|
||||
external_user_id: ExternalUserId
|
||||
external_user_id: str
|
||||
enabled_tools: List[str] = field(default_factory=list)
|
||||
is_active: bool = True
|
||||
is_default: bool = False
|
||||
is_connected: bool = False
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_used_at: Optional[datetime] = None
|
||||
|
||||
def update_last_used(self) -> None:
|
||||
self.last_used_at = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def activate(self) -> None:
|
||||
self.is_active = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def deactivate(self) -> None:
|
||||
self.is_active = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def set_as_default(self) -> None:
|
||||
self.is_default = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def unset_as_default(self) -> None:
|
||||
self.is_default = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def update_connection_status(self, is_connected: bool) -> None:
|
||||
self.is_connected = is_connected
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def enable_tool(self, tool_name: str) -> None:
|
||||
if tool_name not in self.enabled_tools:
|
||||
self.enabled_tools.append(tool_name)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def disable_tool(self, tool_name: str) -> None:
|
||||
if tool_name in self.enabled_tools:
|
||||
self.enabled_tools.remove(tool_name)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def get_mcp_qualified_name(self) -> str:
|
||||
return f"pipedream:{self.app_slug.value}"
|
||||
description: Optional[str] = None
|
||||
oauth_app_id: Optional[str] = None
|
||||
|
||||
class PipedreamException(Exception):
|
||||
def __init__(self, message: str, error_code: str = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'profile_id': self.profile_id,
|
||||
'account_id': self.account_id,
|
||||
'mcp_qualified_name': self.mcp_qualified_name,
|
||||
'profile_name': self.profile_name,
|
||||
'display_name': self.display_name,
|
||||
'encrypted_config': self.encrypted_config,
|
||||
'config_hash': self.config_hash,
|
||||
'app_slug': self.app_slug,
|
||||
'app_name': self.app_name,
|
||||
'external_user_id': self.external_user_id,
|
||||
'enabled_tools': self.enabled_tools,
|
||||
'is_active': self.is_active,
|
||||
'is_default': self.is_default,
|
||||
'is_connected': self.is_connected,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'updated_at': self.updated_at.isoformat(),
|
||||
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
'description': self.description,
|
||||
'oauth_app_id': self.oauth_app_id
|
||||
}
|
||||
|
||||
class DomainException(PipedreamException):
|
||||
|
||||
class ProfileServiceError(Exception):
|
||||
pass
|
||||
|
||||
class ProfileNotFoundError(DomainException):
|
||||
def __init__(self, profile_id: str):
|
||||
super().__init__(f"Profile with ID {profile_id} not found", "PROFILE_NOT_FOUND")
|
||||
self.profile_id = profile_id
|
||||
class ProfileNotFoundError(ProfileServiceError):
|
||||
pass
|
||||
|
||||
class ProfileAlreadyExistsError(DomainException):
|
||||
def __init__(self, profile_name: str, app_slug: str):
|
||||
super().__init__(f"Profile '{profile_name}' already exists for app '{app_slug}'", "PROFILE_ALREADY_EXISTS")
|
||||
self.profile_name = profile_name
|
||||
self.app_slug = app_slug
|
||||
class ProfileAlreadyExistsError(ProfileServiceError):
|
||||
pass
|
||||
|
||||
class DatabaseException(PipedreamException):
|
||||
def __init__(self, operation: str, reason: str):
|
||||
super().__init__(f"Database operation '{operation}' failed: {reason}", "DATABASE_ERROR")
|
||||
self.operation = operation
|
||||
self.reason = reason
|
||||
class InvalidConfigError(ProfileServiceError):
|
||||
pass
|
||||
|
||||
class EncryptionException(PipedreamException):
|
||||
def __init__(self, operation: str, reason: str):
|
||||
super().__init__(f"Encryption operation '{operation}' failed: {reason}", "ENCRYPTION_ERROR")
|
||||
self.operation = operation
|
||||
self.reason = reason
|
||||
class EncryptionError(ProfileServiceError):
|
||||
pass
|
||||
|
||||
class DatabaseConnection(Protocol):
|
||||
async def client(self) -> Any: ...
|
||||
|
||||
class Logger(Protocol):
|
||||
def info(self, message: str) -> None: ...
|
||||
def warning(self, message: str) -> None: ...
|
||||
def error(self, message: str) -> None: ...
|
||||
def debug(self, message: str) -> None: ...
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
def encrypt(self, data: str) -> str:
|
||||
try:
|
||||
from utils.encryption import encrypt_data
|
||||
return encrypt_data(data)
|
||||
except Exception as e:
|
||||
raise EncryptionException("encrypt", str(e))
|
||||
|
||||
def decrypt(self, encrypted_data: str) -> str:
|
||||
try:
|
||||
from utils.encryption import decrypt_data
|
||||
return decrypt_data(encrypted_data)
|
||||
except Exception as e:
|
||||
raise EncryptionException("decrypt", str(e))
|
||||
|
||||
class ExternalUserIdService:
|
||||
def generate(self, account_id: str, app_slug: AppSlug, profile_name: ProfileName) -> ExternalUserId:
|
||||
combined = f"{account_id}:{app_slug.value}:{profile_name.value}"
|
||||
hash_value = hashlib.sha256(combined.encode()).hexdigest()[:16]
|
||||
external_user_id = f"suna_{hash_value}"
|
||||
return ExternalUserId(external_user_id)
|
||||
|
||||
class MCPQualifiedNameService:
|
||||
def generate(self, app_slug: AppSlug) -> str:
|
||||
return f"pipedream:{app_slug.value}"
|
||||
|
||||
class ProfileConfigurationService:
|
||||
def validate_config(self, config: Dict[str, Any]) -> bool:
|
||||
required_keys = ["app_slug", "app_name", "external_user_id"]
|
||||
return all(key in config for key in required_keys)
|
||||
|
||||
def merge_config(self, existing_config: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
|
||||
merged = existing_config.copy()
|
||||
merged.update(updates)
|
||||
return merged
|
||||
|
||||
class ConnectionStatusService:
|
||||
def __init__(self, logger: Logger):
|
||||
self._logger = logger
|
||||
|
||||
async def check_connection_status(self, profile: Profile) -> bool:
|
||||
return True
|
||||
|
||||
async def update_connection_status(self, profile: Profile) -> Profile:
|
||||
try:
|
||||
is_connected = await self.check_connection_status(profile)
|
||||
profile.update_connection_status(is_connected)
|
||||
return profile
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Error updating connection status: {str(e)}")
|
||||
profile.update_connection_status(False)
|
||||
return profile
|
||||
|
||||
|
||||
class ProfileRepository:
|
||||
def __init__(self, db: DatabaseConnection, encryption_service: EncryptionService, logger: Logger):
|
||||
self._db = db
|
||||
self._encryption_service = encryption_service
|
||||
self._logger = logger
|
||||
|
||||
async def create(self, profile: Profile) -> Profile:
|
||||
try:
|
||||
client = await self._db.client
|
||||
|
||||
config = {
|
||||
"app_slug": profile.app_slug.value,
|
||||
"app_name": profile.app_name,
|
||||
"external_user_id": profile.external_user_id.value,
|
||||
"enabled_tools": profile.enabled_tools,
|
||||
"oauth_app_id": getattr(profile, 'oauth_app_id', None),
|
||||
"description": getattr(profile, 'description', None)
|
||||
}
|
||||
|
||||
config_json = json.dumps(config)
|
||||
encrypted_config = self._encryption_service.encrypt(config_json)
|
||||
config_hash = ConfigHash.from_config(config_json)
|
||||
|
||||
result = await client.table('user_mcp_credential_profiles').insert({
|
||||
'account_id': str(profile.account_id),
|
||||
'mcp_qualified_name': profile.mcp_qualified_name,
|
||||
'profile_name': profile.profile_name.value,
|
||||
'display_name': profile.display_name,
|
||||
'encrypted_config': encrypted_config,
|
||||
'config_hash': config_hash.value,
|
||||
'is_active': profile.is_active,
|
||||
'is_default': profile.is_default,
|
||||
'created_at': profile.created_at.isoformat(),
|
||||
'updated_at': profile.updated_at.isoformat()
|
||||
}).execute()
|
||||
|
||||
if result.data:
|
||||
profile_data = result.data[0]
|
||||
return self._map_to_domain(profile_data, config)
|
||||
|
||||
raise DatabaseException("create", "No data returned from insert")
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error creating profile: {str(e)}")
|
||||
raise DatabaseException("create", str(e))
|
||||
|
||||
async def get_by_id(self, account_id: UUID, profile_id: UUID) -> Optional[Profile]:
|
||||
try:
|
||||
client = await self._db.client
|
||||
|
||||
self._logger.debug(f"Querying profile: account_id={account_id}, profile_id={profile_id}")
|
||||
|
||||
result = await client.table('user_mcp_credential_profiles').select('*').eq(
|
||||
'account_id', str(account_id)
|
||||
).eq('profile_id', str(profile_id)).single().execute()
|
||||
|
||||
if result.data:
|
||||
profile_data = result.data
|
||||
self._logger.debug(f"Found profile: {profile_data.get('profile_name', 'unknown')}")
|
||||
decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config'])
|
||||
config = json.loads(decrypted_config)
|
||||
return self._map_to_domain(profile_data, config)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error getting profile by ID {profile_id} for user {account_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_by_app_slug(self, account_id: UUID, app_slug: AppSlug, profile_name: Optional[ProfileName] = None) -> Optional[Profile]:
|
||||
try:
|
||||
client = await self._db.client
|
||||
|
||||
mcp_qualified_name = f"pipedream:{app_slug.value}"
|
||||
query = client.table('user_mcp_credential_profiles').select('*').eq(
|
||||
'account_id', str(account_id)
|
||||
).eq('mcp_qualified_name', mcp_qualified_name)
|
||||
|
||||
if profile_name:
|
||||
query = query.eq('profile_name', profile_name.value)
|
||||
|
||||
result = await query.execute()
|
||||
|
||||
if result.data:
|
||||
if profile_name:
|
||||
profile_data = result.data[0]
|
||||
else:
|
||||
profile_data = next((p for p in result.data if p.get('is_default')), result.data[0])
|
||||
|
||||
decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config'])
|
||||
config = json.loads(decrypted_config)
|
||||
return self._map_to_domain(profile_data, config)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error getting profile by app slug: {str(e)}")
|
||||
return None
|
||||
|
||||
async def find_by_account(self, account_id: UUID, app_slug: Optional[AppSlug] = None, is_active: Optional[bool] = None) -> List[Profile]:
|
||||
try:
|
||||
client = await self._db.client
|
||||
|
||||
query = client.table('user_mcp_credential_profiles').select('*').eq(
|
||||
'account_id', str(account_id)
|
||||
)
|
||||
|
||||
if app_slug:
|
||||
mcp_qualified_name = f"pipedream:{app_slug.value}"
|
||||
query = query.eq('mcp_qualified_name', mcp_qualified_name)
|
||||
else:
|
||||
query = query.like('mcp_qualified_name', 'pipedream:%')
|
||||
|
||||
if is_active is not None:
|
||||
query = query.eq('is_active', is_active)
|
||||
|
||||
result = await query.order('created_at', desc=True).execute()
|
||||
|
||||
profiles = []
|
||||
for profile_data in result.data:
|
||||
try:
|
||||
decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config'])
|
||||
config = json.loads(decrypted_config)
|
||||
profile = self._map_to_domain(profile_data, config)
|
||||
profiles.append(profile)
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error decrypting profile config: {str(e)}")
|
||||
continue
|
||||
|
||||
return profiles
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error finding profiles by account: {str(e)}")
|
||||
return []
|
||||
|
||||
async def update(self, profile: Profile) -> Profile:
|
||||
try:
|
||||
client = await self._db.client
|
||||
|
||||
config = {
|
||||
"app_slug": profile.app_slug.value,
|
||||
"app_name": profile.app_name,
|
||||
"external_user_id": profile.external_user_id.value,
|
||||
"enabled_tools": profile.enabled_tools,
|
||||
"oauth_app_id": getattr(profile, 'oauth_app_id', None),
|
||||
"description": getattr(profile, 'description', None)
|
||||
}
|
||||
|
||||
config_json = json.dumps(config)
|
||||
encrypted_config = self._encryption_service.encrypt(config_json)
|
||||
config_hash = ConfigHash.from_config(config_json)
|
||||
|
||||
result = await client.table('user_mcp_credential_profiles').update({
|
||||
'profile_name': profile.profile_name.value,
|
||||
'display_name': profile.display_name,
|
||||
'encrypted_config': encrypted_config,
|
||||
'config_hash': config_hash.value,
|
||||
'is_active': profile.is_active,
|
||||
'is_default': profile.is_default,
|
||||
'updated_at': profile.updated_at.isoformat(),
|
||||
'last_used_at': profile.last_used_at.isoformat() if profile.last_used_at else None
|
||||
}).eq('profile_id', str(profile.profile_id)).execute()
|
||||
|
||||
if result.data:
|
||||
return self._map_to_domain(result.data[0], config)
|
||||
|
||||
raise DatabaseException("update", "No data returned from update")
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error updating profile: {str(e)}")
|
||||
raise DatabaseException("update", str(e))
|
||||
|
||||
async def delete(self, account_id: UUID, profile_id: UUID) -> bool:
|
||||
try:
|
||||
client = await self._db.client
|
||||
|
||||
result = await client.table('user_mcp_credential_profiles').delete().eq(
|
||||
'profile_id', str(profile_id)
|
||||
).eq('account_id', str(account_id)).execute()
|
||||
|
||||
return len(result.data) > 0
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error deleting profile: {str(e)}")
|
||||
return False
|
||||
|
||||
async def set_default(self, account_id: UUID, profile_id: UUID, mcp_qualified_name: str) -> None:
|
||||
try:
|
||||
client = await self._db.client
|
||||
|
||||
await client.table('user_mcp_credential_profiles').update({
|
||||
'is_default': False
|
||||
}).eq('account_id', str(account_id)).eq('mcp_qualified_name', mcp_qualified_name).execute()
|
||||
|
||||
await client.table('user_mcp_credential_profiles').update({
|
||||
'is_default': True
|
||||
}).eq('profile_id', str(profile_id)).execute()
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error setting default profile: {str(e)}")
|
||||
raise DatabaseException("set_default", str(e))
|
||||
|
||||
def _map_to_domain(self, profile_data: dict, config: dict) -> Profile:
|
||||
return Profile(
|
||||
profile_id=UUID(profile_data['profile_id']),
|
||||
account_id=UUID(profile_data['account_id']),
|
||||
mcp_qualified_name=profile_data['mcp_qualified_name'],
|
||||
profile_name=ProfileName(profile_data['profile_name']),
|
||||
display_name=profile_data['display_name'],
|
||||
encrypted_config=EncryptedConfig(profile_data['encrypted_config']),
|
||||
config_hash=ConfigHash(profile_data['config_hash']),
|
||||
app_slug=AppSlug(config['app_slug']),
|
||||
app_name=config['app_name'],
|
||||
external_user_id=ExternalUserId(config['external_user_id']),
|
||||
enabled_tools=config.get('enabled_tools', []),
|
||||
is_active=profile_data['is_active'],
|
||||
is_default=profile_data['is_default'],
|
||||
is_connected=False,
|
||||
created_at=datetime.fromisoformat(profile_data['created_at']),
|
||||
updated_at=datetime.fromisoformat(profile_data['updated_at']),
|
||||
last_used_at=datetime.fromisoformat(profile_data['last_used_at']) if profile_data.get('last_used_at') else None
|
||||
)
|
||||
|
||||
class ProfileService:
|
||||
def __init__(
|
||||
self,
|
||||
db: Optional[DatabaseConnection] = None,
|
||||
logger: Optional[Logger] = None
|
||||
):
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
def __init__(self):
|
||||
self.db = DBConnection()
|
||||
self._connection_service = None
|
||||
|
||||
async def _get_client(self):
|
||||
return await self.db.client
|
||||
|
||||
def _get_connection_service(self):
|
||||
if self._connection_service is None:
|
||||
from .connection_service import ConnectionService
|
||||
from utils.logger import logger
|
||||
self._connection_service = ConnectionService(logger=logger)
|
||||
return self._connection_service
|
||||
|
||||
async def _check_connection_status(self, external_user_id: str, app_slug: str) -> bool:
|
||||
try:
|
||||
from .connection_service import ExternalUserId, AppSlug
|
||||
connection_service = self._get_connection_service()
|
||||
return await connection_service.has_connection(
|
||||
ExternalUserId(external_user_id),
|
||||
AppSlug(app_slug)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking connection status: {str(e)}")
|
||||
return False
|
||||
|
||||
def _validate_app_slug(self, app_slug: str) -> None:
|
||||
if not app_slug or not isinstance(app_slug, str):
|
||||
raise InvalidConfigError("App slug must be a non-empty string")
|
||||
if not re.match(r'^[a-z0-9_-]+$', app_slug):
|
||||
raise InvalidConfigError("App slug must contain only lowercase letters, numbers, hyphens, and underscores")
|
||||
|
||||
def _validate_profile_name(self, profile_name: str) -> None:
|
||||
if not profile_name or not isinstance(profile_name, str):
|
||||
raise InvalidConfigError("Profile name must be a non-empty string")
|
||||
if len(profile_name) > 100:
|
||||
raise InvalidConfigError("Profile name must be less than 100 characters")
|
||||
|
||||
def _generate_external_user_id(self, account_id: str, app_slug: str, profile_name: str) -> str:
|
||||
combined = f"{account_id}:{app_slug}:{profile_name}"
|
||||
hash_value = hashlib.sha256(combined.encode()).hexdigest()[:16]
|
||||
return f"suna_{hash_value}"
|
||||
|
||||
def _generate_config_hash(self, config_json: str) -> str:
|
||||
return hashlib.sha256(config_json.encode()).hexdigest()
|
||||
|
||||
def _encrypt_config(self, config_json: str) -> str:
|
||||
try:
|
||||
from utils.encryption import encrypt_data
|
||||
return encrypt_data(config_json)
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Failed to encrypt config: {str(e)}")
|
||||
|
||||
def _decrypt_config(self, encrypted_config: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from utils.encryption import decrypt_data
|
||||
decrypted_json = decrypt_data(encrypted_config)
|
||||
return json.loads(decrypted_json)
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Failed to decrypt config: {str(e)}")
|
||||
|
||||
def _build_config(self, app_slug: str, app_name: str, external_user_id: str,
|
||||
enabled_tools: List[str], oauth_app_id: Optional[str] = None,
|
||||
description: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {
|
||||
"app_slug": app_slug,
|
||||
"app_name": app_name,
|
||||
"external_user_id": external_user_id,
|
||||
"enabled_tools": enabled_tools,
|
||||
"oauth_app_id": oauth_app_id,
|
||||
"description": description
|
||||
}
|
||||
|
||||
async def _map_row_to_profile(self, row: Dict[str, Any]) -> Profile:
|
||||
try:
|
||||
config = self._decrypt_config(row['encrypted_config'])
|
||||
except Exception:
|
||||
config = {
|
||||
"app_slug": "unknown",
|
||||
"app_name": "Unknown App",
|
||||
"external_user_id": "unknown",
|
||||
"enabled_tools": [],
|
||||
"oauth_app_id": None,
|
||||
"description": None
|
||||
}
|
||||
|
||||
if db is None:
|
||||
from services.supabase import DBConnection
|
||||
self._db = DBConnection()
|
||||
else:
|
||||
self._db = db
|
||||
is_connected = await self._check_connection_status(config['external_user_id'], config['app_slug'])
|
||||
|
||||
self._encryption_service = EncryptionService()
|
||||
self._profile_repo = ProfileRepository(self._db, self._encryption_service, self._logger)
|
||||
self._external_user_id_service = ExternalUserIdService()
|
||||
self._mcp_qualified_name_service = MCPQualifiedNameService()
|
||||
self._profile_config_service = ProfileConfigurationService()
|
||||
self._connection_status_service = ConnectionStatusService(self._logger)
|
||||
|
||||
return Profile(
|
||||
profile_id=row['profile_id'],
|
||||
account_id=row['account_id'],
|
||||
mcp_qualified_name=row['mcp_qualified_name'],
|
||||
profile_name=row['profile_name'],
|
||||
display_name=row['display_name'],
|
||||
encrypted_config=row['encrypted_config'],
|
||||
config_hash=row['config_hash'],
|
||||
app_slug=config['app_slug'],
|
||||
app_name=config['app_name'],
|
||||
external_user_id=config['external_user_id'],
|
||||
enabled_tools=config.get('enabled_tools', []),
|
||||
is_active=row['is_active'],
|
||||
is_default=row['is_default'],
|
||||
is_connected=is_connected,
|
||||
created_at=datetime.fromisoformat(row['created_at'].replace('Z', '+00:00')) if isinstance(row['created_at'], str) else row['created_at'],
|
||||
updated_at=datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00')) if isinstance(row['updated_at'], str) else row['updated_at'],
|
||||
last_used_at=datetime.fromisoformat(row['last_used_at'].replace('Z', '+00:00')) if row.get('last_used_at') and isinstance(row['last_used_at'], str) else row.get('last_used_at'),
|
||||
description=config.get('description'),
|
||||
oauth_app_id=config.get('oauth_app_id')
|
||||
)
|
||||
|
||||
async def create_profile(
|
||||
self,
|
||||
account_id: UUID,
|
||||
account_id: str,
|
||||
profile_name: str,
|
||||
app_slug: str,
|
||||
app_name: str,
|
||||
|
@ -465,130 +195,246 @@ class ProfileService:
|
|||
enabled_tools: Optional[List[str]] = None,
|
||||
external_user_id: Optional[str] = None
|
||||
) -> Profile:
|
||||
app_slug_vo = AppSlug(app_slug)
|
||||
profile_name_vo = ProfileName(profile_name)
|
||||
self._validate_app_slug(app_slug)
|
||||
self._validate_profile_name(profile_name)
|
||||
|
||||
if external_user_id:
|
||||
external_user_id_vo = ExternalUserId(external_user_id)
|
||||
else:
|
||||
external_user_id_vo = self._external_user_id_service.generate(
|
||||
str(account_id), app_slug_vo, profile_name_vo
|
||||
if enabled_tools is None:
|
||||
enabled_tools = []
|
||||
|
||||
if not external_user_id:
|
||||
external_user_id = self._generate_external_user_id(account_id, app_slug, profile_name)
|
||||
|
||||
config = self._build_config(app_slug, app_name, external_user_id, enabled_tools, oauth_app_id, description)
|
||||
config_json = json.dumps(config, sort_keys=True)
|
||||
encrypted_config = self._encrypt_config(config_json)
|
||||
config_hash = self._generate_config_hash(config_json)
|
||||
|
||||
mcp_qualified_name = f"pipedream:{app_slug}"
|
||||
profile_id = str(uuid4())
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
existing = await client.table('user_mcp_credential_profiles').select('profile_id').eq(
|
||||
'account_id', account_id
|
||||
).eq('mcp_qualified_name', mcp_qualified_name).eq('profile_name', profile_name).execute()
|
||||
|
||||
if existing.data:
|
||||
raise ProfileAlreadyExistsError(f"Profile '{profile_name}' already exists for app '{app_slug}'")
|
||||
|
||||
if is_default:
|
||||
await client.table('user_mcp_credential_profiles').update({
|
||||
'is_default': False
|
||||
}).eq('account_id', account_id).eq('mcp_qualified_name', mcp_qualified_name).execute()
|
||||
|
||||
result = await client.table('user_mcp_credential_profiles').insert({
|
||||
'profile_id': profile_id,
|
||||
'account_id': account_id,
|
||||
'mcp_qualified_name': mcp_qualified_name,
|
||||
'profile_name': profile_name,
|
||||
'display_name': profile_name,
|
||||
'encrypted_config': encrypted_config,
|
||||
'config_hash': config_hash,
|
||||
'is_active': True,
|
||||
'is_default': is_default,
|
||||
'created_at': now.isoformat(),
|
||||
'updated_at': now.isoformat()
|
||||
}).execute()
|
||||
|
||||
if not result.data:
|
||||
raise ProfileServiceError("Failed to create profile")
|
||||
|
||||
logger.info(f"Created profile {profile_id} for app {app_slug}")
|
||||
|
||||
return Profile(
|
||||
profile_id=profile_id,
|
||||
account_id=account_id,
|
||||
mcp_qualified_name=mcp_qualified_name,
|
||||
profile_name=profile_name,
|
||||
display_name=profile_name,
|
||||
encrypted_config=encrypted_config,
|
||||
config_hash=config_hash,
|
||||
app_slug=app_slug,
|
||||
app_name=app_name,
|
||||
external_user_id=external_user_id,
|
||||
enabled_tools=enabled_tools,
|
||||
is_active=True,
|
||||
is_default=is_default,
|
||||
is_connected=False,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
description=description,
|
||||
oauth_app_id=oauth_app_id
|
||||
)
|
||||
|
||||
mcp_qualified_name = self._mcp_qualified_name_service.generate(app_slug_vo)
|
||||
|
||||
config = {
|
||||
"app_slug": app_slug,
|
||||
"app_name": app_name,
|
||||
"external_user_id": external_user_id_vo.value,
|
||||
"oauth_app_id": oauth_app_id,
|
||||
"enabled_tools": enabled_tools or [],
|
||||
"description": description
|
||||
}
|
||||
|
||||
if not self._profile_config_service.validate_config(config):
|
||||
raise ValueError("Invalid profile configuration")
|
||||
|
||||
config_json = json.dumps(config)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (ProfileAlreadyExistsError, ProfileServiceError)):
|
||||
raise
|
||||
logger.error(f"Error creating profile: {str(e)}")
|
||||
raise ProfileServiceError(f"Failed to create profile: {str(e)}")
|
||||
|
||||
async def get_profile(self, account_id: str, profile_id: str) -> Optional[Profile]:
|
||||
client = await self._get_client()
|
||||
|
||||
profile = Profile(
|
||||
profile_id=UUID(int=0),
|
||||
account_id=account_id,
|
||||
mcp_qualified_name=mcp_qualified_name,
|
||||
profile_name=profile_name_vo,
|
||||
display_name=profile_name,
|
||||
encrypted_config=EncryptedConfig("placeholder"),
|
||||
config_hash=ConfigHash.from_config(config_json),
|
||||
app_slug=app_slug_vo,
|
||||
app_name=app_name,
|
||||
external_user_id=external_user_id_vo,
|
||||
enabled_tools=enabled_tools or [],
|
||||
is_default=is_default
|
||||
)
|
||||
|
||||
if is_default:
|
||||
await self._profile_repo.set_default(account_id, profile.profile_id, mcp_qualified_name)
|
||||
|
||||
created_profile = await self._profile_repo.create(profile)
|
||||
self._logger.info(f"Created profile {created_profile.profile_id} for app {app_slug}")
|
||||
|
||||
return created_profile
|
||||
|
||||
async def get_profile(self, account_id: UUID, profile_id: UUID) -> Optional[Profile]:
|
||||
profile = await self._profile_repo.get_by_id(account_id, profile_id)
|
||||
if profile:
|
||||
return await self._connection_status_service.update_connection_status(profile)
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await client.table('user_mcp_credential_profiles').select('*').eq(
|
||||
'account_id', account_id
|
||||
).eq('profile_id', profile_id).execute()
|
||||
|
||||
if not result.data:
|
||||
return None
|
||||
|
||||
return await self._map_row_to_profile(result.data[0])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profile {profile_id}: {str(e)}")
|
||||
raise ProfileServiceError(f"Failed to get profile: {str(e)}")
|
||||
|
||||
async def get_profiles(
|
||||
self,
|
||||
account_id: UUID,
|
||||
account_id: str,
|
||||
app_slug: Optional[str] = None,
|
||||
is_active: Optional[bool] = None
|
||||
) -> List[Profile]:
|
||||
app_slug_vo = AppSlug(app_slug) if app_slug else None
|
||||
profiles = await self._profile_repo.find_by_account(account_id, app_slug_vo, is_active)
|
||||
client = await self._get_client()
|
||||
|
||||
updated_profiles = []
|
||||
for profile in profiles:
|
||||
updated_profile = await self._connection_status_service.update_connection_status(profile)
|
||||
updated_profiles.append(updated_profile)
|
||||
|
||||
return updated_profiles
|
||||
|
||||
try:
|
||||
query = client.table('user_mcp_credential_profiles').select('*').eq('account_id', account_id)
|
||||
|
||||
if app_slug:
|
||||
mcp_qualified_name = f"pipedream:{app_slug}"
|
||||
query = query.eq('mcp_qualified_name', mcp_qualified_name)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.eq('is_active', is_active)
|
||||
|
||||
result = await query.execute()
|
||||
|
||||
import asyncio
|
||||
profiles = await asyncio.gather(*[self._map_row_to_profile(row) for row in result.data])
|
||||
return profiles
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profiles: {str(e)}")
|
||||
raise ProfileServiceError(f"Failed to get profiles: {str(e)}")
|
||||
|
||||
async def update_profile(
|
||||
self,
|
||||
account_id: UUID,
|
||||
profile_id: UUID,
|
||||
account_id: str,
|
||||
profile_id: str,
|
||||
profile_name: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
is_default: Optional[bool] = None,
|
||||
enabled_tools: Optional[List[str]] = None
|
||||
) -> Profile:
|
||||
profile = await self._profile_repo.get_by_id(account_id, profile_id)
|
||||
if not profile:
|
||||
raise ProfileNotFoundError(str(profile_id))
|
||||
|
||||
if profile_name:
|
||||
profile.profile_name = ProfileName(profile_name)
|
||||
if display_name:
|
||||
profile.display_name = display_name
|
||||
if is_active is not None:
|
||||
if is_active:
|
||||
profile.activate()
|
||||
else:
|
||||
profile.deactivate()
|
||||
if is_default is not None:
|
||||
if is_default:
|
||||
await self._profile_repo.set_default(account_id, profile_id, profile.mcp_qualified_name)
|
||||
profile.set_as_default()
|
||||
else:
|
||||
profile.unset_as_default()
|
||||
if enabled_tools is not None:
|
||||
profile.enabled_tools = enabled_tools
|
||||
|
||||
updated_profile = await self._profile_repo.update(profile)
|
||||
self._logger.info(f"Updated profile {profile_id}")
|
||||
existing_profile = await self.get_profile(account_id, profile_id)
|
||||
if not existing_profile:
|
||||
raise ProfileNotFoundError(f"Profile {profile_id} not found")
|
||||
|
||||
return updated_profile
|
||||
|
||||
async def delete_profile(self, account_id: UUID, profile_id: UUID) -> bool:
|
||||
success = await self._profile_repo.delete(account_id, profile_id)
|
||||
if success:
|
||||
self._logger.info(f"Deleted profile {profile_id}")
|
||||
return success
|
||||
|
||||
client = await self._get_client()
|
||||
updates = {'updated_at': datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
try:
|
||||
if enabled_tools is not None:
|
||||
config = self._decrypt_config(existing_profile.encrypted_config)
|
||||
config['enabled_tools'] = enabled_tools
|
||||
config_json = json.dumps(config, sort_keys=True)
|
||||
updates['encrypted_config'] = self._encrypt_config(config_json)
|
||||
updates['config_hash'] = self._generate_config_hash(config_json)
|
||||
|
||||
if profile_name is not None:
|
||||
self._validate_profile_name(profile_name)
|
||||
updates['profile_name'] = profile_name
|
||||
|
||||
if display_name is not None:
|
||||
updates['display_name'] = display_name
|
||||
|
||||
if is_active is not None:
|
||||
updates['is_active'] = is_active
|
||||
|
||||
if is_default is not None:
|
||||
if is_default:
|
||||
await client.table('user_mcp_credential_profiles').update({
|
||||
'is_default': False
|
||||
}).eq('account_id', account_id).eq('mcp_qualified_name', existing_profile.mcp_qualified_name).execute()
|
||||
|
||||
updates['is_default'] = is_default
|
||||
|
||||
result = await client.table('user_mcp_credential_profiles').update(updates).eq(
|
||||
'profile_id', profile_id
|
||||
).eq('account_id', account_id).execute()
|
||||
|
||||
if not result.data:
|
||||
raise ProfileServiceError("Failed to update profile")
|
||||
|
||||
logger.info(f"Updated profile {profile_id}")
|
||||
|
||||
return await self.get_profile(account_id, profile_id)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (ProfileNotFoundError, ProfileServiceError, InvalidConfigError)):
|
||||
raise
|
||||
logger.error(f"Error updating profile {profile_id}: {str(e)}")
|
||||
raise ProfileServiceError(f"Failed to update profile: {str(e)}")
|
||||
|
||||
async def delete_profile(self, account_id: str, profile_id: str) -> bool:
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
result = await client.table('user_mcp_credential_profiles').delete().eq(
|
||||
'profile_id', profile_id
|
||||
).eq('account_id', account_id).execute()
|
||||
|
||||
success = bool(result.data)
|
||||
if success:
|
||||
logger.info(f"Deleted profile {profile_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting profile {profile_id}: {str(e)}")
|
||||
raise ProfileServiceError(f"Failed to delete profile: {str(e)}")
|
||||
|
||||
async def get_profile_by_app(
|
||||
self,
|
||||
account_id: UUID,
|
||||
account_id: str,
|
||||
app_slug: str,
|
||||
profile_name: Optional[str] = None
|
||||
) -> Optional[Profile]:
|
||||
app_slug_vo = AppSlug(app_slug)
|
||||
profile_name_vo = ProfileName(profile_name) if profile_name else None
|
||||
self._validate_app_slug(app_slug)
|
||||
|
||||
profile = await self._profile_repo.get_by_app_slug(account_id, app_slug_vo, profile_name_vo)
|
||||
if profile:
|
||||
return await self._connection_status_service.update_connection_status(profile)
|
||||
return None
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
mcp_qualified_name = f"pipedream:{app_slug}"
|
||||
query = client.table('user_mcp_credential_profiles').select('*').eq(
|
||||
'account_id', account_id
|
||||
).eq('mcp_qualified_name', mcp_qualified_name)
|
||||
|
||||
if profile_name:
|
||||
self._validate_profile_name(profile_name)
|
||||
query = query.eq('profile_name', profile_name)
|
||||
else:
|
||||
query = query.eq('is_default', True)
|
||||
|
||||
result = await query.execute()
|
||||
|
||||
if not result.data:
|
||||
return None
|
||||
|
||||
return await self._map_row_to_profile(result.data[0])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profile by app {app_slug}: {str(e)}")
|
||||
raise ProfileServiceError(f"Failed to get profile by app: {str(e)}")
|
||||
|
||||
|
||||
_profile_service = None
|
||||
|
||||
def get_profile_service() -> ProfileService:
|
||||
global _profile_service
|
||||
if _profile_service is None:
|
||||
_profile_service = ProfileService()
|
||||
return _profile_service
|
Loading…
Reference in New Issue