import os import re from dataclasses import dataclass, field from datetime import datetime, timezone, timedelta from typing import List, Optional, Dict, Any from enum import Enum import httpx from utils.logger import logger class AuthType(Enum): OAUTH = "oauth" API_KEY = "api_key" BASIC = "basic" NONE = "none" KEYS = "keys" CUSTOM = "custom" @classmethod def _missing_(cls, value): if isinstance(value, str): return cls.CUSTOM return super()._missing_(value) @dataclass class App: name: str slug: str description: str category: str logo_url: Optional[str] = None auth_type: AuthType = AuthType.OAUTH is_verified: bool = False url: Optional[str] = None tags: List[str] = field(default_factory=list) featured_weight: int = 0 def is_featured(self) -> bool: return self.featured_weight > 0 @dataclass class Connection: external_user_id: str app: App created_at: datetime updated_at: datetime is_active: bool = True class ConnectionServiceError(Exception): pass class AuthenticationError(ConnectionServiceError): pass class RateLimitError(ConnectionServiceError): pass class ExternalUserId: def __init__(self, value: str): if not value or not isinstance(value, str): raise ValueError("ExternalUserId must be a non-empty string") if len(value) > 255: raise ValueError("ExternalUserId must be less than 255 characters") self.value = value class AppSlug: def __init__(self, value: str): if not value or not isinstance(value, str): raise ValueError("AppSlug must be a non-empty string") if not re.match(r'^[a-z0-9_-]+$', value): raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores") self.value = value class ConnectionService: def __init__(self, logger=None): self._logger = logger or logger self.base_url = "https://api.pipedream.com/v1" self.session = None self.access_token = None self.token_expires_at = None async def _get_session(self) -> httpx.AsyncClient: if self.session is None or self.session.is_closed: self.session = httpx.AsyncClient( timeout=httpx.Timeout(30.0), headers={"User-Agent": "Suna-Pipedream-Client/1.0"} ) return self.session async def _ensure_access_token(self) -> str: if self.access_token and self.token_expires_at: if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)): return self.access_token else: self.access_token = None self.token_expires_at = None return await self._fetch_fresh_token() async def _fetch_fresh_token(self) -> str: project_id = os.getenv("PIPEDREAM_PROJECT_ID") client_id = os.getenv("PIPEDREAM_CLIENT_ID") client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET") if not all([project_id, client_id, client_secret]): raise AuthenticationError("Missing required environment variables") session = await self._get_session() try: response = await session.post( f"{self.base_url}/oauth/token", data={ "grant_type": "client_credentials", "client_id": client_id, "client_secret": client_secret } ) response.raise_for_status() data = response.json() self.access_token = data["access_token"] expires_in = data.get("expires_in", 3600) self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in) return self.access_token except httpx.HTTPStatusError as e: if e.response.status_code == 429: raise RateLimitError("Rate limit exceeded") raise AuthenticationError(f"Failed to obtain access token: {e}") async def _make_request(self, method: str, url: str, headers: Dict[str, str] = None, params: Dict[str, Any] = None, json: Dict[str, Any] = None, retry_count: int = 0) -> Dict[str, Any]: session = await self._get_session() access_token = await self._ensure_access_token() request_headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json" } if headers: request_headers.update(headers) try: if method == "GET": response = await session.get(url, headers=request_headers, params=params) elif method == "POST": response = await session.post(url, headers=request_headers, json=json) else: raise ValueError(f"Unsupported HTTP method: {method}") response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: if e.response.status_code == 429: raise RateLimitError("Rate limit exceeded") elif e.response.status_code == 401 and retry_count < 1: self.access_token = None self.token_expires_at = None return await self._make_request(method, url, headers=headers, params=params, json=json, retry_count=retry_count + 1) else: raise ConnectionServiceError(f"HTTP request failed: {e}") async def get_connections_for_user(self, external_user_id: ExternalUserId) -> List[Connection]: logger.debug(f"Getting connections for user: {external_user_id.value}") project_id = os.getenv("PIPEDREAM_PROJECT_ID") environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development") if not project_id: logger.error("Missing PIPEDREAM_PROJECT_ID environment variable") return [] url = f"{self.base_url}/connect/{project_id}/accounts" params = {"external_id": external_user_id.value} headers = {"X-PD-Environment": environment} try: data = await self._make_request("GET", url, headers=headers, params=params) connections = [] accounts = data.get("data", []) for account in accounts: app_data = account.get("app", {}) if app_data: try: auth_type_str = app_data.get("auth_type", "oauth") auth_type = AuthType(auth_type_str) except ValueError: logger.warning(f"Unknown auth type '{auth_type_str}', using CUSTOM") auth_type = AuthType.CUSTOM app = App( name=app_data.get("name", "Unknown"), slug=app_data.get("name_slug", ""), description=app_data.get("description", ""), category=app_data.get("category", "Other"), logo_url=app_data.get("img_src"), auth_type=auth_type, is_verified=app_data.get("verified", False), url=app_data.get("url"), tags=app_data.get("tags", []), featured_weight=app_data.get("featured_weight", 0) ) connection = Connection( external_user_id=external_user_id.value, app=app, created_at=datetime.utcnow(), updated_at=datetime.utcnow(), is_active=True ) connections.append(connection) logger.debug(f"Retrieved {len(connections)} connections for user: {external_user_id.value}") return connections except Exception as e: logger.error(f"Error getting connections: {str(e)}") return [] async def has_connection(self, external_user_id: ExternalUserId, app_slug: AppSlug) -> bool: connections = await self.get_connections_for_user(external_user_id) for connection in connections: if connection.app.slug == app_slug.value and connection.is_active: return True return False async def close(self): if self.session and not self.session.is_closed: await self.session.aclose() _connection_service = None def get_connection_service() -> ConnectionService: global _connection_service if _connection_service is None: _connection_service = ConnectionService() return _connection_service PipedreamException = ConnectionServiceError HttpClientException = ConnectionServiceError AuthenticationException = AuthenticationError RateLimitException = RateLimitError