suna/backend/credentials/repositories/credential_repository.py

193 lines
7.1 KiB
Python

from abc import abstractmethod
from typing import List, Optional, Dict, Any
import base64
from datetime import datetime, timezone
from .base import Repository
from ..domain.entities import MCPCredential, CredentialRequest
from ..protocols import DatabaseConnection, Logger, EncryptionService
class CredentialRepository(Repository[MCPCredential]):
@abstractmethod
async def find_by_account_and_qualified_name(
self, account_id: str, mcp_qualified_name: str
) -> Optional[MCPCredential]:
pass
@abstractmethod
async def find_by_account(
self, account_id: str
) -> List[MCPCredential]:
pass
@abstractmethod
async def store_credential(self, request: CredentialRequest) -> str:
pass
@abstractmethod
async def deactivate_credential(
self, account_id: str, mcp_qualified_name: str
) -> bool:
pass
class SupabaseCredentialRepository(CredentialRepository):
def __init__(self, db: DatabaseConnection, encryption: EncryptionService, logger: Logger):
self._db = db
self._encryption = encryption
self._logger = logger
async def find_by_id(self, credential_id: str) -> Optional[MCPCredential]:
try:
client = await self._db.client
result = await client.table('user_mcp_credentials').select('*')\
.eq('credential_id', credential_id)\
.eq('is_active', True)\
.execute()
if not result.data:
return None
return self._map_to_credential(result.data[0])
except Exception as e:
self._logger.error(f"Error finding credential {credential_id}: {str(e)}")
return None
async def save(self, credential: MCPCredential) -> MCPCredential:
raise NotImplementedError("Use store_credential method instead")
async def delete(self, credential_id: str) -> bool:
try:
client = await self._db.client
result = await client.table('user_mcp_credentials')\
.update({'is_active': False})\
.eq('credential_id', credential_id)\
.execute()
return len(result.data) > 0
except Exception as e:
self._logger.error(f"Error deleting credential {credential_id}: {str(e)}")
return False
async def find_by_account_and_qualified_name(
self, account_id: str, mcp_qualified_name: str
) -> Optional[MCPCredential]:
try:
client = await self._db.client
result = await client.table('user_mcp_credentials').select('*')\
.eq('account_id', account_id)\
.eq('mcp_qualified_name', mcp_qualified_name)\
.eq('is_active', True)\
.execute()
if not result.data:
return None
credential = self._map_to_credential(result.data[0])
await client.table('user_mcp_credentials')\
.update({'last_used_at': datetime.now(timezone.utc).isoformat()})\
.eq('credential_id', result.data[0]['credential_id'])\
.execute()
return credential
except Exception as e:
self._logger.error(f"Error finding credential: {str(e)}")
return None
async def find_by_account(self, account_id: str) -> List[MCPCredential]:
try:
client = await self._db.client
result = await client.table('user_mcp_credentials').select('*')\
.eq('account_id', account_id)\
.eq('is_active', True)\
.order('created_at', desc=True)\
.execute()
credentials = []
for cred_data in result.data:
try:
credential = self._map_to_credential(cred_data)
credentials.append(credential)
except Exception as e:
self._logger.error(f"Failed to decrypt credential {cred_data['credential_id']}: {e}")
continue
return credentials
except Exception as e:
self._logger.error(f"Error retrieving user credentials: {str(e)}")
return []
async def store_credential(self, request: CredentialRequest) -> str:
try:
encrypted_config, config_hash = self._encryption.encrypt_config(request.config)
client = await self._db.client
encoded_config = base64.b64encode(encrypted_config).decode('utf-8')
result = await client.table('user_mcp_credentials').upsert({
'account_id': request.account_id,
'mcp_qualified_name': request.mcp_qualified_name,
'display_name': request.display_name,
'encrypted_config': encoded_config,
'config_hash': config_hash,
'is_active': True,
'updated_at': datetime.now(timezone.utc).isoformat()
}, on_conflict='account_id,mcp_qualified_name').execute()
if not result.data:
raise ValueError("Failed to store credential")
credential_id = result.data[0]['credential_id']
self._logger.info(f"Successfully stored credential {credential_id}")
return credential_id
except Exception as e:
self._logger.error(f"Error storing credential: {str(e)}")
raise
async def deactivate_credential(
self, account_id: str, mcp_qualified_name: str
) -> bool:
try:
client = await self._db.client
result = await client.table('user_mcp_credentials')\
.update({'is_active': False})\
.eq('account_id', account_id)\
.eq('mcp_qualified_name', mcp_qualified_name)\
.execute()
return len(result.data) > 0
except Exception as e:
self._logger.error(f"Error deactivating credential: {str(e)}")
return False
def _map_to_credential(self, data: Dict[str, Any]) -> MCPCredential:
encrypted_config = data['encrypted_config']
if isinstance(encrypted_config, str):
encrypted_config_bytes = base64.b64decode(encrypted_config.encode('utf-8'))
else:
encrypted_config_bytes = encrypted_config
config = self._encryption.decrypt_config(
encrypted_config_bytes,
data['config_hash']
)
return MCPCredential(
credential_id=data['credential_id'],
account_id=data['account_id'],
mcp_qualified_name=data['mcp_qualified_name'],
display_name=data['display_name'],
config=config,
is_active=data['is_active'],
last_used_at=data.get('last_used_at'),
created_at=data.get('created_at'),
updated_at=data.get('updated_at')
)