mirror of https://github.com/kortix-ai/suna.git
263 lines
10 KiB
Python
263 lines
10 KiB
Python
from abc import abstractmethod
|
|
from typing import List, Optional, Dict, Any
|
|
import base64
|
|
from datetime import datetime, timezone
|
|
|
|
from .base import Repository
|
|
from ..domain.entities import MCPCredentialProfile, ProfileRequest
|
|
from ..protocols import DatabaseConnection, Logger, EncryptionService
|
|
|
|
|
|
class ProfileRepository(Repository[MCPCredentialProfile]):
|
|
@abstractmethod
|
|
async def find_by_account_and_qualified_name(
|
|
self, account_id: str, mcp_qualified_name: str
|
|
) -> List[MCPCredentialProfile]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def find_by_account(self, account_id: str) -> List[MCPCredentialProfile]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def find_default_profile(
|
|
self, account_id: str, mcp_qualified_name: str
|
|
) -> Optional[MCPCredentialProfile]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def store_profile(self, request: ProfileRequest) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def set_default(self, account_id: str, profile_id: str) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def deactivate_profile(self, account_id: str, profile_id: str) -> bool:
|
|
pass
|
|
|
|
|
|
class SupabaseProfileRepository(ProfileRepository):
|
|
def __init__(self, db: DatabaseConnection, encryption: EncryptionService, logger: Logger):
|
|
self._db = db
|
|
self._encryption = encryption
|
|
self._logger = logger
|
|
|
|
async def find_by_id(self, profile_id: str) -> Optional[MCPCredentialProfile]:
|
|
try:
|
|
client = await self._db.client
|
|
result = await client.table('user_mcp_credential_profiles').select('*')\
|
|
.eq('profile_id', profile_id)\
|
|
.eq('is_active', True)\
|
|
.execute()
|
|
|
|
if not result.data:
|
|
return None
|
|
|
|
profile = self._map_to_profile(result.data[0])
|
|
|
|
await client.table('user_mcp_credential_profiles')\
|
|
.update({'last_used_at': datetime.now(timezone.utc).isoformat()})\
|
|
.eq('profile_id', profile_id)\
|
|
.execute()
|
|
|
|
return profile
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error finding profile {profile_id}: {str(e)}")
|
|
return None
|
|
|
|
async def save(self, profile: MCPCredentialProfile) -> MCPCredentialProfile:
|
|
raise NotImplementedError("Use store_profile method instead")
|
|
|
|
async def delete(self, profile_id: str) -> bool:
|
|
try:
|
|
client = await self._db.client
|
|
result = await client.table('user_mcp_credential_profiles')\
|
|
.update({'is_active': False})\
|
|
.eq('profile_id', profile_id)\
|
|
.execute()
|
|
|
|
return len(result.data) > 0
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error deleting profile {profile_id}: {str(e)}")
|
|
return False
|
|
|
|
async def find_by_account_and_qualified_name(
|
|
self, account_id: str, mcp_qualified_name: str
|
|
) -> List[MCPCredentialProfile]:
|
|
try:
|
|
client = await self._db.client
|
|
result = await client.table('user_mcp_credential_profiles').select('*')\
|
|
.eq('account_id', account_id)\
|
|
.eq('mcp_qualified_name', mcp_qualified_name)\
|
|
.eq('is_active', True)\
|
|
.order('is_default', desc=True)\
|
|
.order('created_at', desc=False)\
|
|
.execute()
|
|
|
|
profiles = []
|
|
for profile_data in result.data:
|
|
try:
|
|
profile = self._map_to_profile(profile_data)
|
|
profiles.append(profile)
|
|
except Exception as e:
|
|
self._logger.error(f"Failed to decrypt profile {profile_data['profile_id']}: {e}")
|
|
continue
|
|
|
|
return profiles
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error retrieving profiles: {str(e)}")
|
|
return []
|
|
|
|
async def find_by_account(self, account_id: str) -> List[MCPCredentialProfile]:
|
|
try:
|
|
client = await self._db.client
|
|
result = await client.table('user_mcp_credential_profiles').select('*')\
|
|
.eq('account_id', account_id)\
|
|
.eq('is_active', True)\
|
|
.order('mcp_qualified_name')\
|
|
.order('is_default', desc=True)\
|
|
.order('created_at', desc=False)\
|
|
.execute()
|
|
|
|
profiles = []
|
|
for profile_data in result.data:
|
|
try:
|
|
profile = self._map_to_profile(profile_data)
|
|
profiles.append(profile)
|
|
except Exception as e:
|
|
self._logger.error(f"Failed to decrypt profile {profile_data['profile_id']}: {e}")
|
|
continue
|
|
|
|
return profiles
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error retrieving all profiles: {str(e)}")
|
|
return []
|
|
|
|
async def find_default_profile(
|
|
self, account_id: str, mcp_qualified_name: str
|
|
) -> Optional[MCPCredentialProfile]:
|
|
profiles = await self.find_by_account_and_qualified_name(account_id, mcp_qualified_name)
|
|
|
|
for profile in profiles:
|
|
if profile.is_default:
|
|
return profile
|
|
|
|
return profiles[0] if profiles else None
|
|
|
|
async def store_profile(self, request: ProfileRequest) -> str:
|
|
try:
|
|
encrypted_config, config_hash = self._encryption.encrypt_config(request.config)
|
|
|
|
client = await self._db.client
|
|
encoded_config = base64.b64encode(encrypted_config).decode('utf-8')
|
|
|
|
if not request.is_default:
|
|
existing_profiles = await self.find_by_account_and_qualified_name(
|
|
request.account_id, request.mcp_qualified_name
|
|
)
|
|
if not existing_profiles:
|
|
request.is_default = True
|
|
|
|
result = await client.table('user_mcp_credential_profiles').upsert({
|
|
'account_id': request.account_id,
|
|
'mcp_qualified_name': request.mcp_qualified_name,
|
|
'profile_name': request.profile_name,
|
|
'display_name': request.display_name,
|
|
'encrypted_config': encoded_config,
|
|
'config_hash': config_hash,
|
|
'is_active': True,
|
|
'is_default': request.is_default,
|
|
'updated_at': datetime.now(timezone.utc).isoformat()
|
|
}, on_conflict='account_id,mcp_qualified_name,profile_name').execute()
|
|
|
|
if not result.data:
|
|
raise ValueError("Failed to store profile")
|
|
|
|
profile_id = result.data[0]['profile_id']
|
|
self._logger.info(f"Successfully stored profile {profile_id}")
|
|
return profile_id
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error storing profile: {str(e)}")
|
|
raise
|
|
|
|
async def set_default(self, account_id: str, profile_id: str) -> bool:
|
|
try:
|
|
client = await self._db.client
|
|
|
|
profile = await self.find_by_id(profile_id)
|
|
if not profile or profile.account_id != account_id:
|
|
return False
|
|
|
|
result = await client.table('user_mcp_credential_profiles')\
|
|
.update({'is_default': True})\
|
|
.eq('profile_id', profile_id)\
|
|
.eq('account_id', account_id)\
|
|
.execute()
|
|
|
|
return len(result.data) > 0
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error setting default profile: {str(e)}")
|
|
return False
|
|
|
|
async def deactivate_profile(self, account_id: str, profile_id: str) -> bool:
|
|
try:
|
|
client = await self._db.client
|
|
|
|
profile = await self.find_by_id(profile_id)
|
|
if not profile or profile.account_id != account_id:
|
|
return False
|
|
|
|
if profile.is_default:
|
|
other_profiles = await self.find_by_account_and_qualified_name(
|
|
account_id, profile.mcp_qualified_name
|
|
)
|
|
other_active_profiles = [p for p in other_profiles if p.profile_id != profile_id]
|
|
|
|
if other_active_profiles:
|
|
await self.set_default(account_id, other_active_profiles[0].profile_id)
|
|
|
|
result = await client.table('user_mcp_credential_profiles')\
|
|
.update({'is_active': False})\
|
|
.eq('profile_id', profile_id)\
|
|
.eq('account_id', account_id)\
|
|
.execute()
|
|
|
|
return len(result.data) > 0
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error deactivating profile: {str(e)}")
|
|
return False
|
|
|
|
def _map_to_profile(self, data: Dict[str, Any]) -> MCPCredentialProfile:
|
|
encrypted_config = data['encrypted_config']
|
|
if isinstance(encrypted_config, str):
|
|
encrypted_config_bytes = base64.b64decode(encrypted_config.encode('utf-8'))
|
|
else:
|
|
encrypted_config_bytes = encrypted_config
|
|
|
|
config = self._encryption.decrypt_config(
|
|
encrypted_config_bytes,
|
|
data['config_hash']
|
|
)
|
|
|
|
return MCPCredentialProfile(
|
|
profile_id=data['profile_id'],
|
|
account_id=data['account_id'],
|
|
mcp_qualified_name=data['mcp_qualified_name'],
|
|
profile_name=data['profile_name'],
|
|
display_name=data['display_name'],
|
|
config=config,
|
|
is_active=data['is_active'],
|
|
is_default=data['is_default'],
|
|
last_used_at=data.get('last_used_at'),
|
|
created_at=data.get('created_at'),
|
|
updated_at=data.get('updated_at')
|
|
) |