suna/backend/credentials/repositories/profile_repository.py

263 lines
10 KiB
Python

from abc import abstractmethod
from typing import List, Optional, Dict, Any
import base64
from datetime import datetime, timezone
from .base import Repository
from ..domain.entities import MCPCredentialProfile, ProfileRequest
from ..protocols import DatabaseConnection, Logger, EncryptionService
class ProfileRepository(Repository[MCPCredentialProfile]):
@abstractmethod
async def find_by_account_and_qualified_name(
self, account_id: str, mcp_qualified_name: str
) -> List[MCPCredentialProfile]:
pass
@abstractmethod
async def find_by_account(self, account_id: str) -> List[MCPCredentialProfile]:
pass
@abstractmethod
async def find_default_profile(
self, account_id: str, mcp_qualified_name: str
) -> Optional[MCPCredentialProfile]:
pass
@abstractmethod
async def store_profile(self, request: ProfileRequest) -> str:
pass
@abstractmethod
async def set_default(self, account_id: str, profile_id: str) -> bool:
pass
@abstractmethod
async def deactivate_profile(self, account_id: str, profile_id: str) -> bool:
pass
class SupabaseProfileRepository(ProfileRepository):
def __init__(self, db: DatabaseConnection, encryption: EncryptionService, logger: Logger):
self._db = db
self._encryption = encryption
self._logger = logger
async def find_by_id(self, profile_id: str) -> Optional[MCPCredentialProfile]:
try:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').select('*')\
.eq('profile_id', profile_id)\
.eq('is_active', True)\
.execute()
if not result.data:
return None
profile = self._map_to_profile(result.data[0])
await client.table('user_mcp_credential_profiles')\
.update({'last_used_at': datetime.now(timezone.utc).isoformat()})\
.eq('profile_id', profile_id)\
.execute()
return profile
except Exception as e:
self._logger.error(f"Error finding profile {profile_id}: {str(e)}")
return None
async def save(self, profile: MCPCredentialProfile) -> MCPCredentialProfile:
raise NotImplementedError("Use store_profile method instead")
async def delete(self, profile_id: str) -> bool:
try:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles')\
.update({'is_active': False})\
.eq('profile_id', profile_id)\
.execute()
return len(result.data) > 0
except Exception as e:
self._logger.error(f"Error deleting profile {profile_id}: {str(e)}")
return False
async def find_by_account_and_qualified_name(
self, account_id: str, mcp_qualified_name: str
) -> List[MCPCredentialProfile]:
try:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').select('*')\
.eq('account_id', account_id)\
.eq('mcp_qualified_name', mcp_qualified_name)\
.eq('is_active', True)\
.order('is_default', desc=True)\
.order('created_at', desc=False)\
.execute()
profiles = []
for profile_data in result.data:
try:
profile = self._map_to_profile(profile_data)
profiles.append(profile)
except Exception as e:
self._logger.error(f"Failed to decrypt profile {profile_data['profile_id']}: {e}")
continue
return profiles
except Exception as e:
self._logger.error(f"Error retrieving profiles: {str(e)}")
return []
async def find_by_account(self, account_id: str) -> List[MCPCredentialProfile]:
try:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').select('*')\
.eq('account_id', account_id)\
.eq('is_active', True)\
.order('mcp_qualified_name')\
.order('is_default', desc=True)\
.order('created_at', desc=False)\
.execute()
profiles = []
for profile_data in result.data:
try:
profile = self._map_to_profile(profile_data)
profiles.append(profile)
except Exception as e:
self._logger.error(f"Failed to decrypt profile {profile_data['profile_id']}: {e}")
continue
return profiles
except Exception as e:
self._logger.error(f"Error retrieving all profiles: {str(e)}")
return []
async def find_default_profile(
self, account_id: str, mcp_qualified_name: str
) -> Optional[MCPCredentialProfile]:
profiles = await self.find_by_account_and_qualified_name(account_id, mcp_qualified_name)
for profile in profiles:
if profile.is_default:
return profile
return profiles[0] if profiles else None
async def store_profile(self, request: ProfileRequest) -> str:
try:
encrypted_config, config_hash = self._encryption.encrypt_config(request.config)
client = await self._db.client
encoded_config = base64.b64encode(encrypted_config).decode('utf-8')
if not request.is_default:
existing_profiles = await self.find_by_account_and_qualified_name(
request.account_id, request.mcp_qualified_name
)
if not existing_profiles:
request.is_default = True
result = await client.table('user_mcp_credential_profiles').upsert({
'account_id': request.account_id,
'mcp_qualified_name': request.mcp_qualified_name,
'profile_name': request.profile_name,
'display_name': request.display_name,
'encrypted_config': encoded_config,
'config_hash': config_hash,
'is_active': True,
'is_default': request.is_default,
'updated_at': datetime.now(timezone.utc).isoformat()
}, on_conflict='account_id,mcp_qualified_name,profile_name').execute()
if not result.data:
raise ValueError("Failed to store profile")
profile_id = result.data[0]['profile_id']
self._logger.info(f"Successfully stored profile {profile_id}")
return profile_id
except Exception as e:
self._logger.error(f"Error storing profile: {str(e)}")
raise
async def set_default(self, account_id: str, profile_id: str) -> bool:
try:
client = await self._db.client
profile = await self.find_by_id(profile_id)
if not profile or profile.account_id != account_id:
return False
result = await client.table('user_mcp_credential_profiles')\
.update({'is_default': True})\
.eq('profile_id', profile_id)\
.eq('account_id', account_id)\
.execute()
return len(result.data) > 0
except Exception as e:
self._logger.error(f"Error setting default profile: {str(e)}")
return False
async def deactivate_profile(self, account_id: str, profile_id: str) -> bool:
try:
client = await self._db.client
profile = await self.find_by_id(profile_id)
if not profile or profile.account_id != account_id:
return False
if profile.is_default:
other_profiles = await self.find_by_account_and_qualified_name(
account_id, profile.mcp_qualified_name
)
other_active_profiles = [p for p in other_profiles if p.profile_id != profile_id]
if other_active_profiles:
await self.set_default(account_id, other_active_profiles[0].profile_id)
result = await client.table('user_mcp_credential_profiles')\
.update({'is_active': False})\
.eq('profile_id', profile_id)\
.eq('account_id', account_id)\
.execute()
return len(result.data) > 0
except Exception as e:
self._logger.error(f"Error deactivating profile: {str(e)}")
return False
def _map_to_profile(self, data: Dict[str, Any]) -> MCPCredentialProfile:
encrypted_config = data['encrypted_config']
if isinstance(encrypted_config, str):
encrypted_config_bytes = base64.b64decode(encrypted_config.encode('utf-8'))
else:
encrypted_config_bytes = encrypted_config
config = self._encryption.decrypt_config(
encrypted_config_bytes,
data['config_hash']
)
return MCPCredentialProfile(
profile_id=data['profile_id'],
account_id=data['account_id'],
mcp_qualified_name=data['mcp_qualified_name'],
profile_name=data['profile_name'],
display_name=data['display_name'],
config=config,
is_active=data['is_active'],
is_default=data['is_default'],
last_used_at=data.get('last_used_at'),
created_at=data.get('created_at'),
updated_at=data.get('updated_at')
)