suna/backend/pipedream/profile_service.py

440 lines
17 KiB
Python
Raw Normal View History

2025-07-30 20:27:26 +08:00
import json
import hashlib
import re
2025-07-30 22:03:43 +08:00
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any
from uuid import uuid4, UUID
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
from services.supabase import DBConnection
from utils.logger import logger
2025-07-30 20:27:26 +08:00
@dataclass
class Profile:
2025-07-30 22:03:43 +08:00
profile_id: str
account_id: str
2025-07-30 20:27:26 +08:00
mcp_qualified_name: str
2025-07-30 22:03:43 +08:00
profile_name: str
2025-07-30 20:27:26 +08:00
display_name: str
2025-07-30 22:03:43 +08:00
encrypted_config: str
config_hash: str
app_slug: str
2025-07-30 20:27:26 +08:00
app_name: str
2025-07-30 22:03:43 +08:00
external_user_id: str
2025-07-30 20:27:26 +08:00
enabled_tools: List[str] = field(default_factory=list)
is_active: bool = True
is_default: bool = False
is_connected: bool = False
2025-07-30 22:03:43 +08:00
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
2025-07-30 20:27:26 +08:00
last_used_at: Optional[datetime] = None
2025-07-30 22:03:43 +08:00
description: Optional[str] = None
oauth_app_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
'profile_id': self.profile_id,
'account_id': self.account_id,
'mcp_qualified_name': self.mcp_qualified_name,
'profile_name': self.profile_name,
'display_name': self.display_name,
'encrypted_config': self.encrypted_config,
'config_hash': self.config_hash,
'app_slug': self.app_slug,
'app_name': self.app_name,
'external_user_id': self.external_user_id,
'enabled_tools': self.enabled_tools,
'is_active': self.is_active,
'is_default': self.is_default,
'is_connected': self.is_connected,
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat(),
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
'description': self.description,
'oauth_app_id': self.oauth_app_id
}
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
class ProfileServiceError(Exception):
2025-07-30 20:27:26 +08:00
pass
2025-07-30 22:03:43 +08:00
class ProfileNotFoundError(ProfileServiceError):
pass
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
class ProfileAlreadyExistsError(ProfileServiceError):
pass
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
class InvalidConfigError(ProfileServiceError):
pass
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
class EncryptionError(ProfileServiceError):
pass
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
class ProfileService:
def __init__(self):
self.db = DBConnection()
self._connection_service = None
async def _get_client(self):
return await self.db.client
def _get_connection_service(self):
if self._connection_service is None:
from .connection_service import ConnectionService
from utils.logger import logger
self._connection_service = ConnectionService(logger=logger)
return self._connection_service
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
async def _check_connection_status(self, external_user_id: str, app_slug: str) -> bool:
2025-07-30 20:27:26 +08:00
try:
2025-07-30 22:03:43 +08:00
from .connection_service import ExternalUserId, AppSlug
connection_service = self._get_connection_service()
return await connection_service.has_connection(
ExternalUserId(external_user_id),
AppSlug(app_slug)
)
2025-07-30 20:27:26 +08:00
except Exception as e:
2025-07-30 22:03:43 +08:00
logger.error(f"Error checking connection status: {str(e)}")
return False
def _validate_app_slug(self, app_slug: str) -> None:
if not app_slug or not isinstance(app_slug, str):
raise InvalidConfigError("App slug must be a non-empty string")
if not re.match(r'^[a-z0-9_-]+$', app_slug):
raise InvalidConfigError("App slug must contain only lowercase letters, numbers, hyphens, and underscores")
def _validate_profile_name(self, profile_name: str) -> None:
if not profile_name or not isinstance(profile_name, str):
raise InvalidConfigError("Profile name must be a non-empty string")
if len(profile_name) > 100:
raise InvalidConfigError("Profile name must be less than 100 characters")
def _generate_external_user_id(self, account_id: str, app_slug: str, profile_name: str) -> str:
combined = f"{account_id}:{app_slug}:{profile_name}"
2025-07-30 20:27:26 +08:00
hash_value = hashlib.sha256(combined.encode()).hexdigest()[:16]
2025-07-30 22:03:43 +08:00
return f"suna_{hash_value}"
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
def _generate_config_hash(self, config_json: str) -> str:
return hashlib.sha256(config_json.encode()).hexdigest()
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
def _encrypt_config(self, config_json: str) -> str:
try:
from utils.encryption import encrypt_data
return encrypt_data(config_json)
except Exception as e:
raise EncryptionError(f"Failed to encrypt config: {str(e)}")
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
def _decrypt_config(self, encrypted_config: str) -> Dict[str, Any]:
2025-07-30 20:27:26 +08:00
try:
2025-07-30 22:03:43 +08:00
from utils.encryption import decrypt_data
decrypted_json = decrypt_data(encrypted_config)
return json.loads(decrypted_json)
2025-07-30 20:27:26 +08:00
except Exception as e:
2025-07-30 22:03:43 +08:00
raise EncryptionError(f"Failed to decrypt config: {str(e)}")
def _build_config(self, app_slug: str, app_name: str, external_user_id: str,
enabled_tools: List[str], oauth_app_id: Optional[str] = None,
description: Optional[str] = None) -> Dict[str, Any]:
return {
"app_slug": app_slug,
"app_name": app_name,
"external_user_id": external_user_id,
"enabled_tools": enabled_tools,
"oauth_app_id": oauth_app_id,
"description": description
}
async def _map_row_to_profile(self, row: Dict[str, Any]) -> Profile:
2025-07-30 20:27:26 +08:00
try:
2025-07-30 22:03:43 +08:00
config = self._decrypt_config(row['encrypted_config'])
except Exception:
2025-07-30 20:27:26 +08:00
config = {
2025-07-30 22:03:43 +08:00
"app_slug": "unknown",
"app_name": "Unknown App",
"external_user_id": "unknown",
"enabled_tools": [],
"oauth_app_id": None,
"description": None
2025-07-30 20:27:26 +08:00
}
2025-07-30 22:03:43 +08:00
is_connected = await self._check_connection_status(config['external_user_id'], config['app_slug'])
return Profile(
profile_id=row['profile_id'],
account_id=row['account_id'],
mcp_qualified_name=row['mcp_qualified_name'],
profile_name=row['profile_name'],
display_name=row['display_name'],
encrypted_config=row['encrypted_config'],
config_hash=row['config_hash'],
app_slug=config['app_slug'],
app_name=config['app_name'],
external_user_id=config['external_user_id'],
enabled_tools=config.get('enabled_tools', []),
is_active=row['is_active'],
is_default=row['is_default'],
is_connected=is_connected,
created_at=datetime.fromisoformat(row['created_at'].replace('Z', '+00:00')) if isinstance(row['created_at'], str) else row['created_at'],
updated_at=datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00')) if isinstance(row['updated_at'], str) else row['updated_at'],
last_used_at=datetime.fromisoformat(row['last_used_at'].replace('Z', '+00:00')) if row.get('last_used_at') and isinstance(row['last_used_at'], str) else row.get('last_used_at'),
description=config.get('description'),
oauth_app_id=config.get('oauth_app_id')
)
async def create_profile(
self,
account_id: str,
profile_name: str,
app_slug: str,
app_name: str,
description: Optional[str] = None,
is_default: bool = False,
oauth_app_id: Optional[str] = None,
enabled_tools: Optional[List[str]] = None,
external_user_id: Optional[str] = None
) -> Profile:
self._validate_app_slug(app_slug)
self._validate_profile_name(profile_name)
if enabled_tools is None:
enabled_tools = []
if not external_user_id:
external_user_id = self._generate_external_user_id(account_id, app_slug, profile_name)
config = self._build_config(app_slug, app_name, external_user_id, enabled_tools, oauth_app_id, description)
config_json = json.dumps(config, sort_keys=True)
encrypted_config = self._encrypt_config(config_json)
config_hash = self._generate_config_hash(config_json)
mcp_qualified_name = f"pipedream:{app_slug}"
profile_id = str(uuid4())
now = datetime.now(timezone.utc)
client = await self._get_client()
try:
existing = await client.table('user_mcp_credential_profiles').select('profile_id').eq(
'account_id', account_id
).eq('mcp_qualified_name', mcp_qualified_name).eq('profile_name', profile_name).execute()
if existing.data:
raise ProfileAlreadyExistsError(f"Profile '{profile_name}' already exists for app '{app_slug}'")
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
if is_default:
await client.table('user_mcp_credential_profiles').update({
'is_default': False
}).eq('account_id', account_id).eq('mcp_qualified_name', mcp_qualified_name).execute()
2025-07-30 20:27:26 +08:00
result = await client.table('user_mcp_credential_profiles').insert({
2025-07-30 22:03:43 +08:00
'profile_id': profile_id,
'account_id': account_id,
'mcp_qualified_name': mcp_qualified_name,
'profile_name': profile_name,
'display_name': profile_name,
2025-07-30 20:27:26 +08:00
'encrypted_config': encrypted_config,
2025-07-30 22:03:43 +08:00
'config_hash': config_hash,
'is_active': True,
'is_default': is_default,
'created_at': now.isoformat(),
'updated_at': now.isoformat()
2025-07-30 20:27:26 +08:00
}).execute()
2025-07-30 22:03:43 +08:00
if not result.data:
raise ProfileServiceError("Failed to create profile")
logger.info(f"Created profile {profile_id} for app {app_slug}")
return Profile(
profile_id=profile_id,
account_id=account_id,
mcp_qualified_name=mcp_qualified_name,
profile_name=profile_name,
display_name=profile_name,
encrypted_config=encrypted_config,
config_hash=config_hash,
app_slug=app_slug,
app_name=app_name,
external_user_id=external_user_id,
enabled_tools=enabled_tools,
is_active=True,
is_default=is_default,
is_connected=False,
created_at=now,
updated_at=now,
description=description,
oauth_app_id=oauth_app_id
)
2025-07-30 20:27:26 +08:00
except Exception as e:
2025-07-30 22:03:43 +08:00
if isinstance(e, (ProfileAlreadyExistsError, ProfileServiceError)):
raise
logger.error(f"Error creating profile: {str(e)}")
raise ProfileServiceError(f"Failed to create profile: {str(e)}")
async def get_profile(self, account_id: str, profile_id: str) -> Optional[Profile]:
client = await self._get_client()
2025-07-30 20:27:26 +08:00
try:
result = await client.table('user_mcp_credential_profiles').select('*').eq(
2025-07-30 22:03:43 +08:00
'account_id', account_id
).eq('profile_id', profile_id).execute()
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
if not result.data:
return None
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
return await self._map_row_to_profile(result.data[0])
2025-07-30 20:27:26 +08:00
except Exception as e:
2025-07-30 22:03:43 +08:00
logger.error(f"Error getting profile {profile_id}: {str(e)}")
raise ProfileServiceError(f"Failed to get profile: {str(e)}")
async def get_profiles(
self,
account_id: str,
app_slug: Optional[str] = None,
is_active: Optional[bool] = None
) -> List[Profile]:
client = await self._get_client()
2025-07-30 20:27:26 +08:00
try:
2025-07-30 22:03:43 +08:00
query = client.table('user_mcp_credential_profiles').select('*').eq('account_id', account_id)
2025-07-30 20:27:26 +08:00
if app_slug:
2025-07-30 22:03:43 +08:00
mcp_qualified_name = f"pipedream:{app_slug}"
2025-07-30 20:27:26 +08:00
query = query.eq('mcp_qualified_name', mcp_qualified_name)
if is_active is not None:
query = query.eq('is_active', is_active)
2025-07-30 22:03:43 +08:00
result = await query.execute()
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
import asyncio
profiles = await asyncio.gather(*[self._map_row_to_profile(row) for row in result.data])
2025-07-30 20:27:26 +08:00
return profiles
except Exception as e:
2025-07-30 22:03:43 +08:00
logger.error(f"Error getting profiles: {str(e)}")
raise ProfileServiceError(f"Failed to get profiles: {str(e)}")
async def update_profile(
self,
account_id: str,
profile_id: str,
profile_name: Optional[str] = None,
display_name: Optional[str] = None,
is_active: Optional[bool] = None,
is_default: Optional[bool] = None,
enabled_tools: Optional[List[str]] = None
) -> Profile:
existing_profile = await self.get_profile(account_id, profile_id)
if not existing_profile:
raise ProfileNotFoundError(f"Profile {profile_id} not found")
client = await self._get_client()
updates = {'updated_at': datetime.now(timezone.utc).isoformat()}
2025-07-30 20:27:26 +08:00
try:
2025-07-30 22:03:43 +08:00
if enabled_tools is not None:
config = self._decrypt_config(existing_profile.encrypted_config)
config['enabled_tools'] = enabled_tools
config_json = json.dumps(config, sort_keys=True)
updates['encrypted_config'] = self._encrypt_config(config_json)
updates['config_hash'] = self._generate_config_hash(config_json)
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
if profile_name is not None:
self._validate_profile_name(profile_name)
updates['profile_name'] = profile_name
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
if display_name is not None:
updates['display_name'] = display_name
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
if is_active is not None:
updates['is_active'] = is_active
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
if is_default is not None:
if is_default:
await client.table('user_mcp_credential_profiles').update({
'is_default': False
}).eq('account_id', account_id).eq('mcp_qualified_name', existing_profile.mcp_qualified_name).execute()
updates['is_default'] = is_default
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
result = await client.table('user_mcp_credential_profiles').update(updates).eq(
'profile_id', profile_id
).eq('account_id', account_id).execute()
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
if not result.data:
raise ProfileServiceError("Failed to update profile")
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
logger.info(f"Updated profile {profile_id}")
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
return await self.get_profile(account_id, profile_id)
2025-07-30 20:27:26 +08:00
except Exception as e:
2025-07-30 22:03:43 +08:00
if isinstance(e, (ProfileNotFoundError, ProfileServiceError, InvalidConfigError)):
raise
logger.error(f"Error updating profile {profile_id}: {str(e)}")
raise ProfileServiceError(f"Failed to update profile: {str(e)}")
async def delete_profile(self, account_id: str, profile_id: str) -> bool:
client = await self._get_client()
2025-07-30 20:27:26 +08:00
try:
2025-07-30 22:03:43 +08:00
result = await client.table('user_mcp_credential_profiles').delete().eq(
'profile_id', profile_id
).eq('account_id', account_id).execute()
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
success = bool(result.data)
if success:
logger.info(f"Deleted profile {profile_id}")
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
return success
2025-07-30 20:27:26 +08:00
except Exception as e:
2025-07-30 22:03:43 +08:00
logger.error(f"Error deleting profile {profile_id}: {str(e)}")
raise ProfileServiceError(f"Failed to delete profile: {str(e)}")
async def get_profile_by_app(
2025-07-30 20:27:26 +08:00
self,
2025-07-30 22:03:43 +08:00
account_id: str,
2025-07-30 20:27:26 +08:00
app_slug: str,
2025-07-30 22:03:43 +08:00
profile_name: Optional[str] = None
) -> Optional[Profile]:
self._validate_app_slug(app_slug)
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
client = await self._get_client()
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
try:
mcp_qualified_name = f"pipedream:{app_slug}"
query = client.table('user_mcp_credential_profiles').select('*').eq(
'account_id', account_id
).eq('mcp_qualified_name', mcp_qualified_name)
if profile_name:
self._validate_profile_name(profile_name)
query = query.eq('profile_name', profile_name)
2025-07-30 20:27:26 +08:00
else:
2025-07-30 22:03:43 +08:00
query = query.eq('is_default', True)
result = await query.execute()
if not result.data:
return None
return await self._map_row_to_profile(result.data[0])
except Exception as e:
logger.error(f"Error getting profile by app {app_slug}: {str(e)}")
raise ProfileServiceError(f"Failed to get profile by app: {str(e)}")
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
_profile_service = None
2025-07-30 20:27:26 +08:00
2025-07-30 22:03:43 +08:00
def get_profile_service() -> ProfileService:
global _profile_service
if _profile_service is None:
_profile_service = ProfileService()
return _profile_service