suna/backend/pipedream/connection_service.py

261 lines
8.9 KiB
Python
Raw Normal View History

2025-07-30 20:27:26 +08:00
import os
import re
2025-07-30 22:03:43 +08:00
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
from typing import List, Optional, Dict, Any
from enum import Enum
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
import httpx
from utils.logger import logger
2025-07-30 20:27:26 +08:00
class AuthType(Enum):
OAUTH = "oauth"
API_KEY = "api_key"
BASIC = "basic"
NONE = "none"
KEYS = "keys"
CUSTOM = "custom"
@classmethod
def _missing_(cls, value):
if isinstance(value, str):
return cls.CUSTOM
return super()._missing_(value)
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
@dataclass
class App:
name: str
2025-07-30 22:03:43 +08:00
slug: str
2025-07-30 20:27:26 +08:00
description: str
category: str
logo_url: Optional[str] = None
auth_type: AuthType = AuthType.OAUTH
is_verified: bool = False
url: Optional[str] = None
tags: List[str] = field(default_factory=list)
featured_weight: int = 0
def is_featured(self) -> bool:
return self.featured_weight > 0
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
@dataclass
class Connection:
2025-07-30 22:03:43 +08:00
external_user_id: str
2025-07-30 20:27:26 +08:00
app: App
created_at: datetime
updated_at: datetime
is_active: bool = True
2025-07-30 22:03:43 +08:00
class ConnectionServiceError(Exception):
pass
class AuthenticationError(ConnectionServiceError):
pass
class RateLimitError(ConnectionServiceError):
pass
class ExternalUserId:
def __init__(self, value: str):
if not value or not isinstance(value, str):
raise ValueError("ExternalUserId must be a non-empty string")
if len(value) > 255:
raise ValueError("ExternalUserId must be less than 255 characters")
self.value = value
class AppSlug:
def __init__(self, value: str):
if not value or not isinstance(value, str):
raise ValueError("AppSlug must be a non-empty string")
if not re.match(r'^[a-z0-9_-]+$', value):
raise ValueError("AppSlug must contain only lowercase letters, numbers, hyphens, and underscores")
self.value = value
class ConnectionService:
def __init__(self, logger=None):
self._logger = logger or logger
2025-07-30 20:27:26 +08:00
self.base_url = "https://api.pipedream.com/v1"
2025-07-30 22:03:43 +08:00
self.session = None
self.access_token = None
self.token_expires_at = None
2025-07-30 20:27:26 +08:00
async def _get_session(self) -> httpx.AsyncClient:
if self.session is None or self.session.is_closed:
self.session = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
headers={"User-Agent": "Suna-Pipedream-Client/1.0"}
)
return self.session
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
async def _ensure_access_token(self) -> str:
if self.access_token and self.token_expires_at:
if datetime.utcnow() < (self.token_expires_at - timedelta(minutes=5)):
return self.access_token
else:
self.access_token = None
self.token_expires_at = None
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
return await self._fetch_fresh_token()
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
async def _fetch_fresh_token(self) -> str:
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
client_id = os.getenv("PIPEDREAM_CLIENT_ID")
client_secret = os.getenv("PIPEDREAM_CLIENT_SECRET")
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
if not all([project_id, client_id, client_secret]):
2025-07-30 22:03:43 +08:00
raise AuthenticationError("Missing required environment variables")
2025-07-30 20:27:26 +08:00
session = await self._get_session()
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
try:
response = await session.post(
f"{self.base_url}/oauth/token",
data={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret
}
)
response.raise_for_status()
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
data = response.json()
self.access_token = data["access_token"]
expires_in = data.get("expires_in", 3600)
self.token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
return self.access_token
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
2025-07-30 22:03:43 +08:00
raise RateLimitError("Rate limit exceeded")
raise AuthenticationError(f"Failed to obtain access token: {e}")
2025-07-30 20:27:26 +08:00
async def _make_request(self, method: str, url: str, headers: Dict[str, str] = None,
params: Dict[str, Any] = None, json: Dict[str, Any] = None,
retry_count: int = 0) -> Dict[str, Any]:
session = await self._get_session()
access_token = await self._ensure_access_token()
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
request_headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json"
}
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
if headers:
request_headers.update(headers)
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
try:
if method == "GET":
response = await session.get(url, headers=request_headers, params=params)
elif method == "POST":
response = await session.post(url, headers=request_headers, json=json)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
response.raise_for_status()
return response.json()
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
2025-07-30 22:03:43 +08:00
raise RateLimitError("Rate limit exceeded")
2025-07-30 20:27:26 +08:00
elif e.response.status_code == 401 and retry_count < 1:
2025-07-30 22:03:43 +08:00
self.access_token = None
self.token_expires_at = None
2025-07-30 20:27:26 +08:00
return await self._make_request(method, url, headers=headers, params=params,
json=json, retry_count=retry_count + 1)
else:
2025-07-30 22:03:43 +08:00
raise ConnectionServiceError(f"HTTP request failed: {e}")
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
async def get_connections_for_user(self, external_user_id: ExternalUserId) -> List[Connection]:
2025-08-17 10:10:56 +08:00
logger.debug(f"Getting connections for user: {external_user_id.value}")
2025-07-30 20:27:26 +08:00
project_id = os.getenv("PIPEDREAM_PROJECT_ID")
environment = os.getenv("PIPEDREAM_X_PD_ENVIRONMENT", "development")
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
if not project_id:
2025-07-30 22:03:43 +08:00
logger.error("Missing PIPEDREAM_PROJECT_ID environment variable")
return []
url = f"{self.base_url}/connect/{project_id}/accounts"
2025-07-30 20:27:26 +08:00
params = {"external_id": external_user_id.value}
headers = {"X-PD-Environment": environment}
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
try:
2025-07-30 22:03:43 +08:00
data = await self._make_request("GET", url, headers=headers, params=params)
2025-07-30 20:27:26 +08:00
connections = []
accounts = data.get("data", [])
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
for account in accounts:
app_data = account.get("app", {})
if app_data:
try:
auth_type_str = app_data.get("auth_type", "oauth")
auth_type = AuthType(auth_type_str)
except ValueError:
2025-07-30 22:03:43 +08:00
logger.warning(f"Unknown auth type '{auth_type_str}', using CUSTOM")
2025-07-30 20:27:26 +08:00
auth_type = AuthType.CUSTOM
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
app = App(
name=app_data.get("name", "Unknown"),
2025-07-30 22:03:43 +08:00
slug=app_data.get("name_slug", ""),
2025-07-30 20:27:26 +08:00
description=app_data.get("description", ""),
category=app_data.get("category", "Other"),
logo_url=app_data.get("img_src"),
auth_type=auth_type,
is_verified=app_data.get("verified", False),
url=app_data.get("url"),
tags=app_data.get("tags", []),
featured_weight=app_data.get("featured_weight", 0)
)
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
connection = Connection(
2025-07-30 22:03:43 +08:00
external_user_id=external_user_id.value,
2025-07-30 20:27:26 +08:00
app=app,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
is_active=True
)
connections.append(connection)
2025-07-30 22:03:43 +08:00
2025-08-17 10:10:56 +08:00
logger.debug(f"Retrieved {len(connections)} connections for user: {external_user_id.value}")
2025-07-30 20:27:26 +08:00
return connections
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
except Exception as e:
2025-07-30 22:03:43 +08:00
logger.error(f"Error getting connections: {str(e)}")
2025-07-30 20:27:26 +08:00
return []
async def has_connection(self, external_user_id: ExternalUserId, app_slug: AppSlug) -> bool:
2025-07-30 22:03:43 +08:00
connections = await self.get_connections_for_user(external_user_id)
2025-07-30 20:27:26 +08:00
for connection in connections:
2025-07-30 22:03:43 +08:00
if connection.app.slug == app_slug.value and connection.is_active:
2025-07-30 20:27:26 +08:00
return True
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
return False
2025-07-30 22:03:43 +08:00
2025-07-30 20:27:26 +08:00
async def close(self):
2025-07-30 22:03:43 +08:00
if self.session and not self.session.is_closed:
await self.session.aclose()
_connection_service = None
def get_connection_service() -> ConnectionService:
global _connection_service
if _connection_service is None:
_connection_service = ConnectionService()
return _connection_service
PipedreamException = ConnectionServiceError
HttpClientException = ConnectionServiceError
AuthenticationException = AuthenticationError
RateLimitException = RateLimitError