suna/backend/pipedream/repositories/profile_repository.py

230 lines
10 KiB
Python

from typing import List, Optional
from uuid import UUID
import json
from datetime import datetime
from ..protocols import ProfileRepository, DatabaseConnection, EncryptionService, Logger
from ..domain.entities import Profile
from ..domain.value_objects import ExternalUserId, AppSlug, ProfileName, EncryptedConfig, ConfigHash
from ..domain.exceptions import DatabaseException, ProfileNotFoundError
class SupabaseProfileRepository:
def __init__(self, db: DatabaseConnection, encryption_service: EncryptionService, logger: Logger):
self._db = db
self._encryption_service = encryption_service
self._logger = logger
async def create(self, profile: Profile) -> Profile:
try:
client = await self._db.client
config = {
"app_slug": profile.app_slug.value,
"app_name": profile.app_name,
"external_user_id": profile.external_user_id.value,
"enabled_tools": profile.enabled_tools,
"oauth_app_id": getattr(profile, 'oauth_app_id', None),
"description": getattr(profile, 'description', None)
}
config_json = json.dumps(config)
encrypted_config = self._encryption_service.encrypt(config_json)
config_hash = ConfigHash.from_config(config_json)
result = await client.table('user_mcp_credential_profiles').insert({
'account_id': str(profile.account_id),
'mcp_qualified_name': profile.mcp_qualified_name,
'profile_name': profile.profile_name.value,
'display_name': profile.display_name,
'encrypted_config': encrypted_config,
'config_hash': config_hash.value,
'is_active': profile.is_active,
'is_default': profile.is_default,
'created_at': profile.created_at.isoformat(),
'updated_at': profile.updated_at.isoformat()
}).execute()
if result.data:
profile_data = result.data[0]
return self._map_to_domain(profile_data, config)
raise DatabaseException("create", "No data returned from insert")
except Exception as e:
self._logger.error(f"Error creating profile: {str(e)}")
raise DatabaseException("create", str(e))
async def get_by_id(self, account_id: UUID, profile_id: UUID) -> Optional[Profile]:
try:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').select('*').eq(
'account_id', str(account_id)
).eq('profile_id', str(profile_id)).single().execute()
if result.data:
profile_data = result.data
decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config'])
config = json.loads(decrypted_config)
return self._map_to_domain(profile_data, config)
return None
except Exception as e:
self._logger.error(f"Error getting profile by ID: {str(e)}")
return None
async def get_by_app_slug(self, account_id: UUID, app_slug: AppSlug, profile_name: Optional[ProfileName] = None) -> Optional[Profile]:
try:
client = await self._db.client
mcp_qualified_name = f"pipedream:{app_slug.value}"
query = client.table('user_mcp_credential_profiles').select('*').eq(
'account_id', str(account_id)
).eq('mcp_qualified_name', mcp_qualified_name)
if profile_name:
query = query.eq('profile_name', profile_name.value)
result = await query.execute()
if result.data:
if profile_name:
profile_data = result.data[0]
else:
profile_data = next((p for p in result.data if p.get('is_default')), result.data[0])
decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config'])
config = json.loads(decrypted_config)
return self._map_to_domain(profile_data, config)
return None
except Exception as e:
self._logger.error(f"Error getting profile by app slug: {str(e)}")
return None
async def find_by_account(self, account_id: UUID, app_slug: Optional[AppSlug] = None, is_active: Optional[bool] = None) -> List[Profile]:
try:
client = await self._db.client
query = client.table('user_mcp_credential_profiles').select('*').eq(
'account_id', str(account_id)
)
if app_slug:
mcp_qualified_name = f"pipedream:{app_slug.value}"
query = query.eq('mcp_qualified_name', mcp_qualified_name)
else:
query = query.like('mcp_qualified_name', 'pipedream:%')
if is_active is not None:
query = query.eq('is_active', is_active)
result = await query.order('created_at', desc=True).execute()
profiles = []
for profile_data in result.data:
try:
decrypted_config = self._encryption_service.decrypt(profile_data['encrypted_config'])
config = json.loads(decrypted_config)
profile = self._map_to_domain(profile_data, config)
profiles.append(profile)
except Exception as e:
self._logger.error(f"Error decrypting profile config: {str(e)}")
continue
return profiles
except Exception as e:
self._logger.error(f"Error finding profiles by account: {str(e)}")
return []
async def update(self, profile: Profile) -> Profile:
try:
client = await self._db.client
config = {
"app_slug": profile.app_slug.value,
"app_name": profile.app_name,
"external_user_id": profile.external_user_id.value,
"enabled_tools": profile.enabled_tools,
"oauth_app_id": getattr(profile, 'oauth_app_id', None),
"description": getattr(profile, 'description', None)
}
config_json = json.dumps(config)
encrypted_config = self._encryption_service.encrypt(config_json)
config_hash = ConfigHash.from_config(config_json)
result = await client.table('user_mcp_credential_profiles').update({
'profile_name': profile.profile_name.value,
'display_name': profile.display_name,
'encrypted_config': encrypted_config,
'config_hash': config_hash.value,
'is_active': profile.is_active,
'is_default': profile.is_default,
'updated_at': profile.updated_at.isoformat(),
'last_used_at': profile.last_used_at.isoformat() if profile.last_used_at else None
}).eq('profile_id', str(profile.profile_id)).execute()
if result.data:
return self._map_to_domain(result.data[0], config)
raise DatabaseException("update", "No data returned from update")
except Exception as e:
self._logger.error(f"Error updating profile: {str(e)}")
raise DatabaseException("update", str(e))
async def delete(self, account_id: UUID, profile_id: UUID) -> bool:
try:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').delete().eq(
'profile_id', str(profile_id)
).eq('account_id', str(account_id)).execute()
return len(result.data) > 0
except Exception as e:
self._logger.error(f"Error deleting profile: {str(e)}")
return False
async def set_default(self, account_id: UUID, profile_id: UUID, mcp_qualified_name: str) -> None:
try:
client = await self._db.client
await client.table('user_mcp_credential_profiles').update({
'is_default': False
}).eq('account_id', str(account_id)).eq('mcp_qualified_name', mcp_qualified_name).execute()
await client.table('user_mcp_credential_profiles').update({
'is_default': True
}).eq('profile_id', str(profile_id)).execute()
except Exception as e:
self._logger.error(f"Error setting default profile: {str(e)}")
raise DatabaseException("set_default", str(e))
def _map_to_domain(self, profile_data: dict, config: dict) -> Profile:
return Profile(
profile_id=UUID(profile_data['profile_id']),
account_id=UUID(profile_data['account_id']),
mcp_qualified_name=profile_data['mcp_qualified_name'],
profile_name=ProfileName(profile_data['profile_name']),
display_name=profile_data['display_name'],
encrypted_config=EncryptedConfig(profile_data['encrypted_config']),
config_hash=ConfigHash(profile_data['config_hash']),
app_slug=AppSlug(config['app_slug']),
app_name=config['app_name'],
external_user_id=ExternalUserId(config['external_user_id']),
enabled_tools=config.get('enabled_tools', []),
is_active=profile_data['is_active'],
is_default=profile_data['is_default'],
is_connected=False,
created_at=datetime.fromisoformat(profile_data['created_at']),
updated_at=datetime.fromisoformat(profile_data['updated_at']),
last_used_at=datetime.fromisoformat(profile_data['last_used_at']) if profile_data.get('last_used_at') else None
)